From 4cf2beae567ef66801205362dd0e825e5b22e855 Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Thu, 3 Dec 2020 09:17:34 -0800 Subject: [PATCH 1/5] Add ExpiryWindow and ExpiryWindowJitterFrac to CredentialsCache --- aws/credential_cache.go | 74 +++++++++++++++- aws/credential_cache_bench_test.go | 8 +- aws/credential_cache_test.go | 137 ++++++++++++++++++++++------- 3 files changed, 176 insertions(+), 43 deletions(-) diff --git a/aws/credential_cache.go b/aws/credential_cache.go index 650470d1986..48833e6e0e0 100644 --- a/aws/credential_cache.go +++ b/aws/credential_cache.go @@ -3,17 +3,74 @@ package aws import ( "context" "sync/atomic" + "time" + sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" "github.com/aws/aws-sdk-go-v2/internal/sync/singleflight" ) +// CredentialsCacheOptions are the options +type CredentialsCacheOptions struct { + // Provider is the CredentialProvider implementation to be wrapped by the CredentialCache. + Provider CredentialsProvider + + // ExpiryWindow will allow the credentials to trigger refreshing prior to + // the credentials actually expiring. This is beneficial so race conditions + // with expiring credentials do not cause request to fail unexpectedly + // due to ExpiredTokenException exceptions. + // + // An ExpiryWindow of 10s would cause calls to IsExpired() to return true + // 10 seconds before the credentials are actually expired. + // + // If ExpiryWindow is 0 or less it will be ignored. + ExpiryWindow time.Duration + + // ExpiryWindowJitterFrac provides a mechanism for randomizing the expiration of credentials + // within the configured ExpiryWindow by a random percentage. Valid values are between 0.0 and 1.0. + // + // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac is 0.5 then credentials will be set to + // expire between 30 to 60 seconds prior to their actual expiration time. + // + // If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored. + // If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window. + // If ExpiryWindowJitterFrac < 0 the value will be treated as 0. + // If ExpiryWindowJitterFrac > 1 the value will be treated as 1. + ExpiryWindowJitterFrac float64 +} + // CredentialsCache provides caching and concurrency safe credentials retrieval // via the provider's retrieve method. type CredentialsCache struct { - Provider CredentialsProvider + options CredentialsCacheOptions + creds atomic.Value + sf singleflight.Group +} + +// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider is expected to not be nil. A variadic +// list of one or more functions can be provided to modify the CredentialsCache configuration. This allows for +// configuration of credential expiry window and jitter. +func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache { + options := CredentialsCacheOptions{ + Provider: provider, + } - creds atomic.Value - sf singleflight.Group + for _, fn := range optFns { + fn(&options) + } + + if options.ExpiryWindow < 0 { + options.ExpiryWindow = 0 + } + + if options.ExpiryWindowJitterFrac < 0 { + options.ExpiryWindowJitterFrac = 0 + } else if options.ExpiryWindowJitterFrac > 1 { + options.ExpiryWindowJitterFrac = 1 + } + + return &CredentialsCache{ + options: options, + } } // Retrieve returns the credentials. If the credentials have already been @@ -41,8 +98,17 @@ func (p *CredentialsCache) singleRetrieve() (interface{}, error) { return *creds, nil } - creds, err := p.Provider.Retrieve(context.TODO()) + creds, err := p.options.Provider.Retrieve(context.TODO()) if err == nil { + if creds.CanExpire { + randFloat64, err := sdkrand.CryptoRandFloat64() + if err != nil { + return Credentials{}, err + } + jitter := time.Duration(randFloat64 * p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow)) + creds.Expires = creds.Expires.Add(-(p.options.ExpiryWindow - jitter)) + } + p.creds.Store(&creds) } diff --git a/aws/credential_cache_bench_test.go b/aws/credential_cache_bench_test.go index 1aa562c0623..b8bd09d0075 100644 --- a/aws/credential_cache_bench_test.go +++ b/aws/credential_cache_bench_test.go @@ -21,9 +21,7 @@ func BenchmarkCredentialsCache_Retrieve(b *testing.B) { cases := []int{1, 10, 100, 500, 1000, 10000} for _, c := range cases { b.Run(strconv.Itoa(c), func(b *testing.B) { - p := CredentialsCache{ - Provider: provider, - } + p := NewCredentialsCache(provider) var wg sync.WaitGroup wg.Add(c) for i := 0; i < c; i++ { @@ -59,9 +57,7 @@ func BenchmarkCredentialsCache_Retrieve_Invalidate(b *testing.B) { for _, expRate := range expRates { for _, c := range cases { b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) { - p := CredentialsCache{ - Provider: provider, - } + p := NewCredentialsCache(provider) var wg sync.WaitGroup wg.Add(c) for i := 0; i < c; i++ { diff --git a/aws/credential_cache_test.go b/aws/credential_cache_test.go index e09106e9258..b9d350c0fc5 100644 --- a/aws/credential_cache_test.go +++ b/aws/credential_cache_test.go @@ -42,15 +42,13 @@ func TestCredentialsCache_Cache(t *testing.T) { } var called bool - p := &CredentialsCache{ - Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { - if called { - t.Fatalf("expect provider.Retrieve to only be called once") - } - called = true - return expect, nil - }), - } + p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { + if called { + t.Fatalf("expect provider.Retrieve to only be called once") + } + called = true + return expect, nil + })) for i := 0; i < 2; i++ { creds, err := p.Retrieve(context.Background()) @@ -108,12 +106,10 @@ func TestCredentialsCache_Expires(t *testing.T) { for _, c := range cases { var called int - p := &CredentialsCache{ - Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { - called++ - return c.Creds(), nil - }), - } + p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { + called++ + return c.Creds(), nil + })) p.Retrieve(context.Background()) p.Retrieve(context.Background()) @@ -131,13 +127,92 @@ func TestCredentialsCache_Expires(t *testing.T) { } } -func TestCredentialsCache_Error(t *testing.T) { - p := &CredentialsCache{ - Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { - return Credentials{}, fmt.Errorf("failed") - }), +func TestCredentialsCache_ExpireTime(t *testing.T) { + orig := sdk.NowTime + defer func() { sdk.NowTime = orig }() + var mockTime time.Time + sdk.NowTime = func() time.Time { return mockTime } + + cases := map[string]struct { + ExpireTime time.Time + ExpiryWindow time.Duration + JitterFrac float64 + Validate func(t *testing.T, v time.Time) + }{ + "no expire window": { + Validate: func(t *testing.T, v time.Time) { + t.Helper() + if e, a := mockTime, v; !e.Equal(a) { + t.Errorf("expect %v, got %v", e, a) + } + }, + }, + "expire window": { + ExpireTime: mockTime.Add(100), + ExpiryWindow: 50, + Validate: func(t *testing.T, v time.Time) { + t.Helper() + if e, a := mockTime.Add(50), v; !e.Equal(a) { + t.Errorf("expect %v, got %v", e, a) + } + }, + }, + "expire window with jitter": { + ExpireTime: mockTime.Add(100), + JitterFrac: 0.5, + ExpiryWindow: 50, + Validate: func(t *testing.T, v time.Time) { + t.Helper() + max := mockTime.Add(75) + min := mockTime.Add(50) + if v.Before(min) { + t.Errorf("expect %v to be before %s", v, min) + } + if v.After(max) { + t.Errorf("expect %v to be after %s", v, max) + } + }, + }, + "no expire window with jitter": { + ExpireTime: mockTime, + JitterFrac: 0.5, + Validate: func(t *testing.T, v time.Time) { + t.Helper() + if e, a := mockTime, v; !e.Equal(a) { + t.Errorf("expect %v, got %v", e, a) + } + }, + }, } + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { + return Credentials{ + AccessKeyID: "accessKey", + SecretAccessKey: "secretKey", + CanExpire: true, + Expires: tt.ExpireTime, + }, nil + }), func(options *CredentialsCacheOptions) { + options.ExpiryWindow = tt.ExpiryWindow + options.ExpiryWindowJitterFrac = tt.JitterFrac + }) + + credentials, err := p.Retrieve(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + tt.Validate(t, credentials.Expires) + }) + } +} + +func TestCredentialsCache_Error(t *testing.T) { + p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { + return Credentials{}, fmt.Errorf("failed") + })) + creds, err := p.Retrieve(context.Background()) if err == nil { t.Fatalf("expect error, not none") @@ -156,16 +231,14 @@ func TestCredentialsCache_Race(t *testing.T) { SecretAccessKey: "secret", } var called bool - p := &CredentialsCache{ - Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { - time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) - if called { - t.Fatalf("expect provider.Retrieve only called once") - } - called = true - return expect, nil - }), - } + p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) { + time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) + if called { + t.Fatalf("expect provider.Retrieve only called once") + } + called = true + return expect, nil + })) var wg sync.WaitGroup wg.Add(100) @@ -206,9 +279,7 @@ func TestCredentialsCache_RetrieveConcurrent(t *testing.T) { stub := &stubConcurrentProvider{ done: make(chan struct{}), } - provider := CredentialsCache{ - Provider: stub, - } + provider := NewCredentialsCache(stub) var wg sync.WaitGroup wg.Add(2) From c7250daedb0f2379726767bf099c9bcba082d980 Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Thu, 3 Dec 2020 09:18:38 -0800 Subject: [PATCH 2/5] Remove ExpiryWindow from individual providers --- credentials/ec2rolecreds/provider.go | 13 +------------ credentials/endpointcreds/provider.go | 14 +------------- credentials/processcreds/provider.go | 14 +------------- credentials/stscreds/assume_role_provider.go | 13 +------------ credentials/stscreds/web_identity_provider.go | 14 +------------- 5 files changed, 5 insertions(+), 63 deletions(-) diff --git a/credentials/ec2rolecreds/provider.go b/credentials/ec2rolecreds/provider.go index b2f20d950c5..df80a510bb8 100644 --- a/credentials/ec2rolecreds/provider.go +++ b/credentials/ec2rolecreds/provider.go @@ -46,17 +46,6 @@ type Options struct { // // If nil, the provider will default to the ec2imds client. Client GetMetadataAPIClient - - // ExpiryWindow will allow the credentials to trigger refreshing prior to - // the credentials actually expiring. This is beneficial so race conditions - // with expiring credentials do not cause request to fail unexpectedly - // due to ExpiredTokenException exceptions. - // - // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true - // 10 seconds before the credentials are actually expired. - // - // If ExpiryWindow is 0 or less it will be ignored. - ExpiryWindow time.Duration } // New returns an initialized Provider value configured to retrieve @@ -102,7 +91,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { Source: ProviderName, CanExpire: true, - Expires: roleCreds.Expiration.Add(-p.options.ExpiryWindow), + Expires: roleCreds.Expiration, } return creds, nil diff --git a/credentials/endpointcreds/provider.go b/credentials/endpointcreds/provider.go index 39819ec3384..39454bf86c2 100644 --- a/credentials/endpointcreds/provider.go +++ b/credentials/endpointcreds/provider.go @@ -33,7 +33,6 @@ import ( "context" "fmt" "net/http" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client" @@ -65,17 +64,6 @@ type HTTPClient interface { // Options is structure of configurable options for Provider type Options struct { - // ExpiryWindow will allow the credentials to trigger refreshing prior to - // the credentials actually expiring. This is beneficial so race conditions - // with expiring credentials do not cause request to fail unexpectedly - // due to ExpiredTokenException exceptions. - // - // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true - // 10 seconds before the credentials are actually expired. - // - // If ExpiryWindow is 0 or less it will be ignored. - ExpiryWindow time.Duration - // Endpoint to retrieve credentials from. Required Endpoint string @@ -134,7 +122,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { if resp.Expiration != nil { creds.CanExpire = true - creds.Expires = resp.Expiration.Add(-p.options.ExpiryWindow) + creds.Expires = *resp.Expiration } return creds, nil diff --git a/credentials/processcreds/provider.go b/credentials/processcreds/provider.go index 0bac613db37..3921da34cd7 100644 --- a/credentials/processcreds/provider.go +++ b/credentials/processcreds/provider.go @@ -55,18 +55,6 @@ type Provider struct { // Options is the configuration options for configuring the Provider. type Options struct { - // ExpiryWindow will allow the credentials to trigger refreshing prior to - // the credentials actually expiring. This is beneficial so race conditions - // with expiring credentials do not cause request to fail unexpectedly - // due to ExpiredTokenException exceptions. - // - // For example, an ExpiryWindow of 10s would cause calls to the - // Credentials.IsExpired() method to return true 10 seconds before the - // credentials would of actually expired. - // - // If ExpiryWindow is 0 or less, it will be ignored. - ExpiryWindow time.Duration - // Timeout limits the time a process can run. Timeout time.Duration } @@ -213,7 +201,7 @@ func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { // Handle expiration if resp.Expiration != nil { creds.CanExpire = true - creds.Expires = (*resp.Expiration).Add(-p.options.ExpiryWindow) + creds.Expires = *resp.Expiration } return creds, nil diff --git a/credentials/stscreds/assume_role_provider.go b/credentials/stscreds/assume_role_provider.go index 1e047093972..e79fa66e414 100644 --- a/credentials/stscreds/assume_role_provider.go +++ b/credentials/stscreds/assume_role_provider.go @@ -215,17 +215,6 @@ type AssumeRoleOptions struct { // If both TokenCode and TokenProvider is set, TokenProvider will be used and // TokenCode is ignored. TokenProvider func() (string, error) - - // ExpiryWindow will allow the credentials to trigger refreshing prior to - // the credentials actually expiring. This is beneficial so race conditions - // with expiring credentials do not cause request to fail unexpectedly - // due to ExpiredTokenException exceptions. - // - // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true - // 10 seconds before the credentials are actually expired. - // - // If ExpiryWindow is 0 or less it will be ignored. - ExpiryWindow time.Duration } // NewAssumeRoleProvider constructs and returns a credentials provider that @@ -291,6 +280,6 @@ func (p *AssumeRoleProvider) Retrieve(ctx context.Context) (aws.Credentials, err Source: ProviderName, CanExpire: true, - Expires: resp.Credentials.Expiration.Add(-p.options.ExpiryWindow), + Expires: *resp.Credentials.Expiration, }, nil } diff --git a/credentials/stscreds/web_identity_provider.go b/credentials/stscreds/web_identity_provider.go index 9e3e5e8adce..7854a3228c0 100644 --- a/credentials/stscreds/web_identity_provider.go +++ b/credentials/stscreds/web_identity_provider.go @@ -5,7 +5,6 @@ import ( "fmt" "io/ioutil" "strconv" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/retry" @@ -46,17 +45,6 @@ type WebIdentityRoleOptions struct { // Session name, if you wish to uniquely identify this session. RoleSessionName string - // ExpiryWindow will allow the credentials to trigger refreshing prior to - // the credentials actually expiring. This is beneficial so race conditions - // with expiring credentials do not cause request to fail unexpectedly - // due to ExpiredTokenException exceptions. - // - // So a ExpiryWindow of 10s would cause calls to IsExpired() to return true - // 10 seconds before the credentials are actually expired. - // - // If ExpiryWindow is 0 or less it will be ignored. - ExpiryWindow time.Duration - // The Amazon Resource Names (ARNs) of the IAM managed policies that you // want to use as managed session policies. The policies must exist in the // same account as the role. @@ -133,7 +121,7 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials SessionToken: aws.ToString(resp.Credentials.SessionToken), Source: WebIdentityProviderName, CanExpire: true, - Expires: resp.Credentials.Expiration.Add(-p.options.ExpiryWindow), + Expires: *resp.Credentials.Expiration, } return value, nil } From 88a44f8c3493dc8d0665a247a1b24cd5e2e954b6 Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Thu, 3 Dec 2020 09:18:56 -0800 Subject: [PATCH 3/5] Document internal random Float64 boundaries --- internal/rand/rand.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/rand/rand.go b/internal/rand/rand.go index ad3caf6b658..9791ea590b5 100644 --- a/internal/rand/rand.go +++ b/internal/rand/rand.go @@ -16,7 +16,7 @@ var Reader io.Reader var floatMaxBigInt = big.NewInt(1 << 53) -// Float64 returns a float64 read from an io.Reader source. +// Float64 returns a float64 read from an io.Reader source. The returned float will be between [0.0, 1.0). func Float64(reader io.Reader) (float64, error) { bi, err := rand.Int(reader, floatMaxBigInt) if err != nil { From 28697aa3e48a643759cba4fadcdfc6134b0a7065 Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Thu, 3 Dec 2020 18:04:14 -0800 Subject: [PATCH 4/5] Implement Feedback --- aws/context.go | 22 ++++++++++++++++++++++ aws/credential_cache.go | 21 ++++++++++++--------- 2 files changed, 34 insertions(+), 9 deletions(-) create mode 100644 aws/context.go diff --git a/aws/context.go b/aws/context.go new file mode 100644 index 00000000000..4d8e26ef321 --- /dev/null +++ b/aws/context.go @@ -0,0 +1,22 @@ +package aws + +import ( + "context" + "time" +) + +type suppressedContext struct { + context.Context +} + +func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +func (s *suppressedContext) Done() <-chan struct{} { + return nil +} + +func (s *suppressedContext) Err() error { + return nil +} diff --git a/aws/credential_cache.go b/aws/credential_cache.go index 48833e6e0e0..1203281019a 100644 --- a/aws/credential_cache.go +++ b/aws/credential_cache.go @@ -11,8 +11,6 @@ import ( // CredentialsCacheOptions are the options type CredentialsCacheOptions struct { - // Provider is the CredentialProvider implementation to be wrapped by the CredentialCache. - Provider CredentialsProvider // ExpiryWindow will allow the credentials to trigger refreshing prior to // the credentials actually expiring. This is beneficial so race conditions @@ -20,7 +18,8 @@ type CredentialsCacheOptions struct { // due to ExpiredTokenException exceptions. // // An ExpiryWindow of 10s would cause calls to IsExpired() to return true - // 10 seconds before the credentials are actually expired. + // 10 seconds before the credentials are actually expired. This can cause an + // increased number of requests to refresh the credentials to occur. // // If ExpiryWindow is 0 or less it will be ignored. ExpiryWindow time.Duration @@ -41,6 +40,9 @@ type CredentialsCacheOptions struct { // CredentialsCache provides caching and concurrency safe credentials retrieval // via the provider's retrieve method. type CredentialsCache struct { + // provider is the CredentialProvider implementation to be wrapped by the CredentialCache. + provider CredentialsProvider + options CredentialsCacheOptions creds atomic.Value sf singleflight.Group @@ -50,9 +52,7 @@ type CredentialsCache struct { // list of one or more functions can be provided to modify the CredentialsCache configuration. This allows for // configuration of credential expiry window and jitter. func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache { - options := CredentialsCacheOptions{ - Provider: provider, - } + options := CredentialsCacheOptions{} for _, fn := range optFns { fn(&options) @@ -69,6 +69,7 @@ func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *C } return &CredentialsCache{ + provider: provider, options: options, } } @@ -84,7 +85,9 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) { return *creds, nil } - resCh := p.sf.DoChan("", p.singleRetrieve) + resCh := p.sf.DoChan("", func() (interface{}, error) { + return p.singleRetrieve(&suppressedContext{ctx}) + }) select { case res := <-resCh: return res.Val.(Credentials), res.Err @@ -93,12 +96,12 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) { } } -func (p *CredentialsCache) singleRetrieve() (interface{}, error) { +func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) { if creds := p.getCreds(); creds != nil { return *creds, nil } - creds, err := p.options.Provider.Retrieve(context.TODO()) + creds, err := p.provider.Retrieve(ctx) if err == nil { if creds.CanExpire { randFloat64, err := sdkrand.CryptoRandFloat64() From c3e6b358426aa9955f6eb05ed16bfd87c21160c6 Mon Sep 17 00:00:00 2001 From: Sean McGrail Date: Fri, 4 Dec 2020 14:53:16 -0800 Subject: [PATCH 5/5] Fix config CredentialsCache usage --- config/resolve_credentials.go | 28 +++++++---- config/resolve_test.go | 22 ++++----- credentials/ec2rolecreds/integration_test.go | 2 +- credentials/ec2rolecreds/provider.go | 10 ++-- credentials/ec2rolecreds/provider_test.go | 52 ++++---------------- 5 files changed, 44 insertions(+), 70 deletions(-) diff --git a/config/resolve_credentials.go b/config/resolve_credentials.go index 98705d3e035..e0ceab59054 100644 --- a/config/resolve_credentials.go +++ b/config/resolve_credentials.go @@ -24,7 +24,6 @@ const ( var ( ecsContainerEndpoint = "http://169.254.170.2" // not constant to allow for swapping during unit-testing - ) // resolveCredentials extracts a credential provider from slice of config sources. @@ -66,7 +65,7 @@ func resolveCredentialProvider(cfg *aws.Config, cfgs configs) (bool, error) { return false, nil } - cfg.Credentials = &aws.CredentialsCache{Provider: credProvider} + cfg.Credentials = wrapWithCredentialsCache(credProvider) return true, nil } @@ -100,7 +99,7 @@ func resolveCredentialChain(cfg *aws.Config, configs configs) (err error) { } // Wrap the resolved provider in a cache so the SDK will cache credentials. - cfg.Credentials = &aws.CredentialsCache{Provider: cfg.Credentials} + cfg.Credentials = wrapWithCredentialsCache(cfg.Credentials) return nil } @@ -198,7 +197,6 @@ func resolveLocalHTTPCredProvider(cfg *aws.Config, endpointURL, authToken string func resolveHTTPCredProvider(cfg *aws.Config, url, authToken string, configs configs) error { optFns := []func(*endpointcreds.Options){ func(options *endpointcreds.Options) { - options.ExpiryWindow = 5 * time.Minute if len(authToken) != 0 { options.AuthorizationToken = authToken } @@ -217,7 +215,9 @@ func resolveHTTPCredProvider(cfg *aws.Config, url, authToken string, configs con provider := endpointcreds.New(url, optFns...) - cfg.Credentials = provider + cfg.Credentials = wrapWithCredentialsCache(provider, func(options *aws.CredentialsCacheOptions) { + options.ExpiryWindow = 5 * time.Minute + }) return nil } @@ -264,11 +264,11 @@ func resolveEC2RoleCredentials(cfg *aws.Config, configs configs) error { } }) - provider := ec2rolecreds.New(ec2rolecreds.Options{ - ExpiryWindow: 5 * time.Minute, - }, optFns...) + provider := ec2rolecreds.New(optFns...) - cfg.Credentials = provider + cfg.Credentials = wrapWithCredentialsCache(provider, func(options *aws.CredentialsCacheOptions) { + options.ExpiryWindow = 5*time.Minute + }) return nil } @@ -399,3 +399,13 @@ func credsFromAssumeRole(cfg *aws.Config, sharedCfg *SharedConfig, configs confi return nil } + +// wrapWithCredentialsCache will wrap provider with an aws.CredentialsCache with the provided options if the provider is not already a aws.CredentialsCache. +func wrapWithCredentialsCache(provider aws.CredentialsProvider, optFns ...func(options *aws.CredentialsCacheOptions)) aws.CredentialsProvider { + _, ok := provider.(*aws.CredentialsCache) + if ok { + return provider + } + + return aws.NewCredentialsCache(provider, optFns...) +} diff --git a/config/resolve_test.go b/config/resolve_test.go index 22f125e172d..cf36d435f26 100644 --- a/config/resolve_test.go +++ b/config/resolve_test.go @@ -130,28 +130,24 @@ func TestResolveCredentialsProvider(t *testing.T) { t.Fatalf("expected %v, got %v", e, a) } - cache, ok := cfg.Credentials.(*aws.CredentialsCache) + _, ok := cfg.Credentials.(*aws.CredentialsCache) if !ok { t.Fatalf("expect resolved credentials to be wrapped in cache, was not, %T", cfg.Credentials) } - p := cache.Provider.(credentials.StaticCredentialsProvider) - if e, a := "AKID", p.Value.AccessKeyID; e != a { - t.Errorf("expect %v key, got %v", e, a) - } - if e, a := "SECRET", p.Value.SecretAccessKey; e != a { - t.Errorf("expect %v secret, got %v", e, a) - } - if e, a := "valid", p.Value.Source; e != a { - t.Errorf("expect %v provider name, got %v", e, a) - } - creds, err := cfg.Credentials.Retrieve(context.Background()) if err != nil { t.Fatalf("expect no error, got %v", err) } + + if e, a := "AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v key, got %v", e, a) + } + if e, a := "SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v secret, got %v", e, a) + } if e, a := "valid", creds.Source; e != a { - t.Errorf("expect %v creds, got %v", e, a) + t.Errorf("expect %v provider name, got %v", e, a) } } diff --git a/credentials/ec2rolecreds/integration_test.go b/credentials/ec2rolecreds/integration_test.go index 061ae54e6be..378a376ad31 100644 --- a/credentials/ec2rolecreds/integration_test.go +++ b/credentials/ec2rolecreds/integration_test.go @@ -8,7 +8,7 @@ import ( ) func TestInteg_RetrieveCredentials(t *testing.T) { - provider := New(Options{}) + provider := New() creds, err := provider.Retrieve(context.Background()) if err != nil { diff --git a/credentials/ec2rolecreds/provider.go b/credentials/ec2rolecreds/provider.go index df80a510bb8..6eceafa7300 100644 --- a/credentials/ec2rolecreds/provider.go +++ b/credentials/ec2rolecreds/provider.go @@ -50,15 +50,17 @@ type Options struct { // New returns an initialized Provider value configured to retrieve // credentials from EC2 Instance Metadata service. -func New(options Options, optFns ...func(*Options)) *Provider { - if options.Client == nil { - options.Client = ec2imds.New(ec2imds.Options{}) - } +func New(optFns ...func(*Options)) *Provider { + options := Options{} for _, fn := range optFns { fn(&options) } + if options.Client == nil { + options.Client = ec2imds.New(ec2imds.Options{}) + } + return &Provider{ options: options, } diff --git a/credentials/ec2rolecreds/provider_test.go b/credentials/ec2rolecreds/provider_test.go index af284f2c9f7..a80a1d149fb 100644 --- a/credentials/ec2rolecreds/provider_test.go +++ b/credentials/ec2rolecreds/provider_test.go @@ -67,12 +67,12 @@ func TestProvider(t *testing.T) { orig := sdk.NowTime defer func() { sdk.NowTime = orig }() - p := New(Options{ - Client: mockClient{ + p := New(func(options *Options) { + options.Client = mockClient{ roleName: "RoleName", failAssume: false, expireOn: "2014-12-16T01:51:37Z", - }, + } }) creds, err := p.Retrieve(context.Background()) @@ -99,12 +99,12 @@ func TestProvider(t *testing.T) { } func TestProvider_FailAssume(t *testing.T) { - p := New(Options{ - Client: mockClient{ + p := New(func(options *Options) { + options.Client = mockClient{ roleName: "RoleName", failAssume: true, expireOn: "2014-12-16T01:51:37Z", - }, + } }) creds, err := p.Retrieve(context.Background()) @@ -143,12 +143,12 @@ func TestProvider_IsExpired(t *testing.T) { orig := sdk.NowTime defer func() { sdk.NowTime = orig }() - p := New(Options{ - Client: mockClient{ + p := New(func(options *Options) { + options.Client = mockClient{ roleName: "RoleName", failAssume: false, expireOn: "2014-12-16T01:51:37Z", - }, + } }) sdk.NowTime = func() time.Time { @@ -171,37 +171,3 @@ func TestProvider_IsExpired(t *testing.T) { t.Errorf("expect to be expired") } } - -func TestProvider_ExpiryWindowIsExpired(t *testing.T) { - orig := sdk.NowTime - defer func() { sdk.NowTime = orig }() - - p := New(Options{ - Client: mockClient{ - roleName: "RoleName", - failAssume: false, - expireOn: "2014-12-16T01:51:37Z", - }, - ExpiryWindow: time.Hour, - }) - - sdk.NowTime = func() time.Time { - return time.Date(2014, 12, 16, 0, 40, 37, 0, time.UTC) - } - - creds, err := p.Retrieve(context.Background()) - if err != nil { - t.Fatalf("expect no error, got %v", err) - } - if creds.Expired() { - t.Errorf("expect not to be expired") - } - - sdk.NowTime = func() time.Time { - return time.Date(2014, 12, 16, 1, 30, 37, 0, time.UTC) - } - - if !creds.Expired() { - t.Errorf("expect to be expired") - } -}