Skip to content

Commit

Permalink
awsutil: add GetCallerIdentity to Create/Rotate (#10)
Browse files Browse the repository at this point in the history
* awsutil: add GetCallerIdentity to Create/Rotate

This adds GetCallerIdentity to CreateAccessKey and RotateKeys. The
latter is done simply by passing WithTimeout through to the inner
CreateAccessKey on Rotate. As such, both paths are conditional on a
non-zero WithTimeout option being given - this is to ensure that we
don't try to immediately verify, which is likely to always fail.

* awsutil: add some comments, remove exponential backoff

This does the following:

* Adds some comments for WithTimeout to RotateKeys and CreateAccessKey
  to explain exactly what these options do.

* Removes the exponential backoff. Even with the quadratic-time backoff
  we're still probably looking at a number of seconds of delay that's
  probably unnecessary. Retrying every 1s should not generate that much
  traffic and only need to be necessary for a few retries.

* Change WithTimeout to WithValidityCheckTimeout
  • Loading branch information
vancluever authored Sep 20, 2021
1 parent 29a9a55 commit 8ef7ccd
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 26 deletions.
8 changes: 4 additions & 4 deletions awsutil/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type options struct {
withMaxRetries *int
withRegion string
withHttpClient *http.Client
withTimeout time.Duration
withValidityCheckTimeout time.Duration
}

func getDefaultOptions() options {
Expand Down Expand Up @@ -162,11 +162,11 @@ func WithHttpClient(with *http.Client) Option {
}
}

// WithTimeout allows passing a timeout for operations that can wait
// WithValidityCheckTimeout allows passing a timeout for operations that can wait
// on success.
func WithTimeout(with time.Duration) Option {
func WithValidityCheckTimeout(with time.Duration) Option {
return func(o *options) error {
o.withTimeout = with
o.withValidityCheckTimeout = with
return nil
}
}
6 changes: 3 additions & 3 deletions awsutil/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ func Test_GetOpts(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, &opts.withHttpClient, &client)
})
t.Run("withTimeout", func(t *testing.T) {
opts, err := getOpts(WithTimeout(time.Second))
t.Run("withValidityCheckTimeout", func(t *testing.T) {
opts, err := getOpts(WithValidityCheckTimeout(time.Second))
require.NoError(t, err)
assert.Equal(t, opts.withTimeout, time.Second)
assert.Equal(t, opts.withValidityCheckTimeout, time.Second)
})
}
40 changes: 26 additions & 14 deletions awsutil/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
// if the old one could not be deleted.
//
// Supported options: WithEnvironmentCredentials, WithSharedCredentials,
// WithAwsSession, WithUsername
// WithAwsSession, WithUsername, WithValidityCheckTimeout. Note that WithValidityCheckTimeout
// here, when non-zero, controls the WithValidityCheckTimeout option on access key
// creation. See CreateAccessKey for more details.
func (c *CredentialsConfig) RotateKeys(opt ...Option) error {
if c.AccessKey == "" || c.SecretKey == "" {
return errors.New("cannot rotate credentials when either access_key or secret_key is empty")
Expand Down Expand Up @@ -62,11 +64,14 @@ func (c *CredentialsConfig) RotateKeys(opt ...Option) error {
// CreateAccessKey creates a new access/secret key pair.
//
// Supported options: WithEnvironmentCredentials, WithSharedCredentials,
// WithAwsSession, WithUsername
// WithAwsSession, WithUsername, WithValidityCheckTimeout
//
// When WithValidityCheckTimeout is non-zero, it specifies a timeout to wait on
// the created credentials to be valid and ready for use.
func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKeyOutput, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options in RotateKeys: %w", err)
return nil, fmt.Errorf("error reading options in CreateAccessKey: %w", err)
}

sess := opts.withAwsSession
Expand Down Expand Up @@ -114,6 +119,20 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey
return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from aws.CreateAccessKey")
}

// Check the credentials to make sure they are usable. We only do
// this if withValidityCheckTimeout is non-zero to ensue that we don't
// immediately fail due to eventual consistency.
if opts.withValidityCheckTimeout != 0 {
newC := &CredentialsConfig{
AccessKey: *createAccessKeyRes.AccessKey.AccessKeyId,
SecretKey: *createAccessKeyRes.AccessKey.SecretAccessKey,
}

if _, err := newC.GetCallerIdentity(WithValidityCheckTimeout(opts.withValidityCheckTimeout)); err != nil {
return nil, fmt.Errorf("error verifying new credentials: %w", err)
}
}

return createAccessKeyRes, nil
}

Expand Down Expand Up @@ -204,7 +223,7 @@ func (c *CredentialsConfig) GetSession(opt ...Option) (*session.Session, error)
// account and user ID.
//
// Supported options: WithEnvironmentCredentials,
// WithSharedCredentials, WithAwsSession, WithTimeout
// WithSharedCredentials, WithAwsSession, WithValidityCheckTimeout
func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIdentityOutput, error) {
opts, err := getOpts(opt...)
if err != nil {
Expand All @@ -225,8 +244,7 @@ func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIden
}

delay := time.Second
maxDelay := time.Second * 30
timeoutCtx, cancel := context.WithTimeout(context.Background(), opts.withTimeout)
timeoutCtx, cancel := context.WithTimeout(context.Background(), opts.withValidityCheckTimeout)
defer cancel()
for {
cid, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
Expand All @@ -241,19 +259,13 @@ func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIden

case <-timeoutCtx.Done():
// Format our error based on how we were called.
if opts.withTimeout == 0 {
if opts.withValidityCheckTimeout == 0 {
// There was no timeout, just return the error unwrapped.
return nil, err
}

// Otherwise, return the error wrapped in a timeout error.
return nil, fmt.Errorf("timeout after %s waiting for success: %w", opts.withTimeout, err)
}

// exponential backoff, multiply delay by 2, limit by maxDelay
delay *= 2
if delay > maxDelay {
delay = maxDelay
return nil, fmt.Errorf("timeout after %s waiting for success: %w", opts.withValidityCheckTimeout, err)
}
}
}
11 changes: 6 additions & 5 deletions awsutil/rotate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/stretchr/testify/require"
)

