diff --git a/awsutil/options.go b/awsutil/options.go index ad3fd71..d3da48f 100644 --- a/awsutil/options.go +++ b/awsutil/options.go @@ -41,7 +41,7 @@ type options struct { withMaxRetries *int withRegion string withHttpClient *http.Client - withTimeout time.Duration + withValidityCheckTimeout time.Duration } func getDefaultOptions() options { @@ -162,11 +162,11 @@ func WithHttpClient(with *http.Client) Option { } } -// WithTimeout allows passing a timeout for operations that can wait +// WithValidityCheckTimeout allows passing a timeout for operations that can wait // on success. -func WithTimeout(with time.Duration) Option { +func WithValidityCheckTimeout(with time.Duration) Option { return func(o *options) error { - o.withTimeout = with + o.withValidityCheckTimeout = with return nil } } diff --git a/awsutil/options_test.go b/awsutil/options_test.go index 47b5a7d..65fe352 100644 --- a/awsutil/options_test.go +++ b/awsutil/options_test.go @@ -113,9 +113,9 @@ func Test_GetOpts(t *testing.T) { require.NoError(t, err) assert.Equal(t, &opts.withHttpClient, &client) }) - t.Run("withTimeout", func(t *testing.T) { - opts, err := getOpts(WithTimeout(time.Second)) + t.Run("withValidityCheckTimeout", func(t *testing.T) { + opts, err := getOpts(WithValidityCheckTimeout(time.Second)) require.NoError(t, err) - assert.Equal(t, opts.withTimeout, time.Second) + assert.Equal(t, opts.withValidityCheckTimeout, time.Second) }) } diff --git a/awsutil/rotate.go b/awsutil/rotate.go index 2ef68a9..0fbc7d2 100644 --- a/awsutil/rotate.go +++ b/awsutil/rotate.go @@ -23,7 +23,9 @@ import ( // if the old one could not be deleted. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername +// WithAwsSession, WithUsername, WithValidityCheckTimeout. Note that WithValidityCheckTimeout +// here, when non-zero, controls the WithValidityCheckTimeout option on access key +// creation. See CreateAccessKey for more details. func (c *CredentialsConfig) RotateKeys(opt ...Option) error { if c.AccessKey == "" || c.SecretKey == "" { return errors.New("cannot rotate credentials when either access_key or secret_key is empty") @@ -62,11 +64,14 @@ func (c *CredentialsConfig) RotateKeys(opt ...Option) error { // CreateAccessKey creates a new access/secret key pair. // // Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername +// WithAwsSession, WithUsername, WithValidityCheckTimeout +// +// When WithValidityCheckTimeout is non-zero, it specifies a timeout to wait on +// the created credentials to be valid and ready for use. func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKeyOutput, error) { opts, err := getOpts(opt...) if err != nil { - return nil, fmt.Errorf("error reading options in RotateKeys: %w", err) + return nil, fmt.Errorf("error reading options in CreateAccessKey: %w", err) } sess := opts.withAwsSession @@ -114,6 +119,20 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from aws.CreateAccessKey") } + // Check the credentials to make sure they are usable. We only do + // this if withValidityCheckTimeout is non-zero to ensue that we don't + // immediately fail due to eventual consistency. + if opts.withValidityCheckTimeout != 0 { + newC := &CredentialsConfig{ + AccessKey: *createAccessKeyRes.AccessKey.AccessKeyId, + SecretKey: *createAccessKeyRes.AccessKey.SecretAccessKey, + } + + if _, err := newC.GetCallerIdentity(WithValidityCheckTimeout(opts.withValidityCheckTimeout)); err != nil { + return nil, fmt.Errorf("error verifying new credentials: %w", err) + } + } + return createAccessKeyRes, nil } @@ -204,7 +223,7 @@ func (c *CredentialsConfig) GetSession(opt ...Option) (*session.Session, error) // account and user ID. // // Supported options: WithEnvironmentCredentials, -// WithSharedCredentials, WithAwsSession, WithTimeout +// WithSharedCredentials, WithAwsSession, WithValidityCheckTimeout func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIdentityOutput, error) { opts, err := getOpts(opt...) if err != nil { @@ -225,8 +244,7 @@ func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIden } delay := time.Second - maxDelay := time.Second * 30 - timeoutCtx, cancel := context.WithTimeout(context.Background(), opts.withTimeout) + timeoutCtx, cancel := context.WithTimeout(context.Background(), opts.withValidityCheckTimeout) defer cancel() for { cid, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{}) @@ -241,19 +259,13 @@ func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIden case <-timeoutCtx.Done(): // Format our error based on how we were called. - if opts.withTimeout == 0 { + if opts.withValidityCheckTimeout == 0 { // There was no timeout, just return the error unwrapped. return nil, err } // Otherwise, return the error wrapped in a timeout error. - return nil, fmt.Errorf("timeout after %s waiting for success: %w", opts.withTimeout, err) - } - - // exponential backoff, multiply delay by 2, limit by maxDelay - delay *= 2 - if delay > maxDelay { - delay = maxDelay + return nil, fmt.Errorf("timeout after %s waiting for success: %w", opts.withValidityCheckTimeout, err) } } } diff --git a/awsutil/rotate_test.go b/awsutil/rotate_test.go index 80d5969..858b34b 100644 --- a/awsutil/rotate_test.go +++ b/awsutil/rotate_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +const testRotationWaitTimeout = time.Second * 30 + func TestRotation(t *testing.T) { require, assert := require.New(t), assert.New(t) @@ -32,7 +34,7 @@ func TestRotation(t *testing.T) { } // Create an initial key - out, err := credsConfig.CreateAccessKey(WithUsername(username)) + out, err := credsConfig.CreateAccessKey(WithUsername(username), WithValidityCheckTimeout(testRotationWaitTimeout)) require.NoError(err) require.NotNil(out) @@ -49,8 +51,7 @@ func TestRotation(t *testing.T) { WithSecretKey(secretKey), ) require.NoError(err) - time.Sleep(10 * time.Second) - require.NoError(c.RotateKeys()) + require.NoError(c.RotateKeys(WithValidityCheckTimeout(testRotationWaitTimeout))) assert.NotEqual(accessKey, c.AccessKey) assert.NotEqual(secretKey, c.SecretKey) cleanupKey = &c.AccessKey @@ -115,7 +116,7 @@ func TestCallerIdentityErrorNoTimeout(t *testing.T) { require.Implements((*awserr.Error)(nil), err) } -func TestCallerIdentityErrorWithTimeout(t *testing.T) { +func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) { require := require.New(t) c := &CredentialsConfig{ @@ -123,7 +124,7 @@ func TestCallerIdentityErrorWithTimeout(t *testing.T) { SecretKey: "badagain", } - _, err := c.GetCallerIdentity(WithTimeout(time.Second * 10)) + _, err := c.GetCallerIdentity(WithValidityCheckTimeout(time.Second * 10)) require.NotNil(err) require.True(strings.HasPrefix(err.Error(), "timeout after 10s waiting for success")) err = errors.Unwrap(err)