Skip to content

Commit

Permalink
Retry on transient failures during AWS IAM auth login attempts (#8727)
Browse files Browse the repository at this point in the history
* use retryer for failed aws auth attempts

* fixes from testing
  • Loading branch information
tyrannosaurus-becks authored Apr 13, 2020
1 parent 2de996b commit 78ed5f3
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/fullsailor/pkcs7"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/awsutil"
Expand All @@ -35,6 +37,10 @@ const (
iamAuthType = "iam"
ec2AuthType = "ec2"
ec2EntityType = "ec2_instance"

// Retry configuration
retryWaitMin = 500 * time.Millisecond
retryWaitMax = 30 * time.Second
)

func (b *backend) pathLogin() *framework.Path {
Expand Down Expand Up @@ -1198,6 +1204,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,

endpoint := "https://sts.amazonaws.com"

maxRetries := awsClient.DefaultRetryerMaxNumRetries
if config != nil {
if config.IAMServerIdHeaderValue != "" {
err = validateVaultHeaderValue(headers, parsedUrl, config.IAMServerIdHeaderValue)
Expand All @@ -1208,9 +1215,12 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
if config.STSEndpoint != "" {
endpoint = config.STSEndpoint
}
if config.MaxRetries >= 0 {
maxRetries = config.MaxRetries
}
}

callerID, err := submitCallerIdentityRequest(method, endpoint, parsedUrl, body, headers)
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
}
Expand Down Expand Up @@ -1555,18 +1565,31 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse,
return result, err
}

func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
func submitCallerIdentityRequest(ctx context.Context, maxRetries int, method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) {
// NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy
// The protection against this is that this method will only call the endpoint specified in the
// client config (defaulting to sts.amazonaws.com), so it would require a Vault admin to override
// the endpoint to talk to alternate web addresses
request := buildHttpRequest(method, endpoint, parsedUrl, body, headers)
retryableReq, err := retryablehttp.FromRequest(request)
if err != nil {
return nil, err
}
retryableReq = retryableReq.WithContext(ctx)
client := cleanhttp.DefaultClient()
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
retryingClient := &retryablehttp.Client{
HTTPClient: client,
RetryWaitMin: retryWaitMin,
RetryWaitMax: retryWaitMax,
RetryMax: maxRetries,
CheckRetry: retryablehttp.DefaultRetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
}

response, err := client.Do(request)
response, err := retryingClient.Do(retryableReq)
if err != nil {
return nil, errwrap.Wrapf("error making request: {{err}}", err)
}
Expand Down

0 comments on commit 78ed5f3

Please sign in to comment.