diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index 7c31cfb3bb5a..5bbb74c832e1 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -16,11 +16,13 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + awsClient "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/iam" "github.com/fullsailor/pkcs7" "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-retryablehttp" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/awsutil" @@ -35,6 +37,10 @@ const ( iamAuthType = "iam" ec2AuthType = "ec2" ec2EntityType = "ec2_instance" + + // Retry configuration + retryWaitMin = 500 * time.Millisecond + retryWaitMax = 30 * time.Second ) func (b *backend) pathLogin() *framework.Path { @@ -1199,6 +1205,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, endpoint := "https://sts.amazonaws.com" + maxRetries := awsClient.DefaultRetryerMaxNumRetries if config != nil { if config.IAMServerIdHeaderValue != "" { err = validateVaultHeaderValue(headers, parsedUrl, config.IAMServerIdHeaderValue) @@ -1209,9 +1216,12 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, if config.STSEndpoint != "" { endpoint = config.STSEndpoint } + if config.MaxRetries >= 0 { + maxRetries = config.MaxRetries + } } - callerID, err := submitCallerIdentityRequest(method, endpoint, parsedUrl, body, headers) + callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil } @@ -1555,18 +1565,31 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, return result, err } -func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) { +func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) { // NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy // The protection against this is that this method will only call the endpoint specified in the // client config (defaulting to sts.amazonaws.com), so it would require a Vault admin to override // the endpoint to talk to alternate web addresses request := buildHttpRequest(method, endpoint, parsedUrl, body, headers) + retryableReq, err := retryablehttp.FromRequest(request) + if err != nil { + return nil, err + } + retryableReq = retryableReq.WithContext(ctx) client := cleanhttp.DefaultClient() client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } + retryingClient := &retryablehttp.Client{ + HTTPClient: client, + RetryWaitMin: retryWaitMin, + RetryWaitMax: retryWaitMax, + RetryMax: maxRetries, + CheckRetry: retryablehttp.DefaultRetryPolicy, + Backoff: retryablehttp.DefaultBackoff, + } - response, err := client.Do(request) + response, err := retryingClient.Do(retryableReq) if err != nil { return nil, errwrap.Wrapf("error making request: {{err}}", err) }