const testRotationWaitTimeout = time.Second * 30

func TestRotation(t *testing.T) {
require, assert := require.New(t), assert.New(t)

Expand All @@ -32,7 +34,7 @@ func TestRotation(t *testing.T) {
}

// Create an initial key
out, err := credsConfig.CreateAccessKey(WithUsername(username))
out, err := credsConfig.CreateAccessKey(WithUsername(username), WithValidityCheckTimeout(testRotationWaitTimeout))
require.NoError(err)
require.NotNil(out)

Expand All @@ -49,8 +51,7 @@ func TestRotation(t *testing.T) {
WithSecretKey(secretKey),
)
require.NoError(err)
time.Sleep(10 * time.Second)
require.NoError(c.RotateKeys())
require.NoError(c.RotateKeys(WithValidityCheckTimeout(testRotationWaitTimeout)))
assert.NotEqual(accessKey, c.AccessKey)
assert.NotEqual(secretKey, c.SecretKey)
cleanupKey = &c.AccessKey
Expand Down Expand Up @@ -115,15 +116,15 @@ func TestCallerIdentityErrorNoTimeout(t *testing.T) {
require.Implements((*awserr.Error)(nil), err)
}

func TestCallerIdentityErrorWithTimeout(t *testing.T) {
func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) {
require := require.New(t)

c := &CredentialsConfig{
AccessKey: "bad",
SecretKey: "badagain",
}

_, err := c.GetCallerIdentity(WithTimeout(time.Second * 10))
_, err := c.GetCallerIdentity(WithValidityCheckTimeout(time.Second * 10))
require.NotNil(err)
require.True(strings.HasPrefix(err.Error(), "timeout after 10s waiting for success"))
err = errors.Unwrap(err)
Expand Down

0 comments on commit 8ef7ccd

Please sign in to comment.