diff --git a/awsutil/README.md b/awsutil/README.md new file mode 100644 index 0000000..0055205 --- /dev/null +++ b/awsutil/README.md @@ -0,0 +1,37 @@ +# AWSUTIL - Go library for generating aws credentials + +*NOTE*: This is version 2 of the library. The `v0` branch contains version 0, +which may be needed for legacy applications or while transitioning to version 2. + +## Usage + +Following is an example usage of generating AWS credentials with static user credentials + +```go + +// AWS access keys for an IAM user can be used as your AWS credentials. +// This is an example of an access key and secret key +var accessKey = "AKIAIOSFODNN7EXAMPLE" +var secretKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + +// Access key IDs beginning with AKIA are long-term access keys. A long-term +// access key should be supplied when generating static credentials. +config, err := awsutil.NewCredentialsConfig( + awsutil.WithAccessKey(accessKey), + awsutil.WithSecretKey(secretKey), +) +if err != nil { + return err +} + +s3Client := s3.NewFromConfig(config) + +``` + +## Contributing to v0 + +To push a bug fix or feature for awsutil `v0`, branch out from the [awsutil/v0](https://github.com/hashicorp/go-secure-stdlib/tree/awsutil/v0) branch. +Commit the code changes you want to this new branch and open a PR. Make sure the PR +is configured so that the base branch is set to `awsutil/v0` and not `main`. Once the PR +is reviewed, feel free to merge it into the `awsutil/v0` branch. When creating a new +release, validate that the `Target` branch is `awsutil/v0` and the tag is `awsutil/v0.x.x`. \ No newline at end of file diff --git a/awsutil/clients.go b/awsutil/clients.go index d273777..5aca413 100644 --- a/awsutil/clients.go +++ b/awsutil/clients.go @@ -4,90 +4,98 @@ package awsutil import ( - "errors" + "context" "fmt" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" ) // IAMAPIFunc is a factory function for returning an IAM interface, -// useful for supplying mock interfaces for testing IAM. The session -// is passed into the function in the same way as done with the -// standard iam.New() constructor. -type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error) +// useful for supplying mock interfaces for testing IAM. +type IAMAPIFunc func(awsConfig *aws.Config) (IAMClient, error) + +// IAMClient represents an iam.Client +type IAMClient interface { + CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) + DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) + ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error) + GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error) +} // STSAPIFunc is a factory function for returning a STS interface, -// useful for supplying mock interfaces for testing STS. The session -// is passed into the function in the same way as done with the -// standard sts.New() constructor. -type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error) +// useful for supplying mock interfaces for testing STS. +type STSAPIFunc func(awsConfig *aws.Config) (STSClient, error) + +// STSClient represents an sts.Client +type STSClient interface { + AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) + GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) +} // IAMClient returns an IAM client. // -// Supported options: WithSession, WithIAMAPIFunc. +// Supported options: WithAwsConfig, WithIAMAPIFunc, WithIamEndpointResolver. // // If WithIAMAPIFunc is supplied, the included function is used as // the IAM client constructor instead. This can be used for Mocking // the IAM API. -func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) { +func (c *CredentialsConfig) IAMClient(ctx context.Context, opt ...Option) (IAMClient, error) { opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("error reading options: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) + cfg := opts.withAwsConfig + if cfg == nil { + cfg, err = c.GenerateCredentialChain(ctx, opt...) if err != nil { - return nil, fmt.Errorf("error calling GetSession: %w", err) + return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err) } } if opts.withIAMAPIFunc != nil { - return opts.withIAMAPIFunc(sess) + return opts.withIAMAPIFunc(cfg) } - client := iam.New(sess) - if client == nil { - return nil, errors.New("could not obtain iam client from session") + var iamOpts []func(*iam.Options) + if c.IAMEndpointResolver != nil { + iamOpts = append(iamOpts, iam.WithEndpointResolverV2(c.IAMEndpointResolver)) } - return client, nil + return iam.NewFromConfig(*cfg, iamOpts...), nil } // STSClient returns a STS client. // -// Supported options: WithSession, WithSTSAPIFunc. +// Supported options: WithAwsConfig, WithSTSAPIFunc, WithStsEndpointResolver. // // If WithSTSAPIFunc is supplied, the included function is used as // the STS client constructor instead. This can be used for Mocking // the STS API. -func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) { +func (c *CredentialsConfig) STSClient(ctx context.Context, opt ...Option) (STSClient, error) { opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("error reading options: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) + cfg := opts.withAwsConfig + if cfg == nil { + cfg, err = c.GenerateCredentialChain(ctx, opt...) if err != nil { - return nil, fmt.Errorf("error calling GetSession: %w", err) + return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err) } } if opts.withSTSAPIFunc != nil { - return opts.withSTSAPIFunc(sess) + return opts.withSTSAPIFunc(cfg) } - client := sts.New(sess) - if client == nil { - return nil, errors.New("could not obtain sts client from session") + var stsOpts []func(*sts.Options) + if c.STSEndpointResolver != nil { + stsOpts = append(stsOpts, sts.WithEndpointResolverV2(c.STSEndpointResolver)) } - return client, nil + return sts.NewFromConfig(*cfg, stsOpts...), nil } diff --git a/awsutil/clients_test.go b/awsutil/clients_test.go index 9a0d605..de3d30e 100644 --- a/awsutil/clients_test.go +++ b/awsutil/clients_test.go @@ -4,31 +4,24 @@ package awsutil import ( + "context" "errors" "fmt" "testing" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/stretchr/testify/require" ) const testOptionErr = "test option error" -const testBadClientType = "badclienttype" - -func testWithBadClientType(o *options) error { - o.withClientType = testBadClientType - return nil -} func TestCredentialsConfigIAMClient(t *testing.T) { cases := []struct { name string credentialsConfig *CredentialsConfig opts []Option - require func(t *testing.T, actual iamiface.IAMAPI) + require func(t *testing.T, actual IAMClient) requireErr string }{ { @@ -37,17 +30,11 @@ func TestCredentialsConfigIAMClient(t *testing.T) { opts: []Option{MockOptionErr(errors.New(testOptionErr))}, requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), }, - { - name: "session error", - credentialsConfig: &CredentialsConfig{}, - opts: []Option{testWithBadClientType}, - requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), - }, { name: "with mock IAM session", credentialsConfig: &CredentialsConfig{}, opts: []Option{WithIAMAPIFunc(NewMockIAM())}, - require: func(t *testing.T, actual iamiface.IAMAPI) { + require: func(t *testing.T, actual IAMClient) { t.Helper() require := require.New(t) require.Equal(&MockIAM{}, actual) @@ -57,10 +44,10 @@ func TestCredentialsConfigIAMClient(t *testing.T) { name: "no mock client", credentialsConfig: &CredentialsConfig{}, opts: []Option{}, - require: func(t *testing.T, actual iamiface.IAMAPI) { + require: func(t *testing.T, actual IAMClient) { t.Helper() require := require.New(t) - require.IsType(&iam.IAM{}, actual) + require.IsType(&iam.Client{}, actual) }, }, } @@ -69,7 +56,7 @@ func TestCredentialsConfigIAMClient(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { require := require.New(t) - actual, err := tc.credentialsConfig.IAMClient(tc.opts...) + actual, err := tc.credentialsConfig.IAMClient(context.TODO(), tc.opts...) if tc.requireErr != "" { require.EqualError(err, tc.requireErr) return @@ -86,7 +73,7 @@ func TestCredentialsConfigSTSClient(t *testing.T) { name string credentialsConfig *CredentialsConfig opts []Option - require func(t *testing.T, actual stsiface.STSAPI) + require func(t *testing.T, actual STSClient) requireErr string }{ { @@ -95,17 +82,11 @@ func TestCredentialsConfigSTSClient(t *testing.T) { opts: []Option{MockOptionErr(errors.New(testOptionErr))}, requireErr: fmt.Sprintf("error reading options: %s", testOptionErr), }, - { - name: "session error", - credentialsConfig: &CredentialsConfig{}, - opts: []Option{testWithBadClientType}, - requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType), - }, { name: "with mock STS session", credentialsConfig: &CredentialsConfig{}, opts: []Option{WithSTSAPIFunc(NewMockSTS())}, - require: func(t *testing.T, actual stsiface.STSAPI) { + require: func(t *testing.T, actual STSClient) { t.Helper() require := require.New(t) require.Equal(&MockSTS{}, actual) @@ -115,10 +96,10 @@ func TestCredentialsConfigSTSClient(t *testing.T) { name: "no mock client", credentialsConfig: &CredentialsConfig{}, opts: []Option{}, - require: func(t *testing.T, actual stsiface.STSAPI) { + require: func(t *testing.T, actual STSClient) { t.Helper() require := require.New(t) - require.IsType(&sts.STS{}, actual) + require.IsType(&sts.Client{}, actual) }, }, } @@ -127,7 +108,7 @@ func TestCredentialsConfigSTSClient(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { require := require.New(t) - actual, err := tc.credentialsConfig.STSClient(tc.opts...) + actual, err := tc.credentialsConfig.STSClient(context.TODO(), tc.opts...) if tc.requireErr != "" { require.EqualError(err, tc.requireErr) return diff --git a/awsutil/error.go b/awsutil/error.go index c15b322..c02ae4c 100644 --- a/awsutil/error.go +++ b/awsutil/error.go @@ -6,7 +6,7 @@ package awsutil import ( "errors" - awsRequest "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go-v2/aws/retry" multierror "github.com/hashicorp/go-multierror" ) @@ -15,10 +15,10 @@ var ErrUpstreamRateLimited = errors.New("upstream rate limited") // CheckAWSError will examine an error and convert to a logical error if // appropriate. If no appropriate error is found, return nil func CheckAWSError(err error) error { - // IsErrorThrottle will check if the error returned is one that matches - // known request limiting errors: - // https://github.com/aws/aws-sdk-go/blob/488d634b5a699b9118ac2befb5135922b4a77210/aws/request/retryer.go#L35 - if awsRequest.IsErrorThrottle(err) { + retryErr := retry.ThrottleErrorCode{ + Codes: retry.DefaultThrottleErrorCodes, + } + if retryErr.IsErrorThrottle(err).Bool() { return ErrUpstreamRateLimited } return nil diff --git a/awsutil/error_test.go b/awsutil/error_test.go index 71c1971..4715fb1 100644 --- a/awsutil/error_test.go +++ b/awsutil/error_test.go @@ -7,7 +7,7 @@ import ( "fmt" "testing" - "github.com/aws/aws-sdk-go/aws/awserr" + awserr "github.com/aws/smithy-go" multierror "github.com/hashicorp/go-multierror" ) @@ -23,12 +23,16 @@ func Test_CheckAWSError(t *testing.T) { }, { Name: "Upstream throttle error", - Err: awserr.New("Throttling", "", nil), + Err: MockAWSThrottleErr(), Expected: ErrUpstreamRateLimited, }, { - Name: "Upstream RequestLimitExceeded", - Err: awserr.New("RequestLimitExceeded", "Request rate limited", nil), + Name: "Upstream RequestLimitExceeded", + Err: &MockAWSErr{ + Code: "RequestLimitExceeded", + Message: "Request rate limited", + Fault: awserr.FaultServer, + }, Expected: ErrUpstreamRateLimited, }, } @@ -50,7 +54,7 @@ func Test_CheckAWSError(t *testing.T) { } func Test_AppendRateLimitedError(t *testing.T) { - awsErr := awserr.New("Throttling", "", nil) + throttleErr := MockAWSThrottleErr() testCases := []struct { Name string Err error @@ -63,8 +67,8 @@ func Test_AppendRateLimitedError(t *testing.T) { }, { Name: "Upstream throttle error", - Err: awsErr, - Expected: multierror.Append(awsErr, ErrUpstreamRateLimited), + Err: throttleErr, + Expected: multierror.Append(throttleErr, ErrUpstreamRateLimited), }, { Name: "Nil", diff --git a/awsutil/generate_credentials.go b/awsutil/generate_credentials.go index 89d9996..923cb53 100644 --- a/awsutil/generate_credentials.go +++ b/awsutil/generate_credentials.go @@ -4,24 +4,20 @@ package awsutil import ( - "encoding/base64" - "encoding/json" + "context" "fmt" - "io/ioutil" "net/http" "os" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/defaults" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" - "github.com/pkg/errors" ) const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID" @@ -36,11 +32,11 @@ type CredentialsConfig struct { // The session token if it is being used SessionToken string - // The IAM endpoint to use; if not set will use the default - IAMEndpoint string + // The IAM endpoint resolver to use; if not set will use the default + IAMEndpointResolver iam.EndpointResolverV2 - // The STS endpoint to use; if not set will use the default - STSEndpoint string + // The STS endpoint resolver to use; if not set will use the default + STSEndpointResolver sts.EndpointResolverV2 // If specified, the region will be provided to the config of the // EC2RoleProvider's client. This may be useful if you want to e.g. reuse @@ -87,10 +83,11 @@ type CredentialsConfig struct { Logger hclog.Logger } -// NewCredentialsConfig creates a CredentialsConfig with the provided options. +// GenerateCredentialChain uses the config to generate a credential chain +// suitable for creating AWS sessions and clients. // -// Supported options: WithAccessKey, WithSecretKey, WithLogger, WithStsEndpoint, -// WithIamEndpoint, WithMaxRetries, WithRegion, WithHttpClient, WithRoleArn, +// Supported options: WithAccessKey, WithSecretKey, WithLogger, WithStsEndpointResolver, +// WithIamEndpointResolver, WithMaxRetries, WithRegion, WithHttpClient, WithRoleArn, // WithRoleSessionName, WithRoleExternalId, WithRoleTags, WithWebIdentityTokenFile, // WithWebIdentityToken. func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) { @@ -100,14 +97,14 @@ func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) { } c := &CredentialsConfig{ - AccessKey: opts.withAccessKey, - SecretKey: opts.withSecretKey, - Logger: opts.withLogger, - STSEndpoint: opts.withStsEndpoint, - IAMEndpoint: opts.withIamEndpoint, - MaxRetries: opts.withMaxRetries, - RoleExternalId: opts.withRoleExternalId, - RoleTags: opts.withRoleTags, + AccessKey: opts.withAccessKey, + SecretKey: opts.withSecretKey, + Logger: opts.withLogger, + STSEndpointResolver: opts.withStsEndpointResolver, + IAMEndpointResolver: opts.withIamEndpointResolver, + MaxRetries: opts.withMaxRetries, + RoleExternalId: opts.withRoleExternalId, + RoleTags: opts.withRoleTags, } c.Region = opts.withRegion @@ -167,42 +164,22 @@ func (c *CredentialsConfig) log(level hclog.Level, msg string, args ...interface } } -// GenerateCredentialChain uses the config to generate a credential chain -// suitable for creating AWS sessions and clients. -// -// Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithSkipWebIdentityValidity -func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials.Credentials, error) { - opts, err := getOpts(opt...) - if err != nil { - return nil, fmt.Errorf("error reading options in GenerateCredentialChain: %w", err) - } - - var providers []credentials.Provider +func (c *CredentialsConfig) generateAwsConfigOptions(opts options) []func(*config.LoadOptions) error { + var cfgOpts []func(*config.LoadOptions) error - // Have one or the other but not both and not neither - if (c.AccessKey != "" && c.SecretKey == "") || (c.AccessKey == "" && c.SecretKey != "") { - return nil, fmt.Errorf("static AWS client credentials haven't been properly configured (the access key or secret key were provided but not both)") + if c.Region != "" { + cfgOpts = append(cfgOpts, config.WithRegion(c.Region)) } - // Add the static credential provider - if c.AccessKey != "" && c.SecretKey != "" { - providers = append(providers, &credentials.StaticProvider{ - Value: credentials.Value{ - AccessKeyID: c.AccessKey, - SecretAccessKey: c.SecretKey, - SessionToken: c.SessionToken, - }, - }) - c.log(hclog.Debug, "added static credential provider", "AccessKey", c.AccessKey) + + if c.MaxRetries != nil { + cfgOpts = append(cfgOpts, config.WithRetryMaxAttempts(*c.MaxRetries)) } - // Add the environment credential provider - if opts.withEnvironmentCredentials { - providers = append(providers, &credentials.EnvProvider{}) - c.log(hclog.Debug, "added environment variable credential provider") + if c.HTTPClient != nil { + cfgOpts = append(cfgOpts, config.WithHTTPClient(c.HTTPClient)) } - // Add the shared credentials provider + // Add the shared credentials if opts.withSharedCredentials { profile := os.Getenv("AWS_PROFILE") if profile != "" { @@ -211,231 +188,119 @@ func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials if c.Profile == "" { c.Profile = "default" } - providers = append(providers, &credentials.SharedCredentialsProvider{ - Filename: c.Filename, - Profile: c.Profile, - }) - c.log(hclog.Debug, "added shared credential provider") + cfgOpts = append(cfgOpts, config.WithSharedConfigProfile(c.Profile)) + cfgOpts = append(cfgOpts, config.WithSharedCredentialsFiles([]string{c.Filename})) + c.log(hclog.Debug, "added shared profile credential provider") } - // Add the assume role provider - roleARN := c.RoleARN - if roleARN == "" { - roleARN = os.Getenv("AWS_ROLE_ARN") - } - tokenPath := c.WebIdentityTokenFile - if tokenPath == "" { - tokenPath = os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE") - } - roleSessionName := c.RoleSessionName - if roleSessionName == "" { - roleSessionName = os.Getenv("AWS_ROLE_SESSION_NAME") + // Add the static credential + if c.AccessKey != "" && c.SecretKey != "" { + staticCred := credentials.NewStaticCredentialsProvider(c.AccessKey, c.SecretKey, c.SessionToken) + cfgOpts = append(cfgOpts, config.WithCredentialsProvider(staticCred)) + c.log(hclog.Debug, "added static credential provider", "AccessKey", c.AccessKey) } - if roleARN != "" { - if tokenPath != "" { + + // Add the assume role provider + if c.RoleARN != "" { + if c.WebIdentityTokenFile != "" { // this session is only created to create the WebIdentityRoleProvider, variables used to // assume a role are pulled from values provided in options. If the option values are // not set, then the provider will default to using the environment variables. - c.log(hclog.Debug, "adding web identity provider", "roleARN", roleARN) - sess, err := session.NewSession() - if err != nil { - return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider") - } - webIdentityProvider := stscreds.NewWebIdentityRoleProvider(sts.New(sess), roleARN, roleSessionName, tokenPath) - - if opts.withSkipWebIdentityValidity { - // Add the web identity role credential provider without - // generating credentials to check validity first - providers = append(providers, webIdentityProvider) - } else { - // Check if the webIdentityProvider can successfully retrieve - // credentials (via sts:AssumeRole), and warn if there's a problem. - if _, err := webIdentityProvider.Retrieve(); err != nil { - c.log(hclog.Warn, "error assuming role", "roleARN", roleARN, "tokenPath", tokenPath, "sessionName", roleSessionName, "err", err) - } else { - // Add the web identity role credential provider - providers = append(providers, webIdentityProvider) - } - } + webIdentityRoleCred := config.WithWebIdentityRoleCredentialOptions(func(options *stscreds.WebIdentityRoleOptions) { + options.RoleARN = c.RoleARN + options.RoleSessionName = c.RoleSessionName + options.TokenRetriever = stscreds.IdentityTokenFile(c.WebIdentityTokenFile) + }) + cfgOpts = append(cfgOpts, webIdentityRoleCred) + c.log(hclog.Debug, "added web identity provider", "roleARN", c.RoleARN) } else if c.WebIdentityToken != "" { - c.log(hclog.Debug, "adding web identity provider with token", "roleARN", roleARN) - sess, err := session.NewSession() - if err != nil { - return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider with token") - } - webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, FetchTokenContents(c.WebIdentityToken)) - - if opts.withSkipWebIdentityValidity { - // Add the web identity role credential provider without - // generating credentials to check validity first - providers = append(providers, webIdentityProvider) - } else { - // Check if the webIdentityProvider can successfully retrieve - // credentials (via sts:AssumeRole), and warn if there's a problem. - if _, err := webIdentityProvider.Retrieve(); err != nil { - c.log(hclog.Warn, "error assuming role with WebIdentityToken", "roleARN", roleARN, "sessionName", roleSessionName, "err", err) - } else { - // Add the web identity role credential provider - providers = append(providers, webIdentityProvider) - } - } + webIdentityRoleCred := config.WithWebIdentityRoleCredentialOptions(func(options *stscreds.WebIdentityRoleOptions) { + options.RoleARN = c.RoleARN + options.RoleSessionName = c.RoleSessionName + options.TokenRetriever = FetchTokenContents(c.WebIdentityToken) + }) + cfgOpts = append(cfgOpts, webIdentityRoleCred) + c.log(hclog.Debug, "added web identity provider with token", "roleARN", c.RoleARN) } else { // this session is only created to create the AssumeRoleProvider, variables used to // assume a role are pulled from values provided in options. If the option values are // not set, then the provider will default to using the environment variables. - c.log(hclog.Debug, "adding ec2-instance role provider", "roleARN", roleARN) - sess, err := session.NewSession() - if err != nil { - return nil, errors.Wrap(err, "error creating a new session for ec2 instance role credentials") - } - assumedRoleCredentials := stscreds.NewCredentials(sess, roleARN, func(p *stscreds.AssumeRoleProvider) { - p.RoleSessionName = roleSessionName - if c.RoleExternalId != "" { - p.ExternalID = aws.String(c.RoleExternalId) - } - if len(c.RoleTags) != 0 { - p.Tags = []*sts.Tag{} - for k, v := range c.RoleTags { - p.Tags = append(p.Tags, &sts.Tag{ - Key: aws.String(k), - Value: aws.String(v), - }) - } + assumeRoleCred := config.WithAssumeRoleCredentialOptions(func(options *stscreds.AssumeRoleOptions) { + options.RoleARN = c.RoleARN + options.RoleSessionName = c.RoleSessionName + options.ExternalID = aws.String(c.RoleExternalId) + for k, v := range c.RoleTags { + options.Tags = append(options.Tags, types.Tag{ + Key: aws.String(k), + Value: aws.String(v), + }) } }) - // Check if the credentials are successfully retrieved - // (via sts:AssumeRole), and warn if there's a problem. - creds, err := assumedRoleCredentials.Get() - if err != nil { - c.log(hclog.Warn, "error assuming role", "roleARN", roleARN, "sessionName", roleSessionName, "err", err) - } else { - providers = append(providers, &credentials.StaticProvider{ - Value: creds, - }) - } + cfgOpts = append(cfgOpts, assumeRoleCred) + c.log(hclog.Debug, "added ec2-instance role provider", "roleARN", c.RoleARN) } } - // Add the remote provider - def := defaults.Get() - if c.Region != "" { - def.Config.Region = aws.String(c.Region) + return cfgOpts +} + +// GenerateCredentialChain uses the config to generate a credential chain +// suitable for creating AWS clients. This will by default load configuration +// values from environment variables and append additional configuration options +// provided to the CredentialsConfig. +// +// Supported options: WithSharedCredentials, WithCredentialsProvider +func (c *CredentialsConfig) GenerateCredentialChain(ctx context.Context, opt ...Option) (*aws.Config, error) { + opts, err := getOpts(opt...) + if err != nil { + return nil, fmt.Errorf("error reading options in GenerateCredentialChain: %w", err) } - // We are taking care of this in the New() function but for legacy reasons - // we also set this here - if c.HTTPClient != nil { - def.Config.HTTPClient = c.HTTPClient - _, checkFullURI := os.LookupEnv("AWS_CONTAINER_CREDENTIALS_FULL_URI") - _, checkRelativeURI := os.LookupEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") - if !checkFullURI && !checkRelativeURI { - // match the sdk defaults from https://github.com/aws/aws-sdk-go/pull/3066 - def.Config.HTTPClient.Timeout = 1 * time.Second - def.Config.MaxRetries = aws.Int(2) - } + + // Have one or the other but not both and not neither + if (c.AccessKey != "" && c.SecretKey == "") || (c.AccessKey == "" && c.SecretKey != "") { + return nil, fmt.Errorf("static AWS client credentials haven't been properly configured (the access key or secret key were provided but not both)") } - providers = append(providers, defaults.RemoteCredProvider(*def.Config, def.Handlers)) + awsConfig, err := config.LoadDefaultConfig(ctx, c.generateAwsConfigOptions(opts)...) + if err != nil { + return nil, fmt.Errorf("failed to load SDK's default configurations with given credential options") + } - // Create the credentials required to access the API. - creds := credentials.NewChainCredentials(providers) - if creds == nil { - return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, web identity or instance metadata") + if opts.withCredentialsProvider != nil { + awsConfig.Credentials = opts.withCredentialsProvider } - return creds, nil + return &awsConfig, nil } -func RetrieveCreds(accessKey, secretKey, sessionToken string, logger hclog.Logger) (*credentials.Credentials, error) { +func RetrieveCreds(ctx context.Context, accessKey, secretKey, sessionToken string, logger hclog.Logger, opt ...Option) (*aws.Config, error) { credConfig := CredentialsConfig{ AccessKey: accessKey, SecretKey: secretKey, SessionToken: sessionToken, Logger: logger, } - creds, err := credConfig.GenerateCredentialChain() + creds, err := credConfig.GenerateCredentialChain(ctx, opt...) if err != nil { return nil, err } if creds == nil { return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata") } - - _, err = creds.Get() + _, err = creds.Credentials.Retrieve(ctx) if err != nil { return nil, fmt.Errorf("failed to retrieve credentials from credential chain: %w", err) } return creds, nil } -// GenerateLoginData populates the necessary data to send to the Vault server for generating a token -// This is useful for other API clients to use -func GenerateLoginData(creds *credentials.Credentials, headerValue, configuredRegion string, logger hclog.Logger) (map[string]interface{}, error) { - loginData := make(map[string]interface{}) - - // Use the credentials we've found to construct an STS session - region, err := GetRegion(configuredRegion) - if err != nil { - logger.Warn(fmt.Sprintf("defaulting region to %q due to %s", DefaultRegion, err.Error())) - region = DefaultRegion - } - stsSession, err := session.NewSessionWithOptions(session.Options{ - Config: aws.Config{ - Credentials: creds, - Region: ®ion, - EndpointResolver: endpoints.ResolverFunc(stsSigningResolver), - }, - }) - if err != nil { - return nil, err - } - - var params *sts.GetCallerIdentityInput - svc := sts.New(stsSession) - stsRequest, _ := svc.GetCallerIdentityRequest(params) - - // Inject the required auth header value, if supplied, and then sign the request including that header - if headerValue != "" { - stsRequest.HTTPRequest.Header.Add(iamServerIdHeader, headerValue) - } - stsRequest.Sign() - - // Now extract out the relevant parts of the request - headersJson, err := json.Marshal(stsRequest.HTTPRequest.Header) - if err != nil { - return nil, err - } - requestBody, err := ioutil.ReadAll(stsRequest.HTTPRequest.Body) - if err != nil { - return nil, err - } - loginData["iam_http_request_method"] = stsRequest.HTTPRequest.Method - loginData["iam_request_url"] = base64.StdEncoding.EncodeToString([]byte(stsRequest.HTTPRequest.URL.String())) - loginData["iam_request_headers"] = base64.StdEncoding.EncodeToString(headersJson) - loginData["iam_request_body"] = base64.StdEncoding.EncodeToString(requestBody) - - return loginData, nil -} - -// STS is a really weird service that used to only have global endpoints but now has regional endpoints as well. -// For backwards compatibility, even if you request a region other than us-east-1, it'll still sign for us-east-1. -// See, e.g., https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code -// So we have to shim in this EndpointResolver to force it to sign for the right region -func stsSigningResolver(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - defaultEndpoint, err := endpoints.DefaultResolver().EndpointFor(service, region, optFns...) - if err != nil { - return defaultEndpoint, err - } - defaultEndpoint.SigningRegion = region - return defaultEndpoint, nil -} - // FetchTokenContents allows the use of the content of a token in the // WebIdentityProvider, instead of the path to a token. Useful with a // serviceaccount token requested directly from the EKS/K8s API, for example. type FetchTokenContents []byte -var _ stscreds.TokenFetcher = (*FetchTokenContents)(nil) +var _ stscreds.IdentityTokenRetriever = (*FetchTokenContents)(nil) -func (f FetchTokenContents) FetchToken(_ aws.Context) ([]byte, error) { +func (f FetchTokenContents) GetIdentityToken() ([]byte, error) { return f, nil } diff --git a/awsutil/generate_credentials_test.go b/awsutil/generate_credentials_test.go new file mode 100644 index 0000000..2486799 --- /dev/null +++ b/awsutil/generate_credentials_test.go @@ -0,0 +1,440 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package awsutil + +import ( + "bytes" + "context" + "errors" + "os" + "path" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + stsTypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCredentialsConfig(t *testing.T) { + cases := []struct { + name string + opts []Option + expectedCfg *CredentialsConfig + expectedErr string + }{ + { + name: "session name without role arn", + opts: []Option{ + WithRoleSessionName("foobar"), + }, + expectedErr: "role session name specified without role ARN", + }, + { + name: "external id without role arn", + opts: []Option{ + WithRoleExternalId("foobar"), + }, + expectedErr: "role external ID specified without role ARN", + }, + { + name: "role tags without role arn", + opts: []Option{ + WithRoleTags(map[string]string{"foo": "bar"}), + }, + expectedErr: "role tags specified without role ARN", + }, + { + name: "web identity token file without role arn", + opts: []Option{ + WithWebIdentityTokenFile("foobar"), + }, + expectedErr: "web identity token file specified without role ARN", + }, + { + name: "web identity token without role arn", + opts: []Option{ + WithWebIdentityToken("foobar"), + }, + expectedErr: "web identity token specified without role ARN", + }, + { + name: "valid config", + opts: []Option{ + WithAccessKey("foo"), + WithSecretKey("bar"), + WithRoleSessionName("baz"), + WithRoleArn("foobar"), + WithRoleExternalId("foobaz"), + WithRoleTags(map[string]string{"foo": "bar"}), + WithRegion("barbaz"), + WithWebIdentityToken("bazfoo"), + WithWebIdentityTokenFile("barfoo"), + WithMaxRetries(aws.Int(3)), + }, + expectedCfg: &CredentialsConfig{ + AccessKey: "foo", + SecretKey: "bar", + RoleSessionName: "baz", + RoleARN: "foobar", + RoleExternalId: "foobaz", + RoleTags: map[string]string{"foo": "bar"}, + Region: "barbaz", + WebIdentityToken: "bazfoo", + WebIdentityTokenFile: "barfoo", + MaxRetries: aws.Int(3), + }, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + actualCfg, err := NewCredentialsConfig(tc.opts...) + if tc.expectedErr != "" { + require.Error(err) + require.EqualError(err, tc.expectedErr) + assert.Nil(actualCfg) + return + } + require.NoError(err) + assert.NotNil(actualCfg) + assert.Equal(tc.expectedCfg.AccessKey, actualCfg.AccessKey) + assert.Equal(tc.expectedCfg.SecretKey, actualCfg.SecretKey) + assert.Equal(tc.expectedCfg.RoleSessionName, actualCfg.RoleSessionName) + assert.Equal(tc.expectedCfg.RoleExternalId, actualCfg.RoleExternalId) + assert.Equal(tc.expectedCfg.RoleTags, actualCfg.RoleTags) + assert.Equal(tc.expectedCfg.Region, actualCfg.Region) + assert.Equal(tc.expectedCfg.WebIdentityToken, actualCfg.WebIdentityToken) + assert.Equal(tc.expectedCfg.WebIdentityTokenFile, actualCfg.WebIdentityTokenFile) + assert.Equal(tc.expectedCfg.MaxRetries, actualCfg.MaxRetries) + }) + } +} + +func TestRetrieveCreds(t *testing.T) { + cases := []struct { + name string + opts []Option + expectedCfg *CredentialsConfig + expectedErr string + }{ + { + name: "success", + opts: []Option{ + WithCredentialsProvider( + NewMockCredentialsProvider( + WithCredentials(aws.Credentials{ + AccessKeyID: "foo", + SecretAccessKey: "bar", + SessionToken: "baz", + }), + ), + ), + }, + }, + { + name: "error", + opts: []Option{ + WithCredentialsProvider( + NewMockCredentialsProvider( + WithError(errors.New("invalid credentials")), + ), + ), + }, + expectedErr: "failed to retrieve credentials from credential chain: invalid credentials", + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + cfg, err := NewCredentialsConfig() + require.NoError(err) + require.NotNil(cfg) + + awscfg, err := RetrieveCreds(context.Background(), "foo", "bar", "baz", nil, tc.opts...) + if tc.expectedErr != "" { + require.Error(err) + require.EqualError(err, tc.expectedErr) + assert.Nil(awscfg) + return + } + require.NoError(err) + assert.NotNil(awscfg) + + creds, err := awscfg.Credentials.Retrieve(context.Background()) + require.NoError(err) + assert.Equal("foo", creds.AccessKeyID) + assert.Equal("bar", creds.SecretAccessKey) + assert.Equal("baz", creds.SessionToken) + }) + } +} + +func TestGenerateCredentialChain(t *testing.T) { + cases := []struct { + name string + opts []Option + expectedErr string + }{ + { + name: "static cred missing access key", + opts: []Option{ + WithSecretKey("foo"), + }, + expectedErr: "static AWS client credentials haven't been properly configured (the access key or secret key were provided but not both)", + }, + { + name: "static cred missing secret key", + opts: []Option{ + WithAccessKey("foo"), + }, + expectedErr: "static AWS client credentials haven't been properly configured (the access key or secret key were provided but not both)", + }, + { + name: "valid static cred", + opts: []Option{ + WithAccessKey("foo"), + WithSecretKey("bar"), + }, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + + cfg, err := NewCredentialsConfig(tc.opts...) + require.NoError(err) + require.NotNil(cfg) + + awscfg, err := cfg.GenerateCredentialChain(context.Background()) + if tc.expectedErr != "" { + require.Error(err) + assert.ErrorContains(err, tc.expectedErr) + assert.Nil(awscfg) + return + } + require.NoError(err) + assert.NotNil(awscfg) + }) + } +} + +func TestGenerateAwsConfigOptions(t *testing.T) { + // create web identity token file for test + dir := t.TempDir() + webIdentityTokenFilePath := path.Join(dir, "webIdentityToken") + f, err := os.Create(webIdentityTokenFilePath) + require.NoError(t, err) + _, err = f.Write([]byte("hello world")) + require.NoError(t, err) + require.NoError(t, f.Close()) + + cases := []struct { + name string + cfg *CredentialsConfig + opts options + expectedLoadOptions config.LoadOptions + expectedWebIdentityRoleOptions *stscreds.WebIdentityRoleOptions + expectedAssumeRoleOptions *stscreds.AssumeRoleOptions + expectedStaticCredentials *aws.Credentials + }{ + { + name: "region", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig( + WithRegion("us-west-2"), + ) + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-west-2", + }, + }, + { + name: "default region", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig() + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + }, + }, + { + name: "max retries", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig( + WithMaxRetries(aws.Int(5)), + ) + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + RetryMaxAttempts: 5, + }, + }, + { + name: "shared credential profile", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig() + require.NoError(t, err) + credCfg.Profile = "foobar" + credCfg.Filename = "foobaz" + return credCfg + }(), + opts: options{ + withSharedCredentials: true, + }, + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + SharedConfigProfile: "foobar", + SharedCredentialsFiles: []string{"foobaz"}, + }, + }, + { + name: "web identity token file credential", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig( + WithRoleArn("foo"), + WithWebIdentityTokenFile(webIdentityTokenFilePath), + WithRoleSessionName("bar"), + ) + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + }, + expectedWebIdentityRoleOptions: &stscreds.WebIdentityRoleOptions{ + RoleARN: "foo", + RoleSessionName: "bar", + TokenRetriever: stscreds.IdentityTokenFile(webIdentityTokenFilePath), + }, + }, + { + name: "web identity token credential", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig( + WithRoleArn("foo"), + WithWebIdentityToken("hello_world"), + WithRoleSessionName("bar"), + ) + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + }, + expectedWebIdentityRoleOptions: &stscreds.WebIdentityRoleOptions{ + RoleARN: "foo", + RoleSessionName: "bar", + TokenRetriever: FetchTokenContents("hello_world"), + }, + }, + { + name: "assume role credential", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig( + WithRoleArn("foo"), + WithRoleSessionName("bar"), + WithRoleExternalId("baz"), + WithRoleTags(map[string]string{"foo": "bar"}), + ) + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + }, + expectedAssumeRoleOptions: &stscreds.AssumeRoleOptions{ + RoleARN: "foo", + RoleSessionName: "bar", + ExternalID: aws.String("baz"), + Tags: []stsTypes.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + }, + { + name: "static credential", + cfg: func() *CredentialsConfig { + credCfg, err := NewCredentialsConfig( + WithAccessKey("foo"), + WithSecretKey("bar"), + ) + require.NoError(t, err) + return credCfg + }(), + expectedLoadOptions: config.LoadOptions{ + Region: "us-east-1", + }, + expectedStaticCredentials: &aws.Credentials{ + AccessKeyID: "foo", + SecretAccessKey: "bar", + }, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + opts := tc.cfg.generateAwsConfigOptions(tc.opts) + cfgLoadOpts := config.LoadOptions{} + for _, f := range opts { + require.NoError(f(&cfgLoadOpts)) + } + assert.NotNil(cfgLoadOpts.HTTPClient) + assert.Equal(tc.expectedLoadOptions.Region, cfgLoadOpts.Region) + assert.Equal(tc.expectedLoadOptions.RetryMaxAttempts, cfgLoadOpts.RetryMaxAttempts) + assert.Equal(tc.expectedLoadOptions.SharedConfigProfile, cfgLoadOpts.SharedConfigProfile) + assert.Equal(tc.expectedLoadOptions.SharedCredentialsFiles, cfgLoadOpts.SharedCredentialsFiles) + + if tc.expectedWebIdentityRoleOptions != nil { + actualWebIdentityToken := stscreds.WebIdentityRoleOptions{} + cfgLoadOpts.WebIdentityRoleCredentialOptions(&actualWebIdentityToken) + assert.Equal(tc.expectedWebIdentityRoleOptions.RoleARN, actualWebIdentityToken.RoleARN) + assert.Equal(tc.expectedWebIdentityRoleOptions.RoleSessionName, actualWebIdentityToken.RoleSessionName) + assert.NotNil(actualWebIdentityToken.TokenRetriever) + expectedToken, err := tc.expectedWebIdentityRoleOptions.TokenRetriever.GetIdentityToken() + require.NoError(err) + actualToken, err := actualWebIdentityToken.TokenRetriever.GetIdentityToken() + require.NoError(err) + assert.True(bytes.Equal(expectedToken, actualToken)) + } + + if tc.expectedAssumeRoleOptions != nil { + actualAssumeRoleOptions := stscreds.AssumeRoleOptions{} + cfgLoadOpts.AssumeRoleCredentialOptions(&actualAssumeRoleOptions) + assert.Equal(tc.expectedAssumeRoleOptions.RoleARN, actualAssumeRoleOptions.RoleARN) + assert.Equal(tc.expectedAssumeRoleOptions.RoleSessionName, actualAssumeRoleOptions.RoleSessionName) + assert.Equal(tc.expectedAssumeRoleOptions.ExternalID, actualAssumeRoleOptions.ExternalID) + assert.Equal(tc.expectedAssumeRoleOptions.Tags, actualAssumeRoleOptions.Tags) + } + + if tc.expectedStaticCredentials != nil { + require.NotNil(cfgLoadOpts.Credentials) + actualCreds, err := cfgLoadOpts.Credentials.Retrieve(context.Background()) + require.NoError(err) + assert.Equal(tc.expectedStaticCredentials.AccessKeyID, actualCreds.AccessKeyID) + assert.Equal(tc.expectedStaticCredentials.SecretAccessKey, actualCreds.SecretAccessKey) + } + }) + } +} diff --git a/awsutil/go.mod b/awsutil/go.mod index 0f85228..1a813a0 100644 --- a/awsutil/go.mod +++ b/awsutil/go.mod @@ -1,15 +1,36 @@ -module github.com/hashicorp/go-secure-stdlib/awsutil +module github.com/hashicorp/go-secure-stdlib/awsutil/v2 -go 1.16 +go 1.20 require ( - github.com/aws/aws-sdk-go v1.34.0 - github.com/hashicorp/errwrap v1.1.0 + github.com/aws/aws-sdk-go-v2 v1.20.1 + github.com/aws/aws-sdk-go-v2/config v1.18.33 + github.com/aws/aws-sdk-go-v2/credentials v1.13.32 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.8 + github.com/aws/aws-sdk-go-v2/service/iam v1.22.2 + github.com/aws/aws-sdk-go-v2/service/sts v1.21.2 + github.com/aws/smithy-go v1.14.1 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-hclog v1.5.0 github.com/hashicorp/go-multierror v1.1.1 - github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.38 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.32 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.39 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.32 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.13.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect github.com/kr/pretty v0.3.0 // indirect - github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.8.2 + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.1.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/awsutil/go.sum b/awsutil/go.sum index 4ae6e05..bf893b1 100644 --- a/awsutil/go.sum +++ b/awsutil/go.sum @@ -1,12 +1,37 @@ -github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= -github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go-v2 v1.20.1 h1:rZBf5DWr7YGrnlTK4kgDQGn1ltqOg5orCYb/UhOFZkg= +github.com/aws/aws-sdk-go-v2 v1.20.1/go.mod h1:NU06lETsFm8fUC6ZjhgDpVBcGZTFQ6XM+LZWZxMI4ac= +github.com/aws/aws-sdk-go-v2/config v1.18.33 h1:JKcw5SFxFW/rpM4mOPjv0VQ11E2kxW13F3exWOy7VZU= +github.com/aws/aws-sdk-go-v2/config v1.18.33/go.mod h1:hXO/l9pgY3K5oZJldamP0pbZHdPqqk+4/maa7DSD3cA= +github.com/aws/aws-sdk-go-v2/credentials v1.13.32 h1:lIH1eKPcCY1ylR4B6PkBGRWMHO3aVenOKJHWiS4/G2w= +github.com/aws/aws-sdk-go-v2/credentials v1.13.32/go.mod h1:lL8U3v/Y79YRG69WlAho0OHIKUXCyFvSXaIvfo81sls= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.8 h1:DK/9C+UN/X+1+Wm8pqaDksQr2tSLzq+8X1/rI/ZxKEQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.8/go.mod h1:ce7BgLQfYr5hQFdy67oX2svto3ufGtm6oBvmsHScI1Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.38 h1:c8ed/T9T2K5I+h/JzmF5tpI46+OODQ74dzmdo+QnaMg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.38/go.mod h1:qggunOChCMu9ZF/UkAfhTz25+U2rLVb3ya0Ua6TTfCA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.32 h1:hNeAAymUY5gu11WrrmFb3CVIp9Dar9hbo44yzzcQpzA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.32/go.mod h1:0ZXSqrty4FtQ7p8TEuRde/SZm9X05KT18LAUlR40Ln0= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.39 h1:fc0ukRAiP1syoSGZYu+DaE+FulSYhTiJ8WpVu5jElU4= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.39/go.mod h1:WLAW8PT7+JhjZfLSWe7WEJaJu0GNo0cKc2Zyo003RBs= +github.com/aws/aws-sdk-go-v2/service/iam v1.22.2 h1:DPFxx/6Zwes/MiadlDteVqDKov7yQ5v9vuwfhZuJm1s= +github.com/aws/aws-sdk-go-v2/service/iam v1.22.2/go.mod h1:cQTMNdo/Z5t1DDRsUnx0a2j6cPnytMBidUYZw2zks28= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.32 h1:dGAseBFEYxth10V23b5e2mAS+tX7oVbfYHD6dnDdAsg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.32/go.mod h1:4jwAWKEkCR0anWk5+1RbfSg1R5Gzld7NLiuaq5bTR/Y= +github.com/aws/aws-sdk-go-v2/service/sso v1.13.2 h1:A2RlEMo4SJSwbNoUUgkxTAEMduAy/8wG3eB2b2lP4gY= +github.com/aws/aws-sdk-go-v2/service/sso v1.13.2/go.mod h1:ju+nNXUunfIFamXUIZQiICjnO/TPlOmWcYhZcSy7xaE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.2 h1:OJELEgyaT2kmaBGZ+myyZbTTLobfe3ox3FSh5eYK9Qs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.15.2/go.mod h1:ubDBBaDFs1GHijSOTi8ljppML15GLG0HxhILtbjNNYQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.21.2 h1:ympg1+Lnq33XLhcK/xTG4yZHPs1Oyxu+6DEWbl7qOzA= +github.com/aws/aws-sdk-go-v2/service/sts v1.21.2/go.mod h1:FQ/DQcOfESELfJi5ED+IPPAjI5xC6nxtSolVVB773jM= +github.com/aws/smithy-go v1.14.1 h1:EFKMUmH/iHMqLiwoEDx2rRjRQpI1YCn5jTysoaDujFs= +github.com/aws/smithy-go v1.14.1/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -16,10 +41,7 @@ github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+ github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= -github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= @@ -34,38 +56,25 @@ github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 h1:nonptSpoQ4vQjyraW20DXPAglgQfVnM9ZC6MmNLMR60= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/awsutil/mocks.go b/awsutil/mocks.go index e8bc87d..3ce54cd 100644 --- a/awsutil/mocks.go +++ b/awsutil/mocks.go @@ -4,13 +4,58 @@ package awsutil import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" + awserr "github.com/aws/smithy-go" +) + +var ( + _ awserr.APIError = (*MockAWSErr)(nil) + _ aws.CredentialsProvider = (*MockCredentialsProvider)(nil) + _ IAMClient = (*MockIAM)(nil) + _ STSClient = (*MockSTS)(nil) ) +// MockAWSErr is used to mock API error types for tests +type MockAWSErr struct { + Code string + Message string + Fault awserr.ErrorFault +} + +// ErrorCode returns the error code +func (e *MockAWSErr) ErrorCode() string { + return e.Code +} + +// Error returns the error message +func (e *MockAWSErr) Error() string { + return e.Message +} + +// ErrorFault returns one of the following values: +// FaultClient, FaultServer, FaultUnknown +func (e *MockAWSErr) ErrorFault() awserr.ErrorFault { + return e.Fault +} + +// ErrorMessage returns the error message +func (e *MockAWSErr) ErrorMessage() string { + return e.Message +} + +// MockAWSThrottleErr returns a mocked aws error that mimics a throttling exception. +func MockAWSThrottleErr() error { + return &MockAWSErr{ + Code: "ThrottlingException", + Message: "Throttling Exception", + Fault: awserr.FaultServer, + } +} + // MockOptionErr provides a mock option error for use with testing. func MockOptionErr(withErr error) Option { return func(_ *options) error { @@ -18,9 +63,53 @@ func MockOptionErr(withErr error) Option { } } +// MockCredentialsProvider provides a way to mock the aws.CredentialsProvider +type MockCredentialsProvider struct { + aws.CredentialsProvider + + aws.Credentials + error +} + +// MockCredentialsProviderOption is a function for setting +// the various fields on a MockCredentialsProvider object. +type MockCredentialsProviderOption func(m *MockCredentialsProvider) + +// WithCredentials sets the output for the Retrieve method. +func WithCredentials(o aws.Credentials) MockCredentialsProviderOption { + return func(m *MockCredentialsProvider) { + m.Credentials = o + } +} + +// WithCredentials sets the output for the Retrieve method. +func WithError(o error) MockCredentialsProviderOption { + return func(m *MockCredentialsProvider) { + m.error = o + } +} + +// NewMockCredentialsProvider provides a factory function to +// use with the WithCredentialsProvider option. +func NewMockCredentialsProvider(opts ...MockCredentialsProviderOption) aws.CredentialsProvider { + m := new(MockCredentialsProvider) + for _, opt := range opts { + opt(m) + } + return m +} + +func (m *MockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + if m.error != nil { + return aws.Credentials{}, m.error + } + + return m.Credentials, nil +} + // MockIAM provides a way to mock the AWS IAM API. type MockIAM struct { - iamiface.IAMAPI + IAMClient CreateAccessKeyOutput *iam.CreateAccessKeyOutput CreateAccessKeyError error @@ -96,7 +185,7 @@ func WithGetUserError(e error) MockIAMOption { // NewMockIAM provides a factory function to use with the WithIAMAPIFunc // option. func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc { - return func(_ *session.Session) (iamiface.IAMAPI, error) { + return func(_ *aws.Config) (IAMClient, error) { m := new(MockIAM) for _, opt := range opts { if err := opt(m); err != nil { @@ -108,7 +197,7 @@ func NewMockIAM(opts ...MockIAMOption) IAMAPIFunc { } } -func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) { +func (m *MockIAM) CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) { if m.CreateAccessKeyError != nil { return nil, m.CreateAccessKeyError } @@ -116,11 +205,11 @@ func (m *MockIAM) CreateAccessKey(*iam.CreateAccessKeyInput) (*iam.CreateAccessK return m.CreateAccessKeyOutput, nil } -func (m *MockIAM) DeleteAccessKey(*iam.DeleteAccessKeyInput) (*iam.DeleteAccessKeyOutput, error) { +func (m *MockIAM) DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) { return &iam.DeleteAccessKeyOutput{}, m.DeleteAccessKeyError } -func (m *MockIAM) ListAccessKeys(*iam.ListAccessKeysInput) (*iam.ListAccessKeysOutput, error) { +func (m *MockIAM) ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error) { if m.ListAccessKeysError != nil { return nil, m.ListAccessKeysError } @@ -128,7 +217,7 @@ func (m *MockIAM) ListAccessKeys(*iam.ListAccessKeysInput) (*iam.ListAccessKeysO return m.ListAccessKeysOutput, nil } -func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) { +func (m *MockIAM) GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error) { if m.GetUserError != nil { return nil, m.GetUserError } @@ -138,16 +227,35 @@ func (m *MockIAM) GetUser(*iam.GetUserInput) (*iam.GetUserOutput, error) { // MockSTS provides a way to mock the AWS STS API. type MockSTS struct { - stsiface.STSAPI + STSClient GetCallerIdentityOutput *sts.GetCallerIdentityOutput GetCallerIdentityError error + + AssumeRoleOutput *sts.AssumeRoleOutput + AssumeRoleError error } // MockSTSOption is a function for setting the various fields on a MockSTS // object. type MockSTSOption func(m *MockSTS) error +// WithAssumeRoleOutput sets the output for the AssumeRole method. +func WithAssumeRoleOutput(o *sts.AssumeRoleOutput) MockSTSOption { + return func(m *MockSTS) error { + m.AssumeRoleOutput = o + return nil + } +} + +// WithAssumeRoleError sets the error output for the AssumeRole method. +func WithAssumeRoleError(e error) MockSTSOption { + return func(m *MockSTS) error { + m.AssumeRoleError = e + return nil + } +} + // WithGetCallerIdentityOutput sets the output for the GetCallerIdentity // method. func WithGetCallerIdentityOutput(o *sts.GetCallerIdentityOutput) MockSTSOption { @@ -172,7 +280,7 @@ func WithGetCallerIdentityError(e error) MockSTSOption { // If withGetCallerIdentityError is supplied, calls to GetCallerIdentity will // return the supplied error. Otherwise, a basic mock API output is returned. func NewMockSTS(opts ...MockSTSOption) STSAPIFunc { - return func(_ *session.Session) (stsiface.STSAPI, error) { + return func(_ *aws.Config) (STSClient, error) { m := new(MockSTS) for _, opt := range opts { if err := opt(m); err != nil { @@ -184,10 +292,18 @@ func NewMockSTS(opts ...MockSTSOption) STSAPIFunc { } } -func (m *MockSTS) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { +func (m *MockSTS) GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { if m.GetCallerIdentityError != nil { return nil, m.GetCallerIdentityError } return m.GetCallerIdentityOutput, nil } + +func (m *MockSTS) AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + if m.AssumeRoleError != nil { + return nil, m.AssumeRoleError + } + + return m.AssumeRoleOutput, nil +} diff --git a/awsutil/mocks_test.go b/awsutil/mocks_test.go index 0cc8be7..8ea1da5 100644 --- a/awsutil/mocks_test.go +++ b/awsutil/mocks_test.go @@ -4,16 +4,77 @@ package awsutil import ( + "context" "errors" "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/sts" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + iamTypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + stsTypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestMockCredentialsProvider(t *testing.T) { + cases := []struct { + name string + opts []MockCredentialsProviderOption + expectedCredentials aws.Credentials + expectedError string + }{ + { + name: "with credentials", + opts: []MockCredentialsProviderOption{ + WithCredentials(aws.Credentials{ + AccessKeyID: "foobar", + SecretAccessKey: "barbaz", + }), + }, + expectedCredentials: aws.Credentials{ + AccessKeyID: "foobar", + SecretAccessKey: "barbaz", + }, + }, + { + name: "with error", + opts: []MockCredentialsProviderOption{ + WithError(errors.New("credential provider error test")), + }, + expectedError: "credential provider error test", + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + actualCredentialsProvider := NewMockCredentialsProvider(tc.opts...) + require.NotNil(actualCredentialsProvider) + + actualCredentials, err := actualCredentialsProvider.Retrieve(context.TODO()) + if tc.expectedError != "" { + assert.NotNil(actualCredentials) + assert.Empty(actualCredentials.AccessKeyID) + assert.Empty(actualCredentials.SecretAccessKey) + assert.Empty(actualCredentials.SessionToken) + assert.Empty(actualCredentials.Source) + assert.Empty(actualCredentials.Expires) + assert.Empty(actualCredentials.CanExpire) + assert.Error(err) + assert.EqualError(err, tc.expectedError) + return + } + assert.NoError(err) + assert.NotNil(actualCredentials) + assert.Equal(tc.expectedCredentials, actualCredentials) + }) + } +} + func TestMockIAM(t *testing.T) { cases := []struct { name string @@ -30,14 +91,14 @@ func TestMockIAM(t *testing.T) { name: "CreateAccessKeyOutput", opts: []MockIAMOption{WithCreateAccessKeyOutput( &iam.CreateAccessKeyOutput{ - AccessKey: &iam.AccessKey{ + AccessKey: &iamTypes.AccessKey{ AccessKeyId: aws.String("foobar"), SecretAccessKey: aws.String("bazqux"), }, }, )}, expectedCreateAccessKeyOutput: &iam.CreateAccessKeyOutput{ - AccessKey: &iam.AccessKey{ + AccessKey: &iamTypes.AccessKey{ AccessKeyId: aws.String("foobar"), SecretAccessKey: aws.String("bazqux"), }, @@ -52,20 +113,20 @@ func TestMockIAM(t *testing.T) { name: "ListAccessKeysOutput", opts: []MockIAMOption{WithListAccessKeysOutput( &iam.ListAccessKeysOutput{ - AccessKeyMetadata: []*iam.AccessKeyMetadata{ + AccessKeyMetadata: []iamTypes.AccessKeyMetadata{ { AccessKeyId: aws.String("foobar"), - Status: aws.String("bazqux"), + Status: iamTypes.StatusTypeActive, UserName: aws.String("janedoe"), }, }, }, )}, expectedListAccessKeysOutput: &iam.ListAccessKeysOutput{ - AccessKeyMetadata: []*iam.AccessKeyMetadata{ + AccessKeyMetadata: []iamTypes.AccessKeyMetadata{ { AccessKeyId: aws.String("foobar"), - Status: aws.String("bazqux"), + Status: iamTypes.StatusTypeActive, UserName: aws.String("janedoe"), }, }, @@ -85,7 +146,7 @@ func TestMockIAM(t *testing.T) { name: "GetUserOutput", opts: []MockIAMOption{WithGetUserOutput( &iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), UserName: aws.String("JohnDoe"), @@ -93,7 +154,7 @@ func TestMockIAM(t *testing.T) { }, )}, expectedGetUserOutput: &iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ Arn: aws.String("arn:aws:iam::123456789012:user/JohnDoe"), UserId: aws.String("AIDAJQABLZS4A3QDU576Q"), UserName: aws.String("JohnDoe"), @@ -116,9 +177,9 @@ func TestMockIAM(t *testing.T) { f := NewMockIAM(tc.opts...) m, err := f(nil) require.NoError(err) // Nothing returns an error right now - actualCreateAccessKeyOutput, actualCreateAccessKeyError := m.CreateAccessKey(nil) - _, actualDeleteAccessKeyError := m.DeleteAccessKey(nil) - actualGetUserOutput, actualGetUserError := m.GetUser(nil) + actualCreateAccessKeyOutput, actualCreateAccessKeyError := m.CreateAccessKey(context.TODO(), nil) + _, actualDeleteAccessKeyError := m.DeleteAccessKey(context.TODO(), nil) + actualGetUserOutput, actualGetUserError := m.GetUser(context.TODO(), nil) assert.Equal(tc.expectedCreateAccessKeyOutput, actualCreateAccessKeyOutput) assert.Equal(tc.expectedCreateAccessKeyError, actualCreateAccessKeyError) assert.Equal(tc.expectedDeleteAccessKeyError, actualDeleteAccessKeyError) @@ -134,6 +195,8 @@ func TestMockSTS(t *testing.T) { opts []MockSTSOption expectedGetCallerIdentityOutput *sts.GetCallerIdentityOutput expectedGetCallerIdentityError error + expectedAssumeRoleOutput *sts.AssumeRoleOutput + expectedAssumeRoleError error }{ { name: "GetCallerIdentityOutput", @@ -155,6 +218,42 @@ func TestMockSTS(t *testing.T) { opts: []MockSTSOption{WithGetCallerIdentityError(errors.New("testerr"))}, expectedGetCallerIdentityError: errors.New("testerr"), }, + { + name: "AssumeRoleOutput", + opts: []MockSTSOption{WithAssumeRoleOutput( + &sts.AssumeRoleOutput{ + AssumedRoleUser: &stsTypes.AssumedRoleUser{ + Arn: aws.String("arn:aws:sts::123456789012:assumed-role/example"), + AssumedRoleId: aws.String("example"), + }, + Credentials: &stsTypes.Credentials{ + AccessKeyId: aws.String("foobar"), + Expiration: &time.Time{}, + SecretAccessKey: aws.String("bazqux"), + SessionToken: aws.String("bizbuz"), + }, + PackedPolicySize: aws.Int32(0), + }, + )}, + expectedAssumeRoleOutput: &sts.AssumeRoleOutput{ + AssumedRoleUser: &stsTypes.AssumedRoleUser{ + Arn: aws.String("arn:aws:sts::123456789012:assumed-role/example"), + AssumedRoleId: aws.String("example"), + }, + Credentials: &stsTypes.Credentials{ + AccessKeyId: aws.String("foobar"), + Expiration: &time.Time{}, + SecretAccessKey: aws.String("bazqux"), + SessionToken: aws.String("bizbuz"), + }, + PackedPolicySize: aws.Int32(0), + }, + }, + { + name: "AssumeRoleError", + opts: []MockSTSOption{WithAssumeRoleError(errors.New("testerr"))}, + expectedAssumeRoleError: errors.New("testerr"), + }, } for _, tc := range cases { @@ -166,9 +265,12 @@ func TestMockSTS(t *testing.T) { f := NewMockSTS(tc.opts...) m, err := f(nil) require.NoError(err) // Nothing returns an error right now - actualGetCallerIdentityOutput, actualGetCallerIdentityError := m.GetCallerIdentity(nil) + actualGetCallerIdentityOutput, actualGetCallerIdentityError := m.GetCallerIdentity(context.TODO(), nil) assert.Equal(tc.expectedGetCallerIdentityOutput, actualGetCallerIdentityOutput) assert.Equal(tc.expectedGetCallerIdentityError, actualGetCallerIdentityError) + actualAssumeRoleOutput, actualAssumeRoleError := m.AssumeRole(context.TODO(), nil) + assert.Equal(tc.expectedAssumeRoleOutput, actualAssumeRoleOutput) + assert.Equal(tc.expectedAssumeRoleError, actualAssumeRoleError) }) } } diff --git a/awsutil/options.go b/awsutil/options.go index d3c7fef..b2795fb 100644 --- a/awsutil/options.go +++ b/awsutil/options.go @@ -4,11 +4,12 @@ package awsutil import ( - "fmt" "net/http" "time" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/hashicorp/go-hclog" ) @@ -31,36 +32,32 @@ type Option func(*options) error // options = how options are represented type options struct { - withEnvironmentCredentials bool - withSharedCredentials bool - withAwsSession *session.Session - withClientType string - withUsername string - withAccessKey string - withSecretKey string - withLogger hclog.Logger - withStsEndpoint string - withIamEndpoint string - withMaxRetries *int - withRegion string - withRoleArn string - withRoleSessionName string - withRoleExternalId string - withRoleTags map[string]string - withWebIdentityTokenFile string - withWebIdentityToken string - withSkipWebIdentityValidity bool - withHttpClient *http.Client - withValidityCheckTimeout time.Duration - withIAMAPIFunc IAMAPIFunc - withSTSAPIFunc STSAPIFunc + withSharedCredentials bool + withAwsConfig *aws.Config + withUsername string + withAccessKey string + withSecretKey string + withLogger hclog.Logger + withStsEndpointResolver sts.EndpointResolverV2 + withIamEndpointResolver iam.EndpointResolverV2 + withMaxRetries *int + withRegion string + withRoleArn string + withRoleSessionName string + withRoleExternalId string + withRoleTags map[string]string + withWebIdentityTokenFile string + withWebIdentityToken string + withHttpClient *http.Client + withValidityCheckTimeout time.Duration + withIAMAPIFunc IAMAPIFunc + withSTSAPIFunc STSAPIFunc + withCredentialsProvider aws.CredentialsProvider } func getDefaultOptions() options { return options{ - withEnvironmentCredentials: true, - withSharedCredentials: true, - withClientType: "iam", + withSharedCredentials: true, } } @@ -124,24 +121,6 @@ func WithWebIdentityToken(with string) Option { } } -// WithSkipWebIdentityValidity allows controlling whether the validity check is -// skipped for the web identity provider -func WithSkipWebIdentityValidity(with bool) Option { - return func(o *options) error { - o.withSkipWebIdentityValidity = with - return nil - } -} - -// WithEnvironmentCredentials allows controlling whether environment credentials -// are used -func WithEnvironmentCredentials(with bool) Option { - return func(o *options) error { - o.withEnvironmentCredentials = with - return nil - } -} - // WithSharedCredentials allows controlling whether shared credentials are used func WithSharedCredentials(with bool) Option { return func(o *options) error { @@ -150,23 +129,10 @@ func WithSharedCredentials(with bool) Option { } } -// WithAwsSession allows controlling the session passed into the client -func WithAwsSession(with *session.Session) Option { +// WithAwsConfig allows controlling the configuration passed into the client +func WithAwsConfig(with *aws.Config) Option { return func(o *options) error { - o.withAwsSession = with - return nil - } -} - -// WithClientType allows choosing the client type to use -func WithClientType(with string) Option { - return func(o *options) error { - switch with { - case "iam", "sts": - default: - return fmt.Errorf("unsupported client type %q", with) - } - o.withClientType = with + o.withAwsConfig = with return nil } } @@ -195,18 +161,18 @@ func WithSecretKey(with string) Option { } } -// WithStsEndpoint allows passing a custom STS endpoint -func WithStsEndpoint(with string) Option { +// WithStsEndpointResolver allows passing a custom STS endpoint resolver +func WithStsEndpointResolver(with sts.EndpointResolverV2) Option { return func(o *options) error { - o.withStsEndpoint = with + o.withStsEndpointResolver = with return nil } } -// WithIamEndpoint allows passing a custom IAM endpoint -func WithIamEndpoint(with string) Option { +// WithIamEndppointResolver allows passing a custom IAM endpoint resolver +func WithIamEndpointResolver(with iam.EndpointResolverV2) Option { return func(o *options) error { - o.withIamEndpoint = with + o.withIamEndpointResolver = with return nil } } @@ -269,3 +235,12 @@ func WithSTSAPIFunc(with STSAPIFunc) Option { return nil } } + +// WithCredentialsProvider allows passing in a CredentialsProvider interface +// constructor for mocking the AWS Credential Provider. +func WithCredentialsProvider(with aws.CredentialsProvider) Option { + return func(o *options) error { + o.withCredentialsProvider = with + return nil + } +} diff --git a/awsutil/options_test.go b/awsutil/options_test.go index 98060df..4d08c66 100644 --- a/awsutil/options_test.go +++ b/awsutil/options_test.go @@ -8,8 +8,9 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,17 +20,8 @@ func Test_GetOpts(t *testing.T) { t.Parallel() t.Run("default", func(t *testing.T) { testOpts := getDefaultOptions() - assert.Equal(t, true, testOpts.withEnvironmentCredentials) assert.Equal(t, true, testOpts.withSharedCredentials) - assert.Nil(t, testOpts.withAwsSession) - assert.Equal(t, "iam", testOpts.withClientType) - }) - t.Run("withEnvironmentCredentials", func(t *testing.T) { - opts, err := getOpts(WithEnvironmentCredentials(false)) - require.NoError(t, err) - testOpts := getDefaultOptions() - testOpts.withEnvironmentCredentials = false - assert.Equal(t, opts, testOpts) + assert.Nil(t, testOpts.withAwsConfig) }) t.Run("withSharedCredentials", func(t *testing.T) { opts, err := getOpts(WithSharedCredentials(false)) @@ -38,12 +30,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withSharedCredentials = false assert.Equal(t, opts, testOpts) }) - t.Run("withAwsSession", func(t *testing.T) { - sess := new(session.Session) - opts, err := getOpts(WithAwsSession(sess)) + t.Run("withAwsConfig", func(t *testing.T) { + cfg := new(aws.Config) + opts, err := getOpts(WithAwsConfig(cfg)) require.NoError(t, err) testOpts := getDefaultOptions() - testOpts.withAwsSession = sess + testOpts.withAwsConfig = cfg assert.Equal(t, opts, testOpts) }) t.Run("withUsername", func(t *testing.T) { @@ -53,15 +45,6 @@ func Test_GetOpts(t *testing.T) { testOpts.withUsername = "foobar" assert.Equal(t, opts, testOpts) }) - t.Run("withClientType", func(t *testing.T) { - _, err := getOpts(WithClientType("foobar")) - require.Error(t, err) - opts, err := getOpts(WithClientType("sts")) - require.NoError(t, err) - testOpts := getDefaultOptions() - testOpts.withClientType = "sts" - assert.Equal(t, opts, testOpts) - }) t.Run("withAccessKey", func(t *testing.T) { opts, err := getOpts(WithAccessKey("foobar")) require.NoError(t, err) @@ -77,17 +60,19 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("withStsEndpoint", func(t *testing.T) { - opts, err := getOpts(WithStsEndpoint("foobar")) + resolver := sts.NewDefaultEndpointResolverV2() + opts, err := getOpts(WithStsEndpointResolver(resolver)) require.NoError(t, err) testOpts := getDefaultOptions() - testOpts.withStsEndpoint = "foobar" + testOpts.withStsEndpointResolver = resolver assert.Equal(t, opts, testOpts) }) t.Run("withIamEndpoint", func(t *testing.T) { - opts, err := getOpts(WithIamEndpoint("foobar")) + resolver := iam.NewDefaultEndpointResolverV2() + opts, err := getOpts(WithIamEndpointResolver(resolver)) require.NoError(t, err) testOpts := getDefaultOptions() - testOpts.withIamEndpoint = "foobar" + testOpts.withIamEndpointResolver = resolver assert.Equal(t, opts, testOpts) }) t.Run("withLogger", func(t *testing.T) { @@ -177,11 +162,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withWebIdentityToken = "foo" assert.Equal(t, opts, testOpts) }) - t.Run("WithSkipWebIdentityValidity", func(t *testing.T) { - opts, err := getOpts(WithSkipWebIdentityValidity(true)) + t.Run("WithCredentialsProvider", func(t *testing.T) { + credProvider := &MockCredentialsProvider{} + opts, err := getOpts(WithCredentialsProvider(credProvider)) require.NoError(t, err) testOpts := getDefaultOptions() - testOpts.withSkipWebIdentityValidity = true + testOpts.withCredentialsProvider = credProvider assert.Equal(t, opts, testOpts) }) } diff --git a/awsutil/region.go b/awsutil/region.go index d6145b5..a89d3fc 100644 --- a/awsutil/region.go +++ b/awsutil/region.go @@ -4,13 +4,11 @@ package awsutil import ( - "net/http" - "time" + "context" + "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/hashicorp/errwrap" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" ) // "us-east-1 is used because it's where AWS first provides support for new features, @@ -40,38 +38,27 @@ Our chosen approach is: This approach should be used in future updates to this logic. */ -func GetRegion(configuredRegion string) (string, error) { +func GetRegion(ctx context.Context, configuredRegion string) (string, error) { if configuredRegion != "" { return configuredRegion, nil } - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - }) + cfg, err := config.LoadDefaultConfig(ctx) if err != nil { - return "", errwrap.Wrapf("got error when starting session: {{err}}", err) + return "", fmt.Errorf("got error when loading default configuration: %w", err) } - - region := aws.StringValue(sess.Config.Region) - if region != "" { - return region, nil - } - - metadata := ec2metadata.New(sess, &aws.Config{ - Endpoint: ec2Endpoint, - EC2MetadataDisableTimeoutOverride: aws.Bool(true), - HTTPClient: &http.Client{ - Timeout: time.Second, - }, - }) - if !metadata.Available() { - return DefaultRegion, nil + if cfg.Region != "" { + return cfg.Region, nil } - region, err = metadata.Region() + client := imds.NewFromConfig(cfg) + resp, err := client.GetRegion(ctx, &imds.GetRegionInput{}) if err != nil { - return "", errwrap.Wrapf("unable to retrieve region from instance metadata: {{err}}", err) + return "", fmt.Errorf("unable to retrieve region from instance metadata: %w", err) + } + if resp.Region != "" { + return resp.Region, nil } - return region, nil + return DefaultRegion, nil } diff --git a/awsutil/region_test.go b/awsutil/region_test.go index ea76f2f..83e92f5 100644 --- a/awsutil/region_test.go +++ b/awsutil/region_test.go @@ -4,15 +4,15 @@ package awsutil import ( + "context" "fmt" - "io/ioutil" "net/http" "net/http/httptest" "os" "os/user" "testing" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/aws" ) const testConfigFile = `[default] @@ -39,7 +39,7 @@ func TestGetRegion_UserConfigPreferredFirst(t *testing.T) { cleanupMetadata := setInstanceMetadata(t, unexpectedTestRegion) defer cleanupMetadata() - result, err := GetRegion(configuredRegion) + result, err := GetRegion(context.Background(), configuredRegion) if err != nil { t.Fatal(err) } @@ -60,7 +60,7 @@ func TestGetRegion_EnvVarsPreferredSecond(t *testing.T) { cleanupMetadata := setInstanceMetadata(t, unexpectedTestRegion) defer cleanupMetadata() - result, err := GetRegion(configuredRegion) + result, err := GetRegion(context.Background(), configuredRegion) if err != nil { t.Fatal(err) } @@ -87,7 +87,7 @@ func TestGetRegion_ConfigFilesPreferredThird(t *testing.T) { cleanupMetadata := setInstanceMetadata(t, unexpectedTestRegion) defer cleanupMetadata() - result, err := GetRegion(configuredRegion) + result, err := GetRegion(context.Background(), configuredRegion) if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func TestGetRegion_ConfigFileUnfound(t *testing.T) { } }() - result, err := GetRegion(configuredRegion) + result, err := GetRegion(context.Background(), configuredRegion) if err != nil { t.Fatal(err) } @@ -141,7 +141,7 @@ func TestGetRegion_EC2InstanceMetadataPreferredFourth(t *testing.T) { cleanupMetadata := setInstanceMetadata(t, expectedTestRegion) defer cleanupMetadata() - result, err := GetRegion(configuredRegion) + result, err := GetRegion(context.Background(), configuredRegion) if err != nil { t.Fatal(err) } @@ -163,7 +163,7 @@ func TestGetRegion_DefaultsToDefaultRegionWhenRegionUnavailable(t *testing.T) { cleanupFile := setConfigFileRegion(t, "") defer cleanupFile() - result, err := GetRegion(configuredRegion) + result, err := GetRegion(context.Background(), configuredRegion) if err != nil { t.Fatal(err) } @@ -209,7 +209,7 @@ func setConfigFileRegion(t *testing.T, region string) (cleanup func()) { pathToAWSDir := usr.HomeDir + "/.aws" pathToConfig := pathToAWSDir + "/config" - preExistingConfig, err := ioutil.ReadFile(pathToConfig) + preExistingConfig, err := os.ReadFile(pathToConfig) if err != nil { // File simply doesn't exist. if err := os.Mkdir(pathToAWSDir, os.ModeDir); err != nil { @@ -222,13 +222,13 @@ func setConfigFileRegion(t *testing.T, region string) (cleanup func()) { }) } else { cleanupFuncs = append(cleanupFuncs, func() { - if err := ioutil.WriteFile(pathToConfig, preExistingConfig, 0o644); err != nil { + if err := os.WriteFile(pathToConfig, preExistingConfig, 0o644); err != nil { t.Fatal(err) } }) } fileBody := fmt.Sprintf(testConfigFile, region) - if err := ioutil.WriteFile(pathToConfig, []byte(fileBody), 0o644); err != nil { + if err := os.WriteFile(pathToConfig, []byte(fileBody), 0o644); err != nil { t.Fatal(err) } diff --git a/awsutil/rotate.go b/awsutil/rotate.go index c58248a..302c89a 100644 --- a/awsutil/rotate.go +++ b/awsutil/rotate.go @@ -9,10 +9,9 @@ import ( "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" ) // RotateKeys takes the access key and secret key from this credentials config @@ -25,14 +24,14 @@ import ( // try to delete the new one to clean up, although it's unlikely that will work // if the old one could not be deleted. // -// Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, +// Supported options: WithSharedCredentials, WithAwsConfig +// WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, // WithSTSAPIFunc // // Note that WithValidityCheckTimeout here, when non-zero, controls the // WithValidityCheckTimeout option on access key creation. See CreateAccessKey // for more details. -func (c *CredentialsConfig) RotateKeys(opt ...Option) error { +func (c *CredentialsConfig) RotateKeys(ctx context.Context, opt ...Option) error { if c.AccessKey == "" || c.SecretKey == "" { return errors.New("cannot rotate credentials when either access_key or secret_key is empty") } @@ -42,21 +41,21 @@ func (c *CredentialsConfig) RotateKeys(opt ...Option) error { return fmt.Errorf("error reading options in RotateKeys: %w", err) } - sess := opts.withAwsSession - if sess == nil { - sess, err = c.GetSession(opt...) + cfg := opts.withAwsConfig + if cfg == nil { + cfg, err = c.GenerateCredentialChain(ctx, opt...) if err != nil { - return fmt.Errorf("error calling GetSession: %w", err) + return fmt.Errorf("error calling GenerateCredentialChain: %w", err) } } - sessOpt := append(opt, WithAwsSession(sess)) - createAccessKeyRes, err := c.CreateAccessKey(sessOpt...) + opt = append(opt, WithAwsConfig(cfg)) + createAccessKeyRes, err := c.CreateAccessKey(ctx, opt...) if err != nil { return fmt.Errorf("error calling CreateAccessKey: %w", err) } - err = c.DeleteAccessKey(c.AccessKey, append(sessOpt, WithUsername(*createAccessKeyRes.AccessKey.UserName))...) + err = c.DeleteAccessKey(ctx, c.AccessKey, append(opt, WithUsername(*createAccessKeyRes.AccessKey.UserName))...) if err != nil { return fmt.Errorf("error deleting old access key: %w", err) } @@ -69,68 +68,73 @@ func (c *CredentialsConfig) RotateKeys(opt ...Option) error { // CreateAccessKey creates a new access/secret key pair. // -// Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, +// Supported options: WithSharedCredentials, WithAwsConfig, +// WithUsername, WithValidityCheckTimeout, WithIAMAPIFunc, // WithSTSAPIFunc // // When WithValidityCheckTimeout is non-zero, it specifies a timeout to wait on // the created credentials to be valid and ready for use. -func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKeyOutput, error) { +func (c *CredentialsConfig) CreateAccessKey(ctx context.Context, opt ...Option) (*iam.CreateAccessKeyOutput, error) { opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("error reading options in CreateAccessKey: %w", err) } - client, err := c.IAMClient(opt...) + client, err := c.IAMClient(ctx, opt...) if err != nil { return nil, fmt.Errorf("error loading IAM client: %w", err) } var getUserInput iam.GetUserInput if opts.withUsername != "" { - getUserInput.SetUserName(opts.withUsername) + getUserInput.UserName = aws.String(opts.withUsername) } // otherwise, empty input means get current user - getUserRes, err := client.GetUser(&getUserInput) + getUserRes, err := client.GetUser(ctx, &getUserInput) if err != nil { - return nil, fmt.Errorf("error calling aws.GetUser: %w", err) + return nil, fmt.Errorf("error calling iam.GetUser: %w", err) } if getUserRes == nil { - return nil, fmt.Errorf("nil response from aws.GetUser") + return nil, fmt.Errorf("nil response from iam.GetUser") } if getUserRes.User == nil { - return nil, fmt.Errorf("nil user returned from aws.GetUser") + return nil, fmt.Errorf("nil user returned from iam.GetUser") } if getUserRes.User.UserName == nil { - return nil, fmt.Errorf("nil UserName returned from aws.GetUser") + return nil, fmt.Errorf("nil UserName returned from iam.GetUser") } createAccessKeyInput := iam.CreateAccessKeyInput{ UserName: getUserRes.User.UserName, } - createAccessKeyRes, err := client.CreateAccessKey(&createAccessKeyInput) + createAccessKeyRes, err := client.CreateAccessKey(ctx, &createAccessKeyInput) if err != nil { - return nil, fmt.Errorf("error calling aws.CreateAccessKey: %w", err) + return nil, fmt.Errorf("error calling iam.CreateAccessKey: %w", err) } if createAccessKeyRes == nil { - return nil, fmt.Errorf("nil response from aws.CreateAccessKey") + return nil, fmt.Errorf("nil response from iam.CreateAccessKey") } if createAccessKeyRes.AccessKey == nil { - return nil, fmt.Errorf("nil access key in response from aws.CreateAccessKey") + return nil, fmt.Errorf("nil access key in response from iam.CreateAccessKey") } if createAccessKeyRes.AccessKey.AccessKeyId == nil || createAccessKeyRes.AccessKey.SecretAccessKey == nil { - return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from aws.CreateAccessKey") + return nil, fmt.Errorf("nil AccessKeyId or SecretAccessKey returned from iam.CreateAccessKey") } // Check the credentials to make sure they are usable. We only do // this if withValidityCheckTimeout is non-zero to ensue that we don't // immediately fail due to eventual consistency. if opts.withValidityCheckTimeout != 0 { - newC := &CredentialsConfig{ - AccessKey: *createAccessKeyRes.AccessKey.AccessKeyId, - SecretKey: *createAccessKeyRes.AccessKey.SecretAccessKey, + newStaticCreds, err := NewCredentialsConfig( + WithAccessKey(*createAccessKeyRes.AccessKey.AccessKeyId), + WithSecretKey(*createAccessKeyRes.AccessKey.SecretAccessKey), + WithRegion(c.Region), + ) + if err != nil { + return nil, fmt.Errorf("failed to create credential config with new static credential: %w", err) } - if _, err := newC.GetCallerIdentity( + if _, err := newStaticCreds.GetCallerIdentity( + ctx, WithValidityCheckTimeout(opts.withValidityCheckTimeout), WithSTSAPIFunc(opts.withSTSAPIFunc), ); err != nil { @@ -143,15 +147,14 @@ func (c *CredentialsConfig) CreateAccessKey(opt ...Option) (*iam.CreateAccessKey // DeleteAccessKey deletes an access key. // -// Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithUserName, WithIAMAPIFunc -func (c *CredentialsConfig) DeleteAccessKey(accessKeyId string, opt ...Option) error { +// Supported options: WithSharedCredentials, WithAwsConfig, WithUserName, WithIAMAPIFunc +func (c *CredentialsConfig) DeleteAccessKey(ctx context.Context, accessKeyId string, opt ...Option) error { opts, err := getOpts(opt...) if err != nil { return fmt.Errorf("error reading options in RotateKeys: %w", err) } - client, err := c.IAMClient(opt...) + client, err := c.IAMClient(ctx, opt...) if err != nil { return fmt.Errorf("error loading IAM client: %w", err) } @@ -160,10 +163,10 @@ func (c *CredentialsConfig) DeleteAccessKey(accessKeyId string, opt ...Option) e AccessKeyId: aws.String(accessKeyId), } if opts.withUsername != "" { - deleteAccessKeyInput.SetUserName(opts.withUsername) + deleteAccessKeyInput.UserName = aws.String(opts.withUsername) } - _, err = client.DeleteAccessKey(&deleteAccessKeyInput) + _, err = client.DeleteAccessKey(ctx, &deleteAccessKeyInput) if err != nil { return fmt.Errorf("error deleting old access key: %w", err) } @@ -171,77 +174,31 @@ func (c *CredentialsConfig) DeleteAccessKey(accessKeyId string, opt ...Option) e return nil } -// GetSession returns an AWS session configured according to the various values -// in the CredentialsConfig object. This can be passed into iam.New or sts.New -// as appropriate. -// -// Supported options: WithEnvironmentCredentials, WithSharedCredentials, -// WithAwsSession, WithClientType -func (c *CredentialsConfig) GetSession(opt ...Option) (*session.Session, error) { - opts, err := getOpts(opt...) - if err != nil { - return nil, fmt.Errorf("error reading options in GetSession: %w", err) - } - - creds, err := c.GenerateCredentialChain(opt...) - if err != nil { - return nil, err - } - - var endpoint string - switch opts.withClientType { - case "sts": - endpoint = c.STSEndpoint - case "iam": - endpoint = c.IAMEndpoint - default: - return nil, fmt.Errorf("unknown client type %q in GetSession", opts.withClientType) - } - - awsConfig := &aws.Config{ - Credentials: creds, - Region: aws.String(c.Region), - Endpoint: aws.String(endpoint), - HTTPClient: c.HTTPClient, - MaxRetries: c.MaxRetries, - } - - sess, err := session.NewSession(awsConfig) - if err != nil { - return nil, fmt.Errorf("error getting new session: %w", err) - } - - return sess, nil -} - // GetCallerIdentity runs sts.GetCallerIdentity for the current set // credentials. This can be used to check that credentials are valid, // in addition to checking details about the effective logged in // account and user ID. // -// Supported options: WithEnvironmentCredentials, -// WithSharedCredentials, WithAwsSession, WithValidityCheckTimeout -func (c *CredentialsConfig) GetCallerIdentity(opt ...Option) (*sts.GetCallerIdentityOutput, error) { +// Supported options: WithSharedCredentials, WithAwsConfig, WithValidityCheckTimeout +func (c *CredentialsConfig) GetCallerIdentity(ctx context.Context, opt ...Option) (*sts.GetCallerIdentityOutput, error) { opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("error reading options in GetCallerIdentity: %w", err) } - client, err := c.STSClient(opt...) + client, err := c.STSClient(ctx, opt...) if err != nil { return nil, fmt.Errorf("error loading STS client: %w", err) } delay := time.Second - timeoutCtx, cancel := context.WithTimeout(context.Background(), opts.withValidityCheckTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, opts.withValidityCheckTimeout) defer cancel() for { - cid, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + cid, err := client.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err == nil { return cid, nil } - - // TODO: can add a context here for external cancellation in the future select { case <-time.After(delay): // pass diff --git a/awsutil/rotate_test.go b/awsutil/rotate_test.go index 2de873a..e9a5370 100644 --- a/awsutil/rotate_test.go +++ b/awsutil/rotate_test.go @@ -4,16 +4,19 @@ package awsutil import ( + "context" "errors" + "fmt" "os" "strings" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + iamTypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + awserr "github.com/aws/smithy-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -40,14 +43,14 @@ func TestRotation(t *testing.T) { } // Create an initial key - out, err := credsConfig.CreateAccessKey(WithUsername(username), WithValidityCheckTimeout(testRotationWaitTimeout)) + out, err := credsConfig.CreateAccessKey(context.Background(), WithUsername(username), WithValidityCheckTimeout(testRotationWaitTimeout)) require.NoError(err) require.NotNil(out) cleanupKey := out.AccessKey.AccessKeyId defer func() { - assert.NoError(credsConfig.DeleteAccessKey(*cleanupKey, WithUsername(username))) + assert.NoError(credsConfig.DeleteAccessKey(context.Background(), *cleanupKey, WithUsername(username))) }() // Run rotation @@ -57,7 +60,7 @@ func TestRotation(t *testing.T) { WithSecretKey(secretKey), ) require.NoError(err) - require.NoError(c.RotateKeys(WithValidityCheckTimeout(testRotationWaitTimeout))) + require.NoError(c.RotateKeys(context.Background(), WithValidityCheckTimeout(testRotationWaitTimeout))) assert.NotEqual(accessKey, c.AccessKey) assert.NotEqual(secretKey, c.SecretKey) cleanupKey = &c.AccessKey @@ -77,14 +80,14 @@ func TestCallerIdentity(t *testing.T) { SessionToken: sessionToken, } - cid, err := c.GetCallerIdentity() + cid, err := c.GetCallerIdentity(context.Background()) require.NoError(err) assert.NotEmpty(cid.Account) assert.NotEmpty(cid.Arn) assert.NotEmpty(cid.UserId) } -func TestCallerIdentityWithSession(t *testing.T) { +func TestCallerIdentityWithConfig(t *testing.T) { require, assert := require.New(t), assert.New(t) key, secretKey, sessionToken := os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("AWS_SESSION_TOKEN") @@ -98,11 +101,11 @@ func TestCallerIdentityWithSession(t *testing.T) { SessionToken: sessionToken, } - sess, err := c.GetSession() + cfg, err := c.GenerateCredentialChain(context.Background()) require.NoError(err) - require.NotNil(sess) + require.NotNil(cfg) - cid, err := c.GetCallerIdentity(WithAwsSession(sess)) + cid, err := c.GetCallerIdentity(context.Background(), WithAwsConfig(cfg)) require.NoError(err) assert.NotEmpty(cid.Account) assert.NotEmpty(cid.Arn) @@ -117,9 +120,12 @@ func TestCallerIdentityErrorNoTimeout(t *testing.T) { SecretKey: "badagain", } - _, err := c.GetCallerIdentity() + _, err := c.GetCallerIdentity(context.Background()) require.NotNil(err) - require.Implements((*awserr.Error)(nil), err) + fmt.Printf("\nTEST: %v\n", err) + + var oe *awserr.OperationError + require.True(errors.As(err, &oe)) } func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) { @@ -130,12 +136,31 @@ func TestCallerIdentityErrorWithValidityCheckTimeout(t *testing.T) { SecretKey: "badagain", } - _, err := c.GetCallerIdentity(WithValidityCheckTimeout(time.Second * 10)) + _, err := c.GetCallerIdentity(context.Background(), WithValidityCheckTimeout(time.Second*10)) require.NotNil(err) require.True(strings.HasPrefix(err.Error(), "timeout after 10s waiting for success")) err = errors.Unwrap(err) require.NotNil(err) - require.Implements((*awserr.Error)(nil), err) + var oe *awserr.OperationError + require.True(errors.As(err, &oe)) +} + +func TestCallerIdentityErrorWithAPIThrottleException(t *testing.T) { + require := require.New(t) + + c := &CredentialsConfig{ + AccessKey: "bad", + SecretKey: "badagain", + } + + _, err := c.GetCallerIdentity(context.Background(), WithSTSAPIFunc( + NewMockSTS( + WithGetCallerIdentityError(MockAWSThrottleErr()), + ), + )) + require.NotNil(err) + var ae awserr.APIError + require.True(errors.As(err, &ae)) } func TestCallerIdentityWithSTSMockError(t *testing.T) { @@ -144,7 +169,7 @@ func TestCallerIdentityWithSTSMockError(t *testing.T) { expectedErr := errors.New("this is the expected error") c, err := NewCredentialsConfig() require.NoError(err) - _, err = c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityError(expectedErr)))) + _, err = c.GetCallerIdentity(context.Background(), WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityError(expectedErr)))) require.EqualError(err, expectedErr.Error()) } @@ -159,7 +184,7 @@ func TestCallerIdentityWithSTSMockNoErorr(t *testing.T) { c, err := NewCredentialsConfig() require.NoError(err) - out, err := c.GetCallerIdentity(WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityOutput(expectedOut)))) + out, err := c.GetCallerIdentity(context.Background(), WithSTSAPIFunc(NewMockSTS(WithGetCallerIdentityOutput(expectedOut)))) require.NoError(err) require.Equal(expectedOut, out) } @@ -171,7 +196,7 @@ func TestDeleteAccessKeyWithIAMMock(t *testing.T) { expectedErr := "error deleting old access key: this is the expected error" c, err := NewCredentialsConfig() require.NoError(err) - err = c.DeleteAccessKey("foobar", WithIAMAPIFunc(NewMockIAM(WithDeleteAccessKeyError(mockErr)))) + err = c.DeleteAccessKey(context.Background(), "foobar", WithIAMAPIFunc(NewMockIAM(WithDeleteAccessKeyError(mockErr)))) require.EqualError(err, expectedErr) } @@ -179,10 +204,10 @@ func TestCreateAccessKeyWithIAMMockGetUserError(t *testing.T) { require := require.New(t) mockErr := errors.New("this is the expected error") - expectedErr := "error calling aws.GetUser: this is the expected error" + expectedErr := "error calling iam.GetUser: this is the expected error" c, err := NewCredentialsConfig() require.NoError(err) - _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM(WithGetUserError(mockErr)))) + _, err = c.CreateAccessKey(context.Background(), WithIAMAPIFunc(NewMockIAM(WithGetUserError(mockErr)))) require.EqualError(err, expectedErr) } @@ -190,12 +215,12 @@ func TestCreateAccessKeyWithIAMMockCreateAccessKeyError(t *testing.T) { require := require.New(t) mockErr := errors.New("this is the expected error") - expectedErr := "error calling aws.CreateAccessKey: this is the expected error" + expectedErr := "error calling iam.CreateAccessKey: this is the expected error" c, err := NewCredentialsConfig() require.NoError(err) - _, err = c.CreateAccessKey(WithIAMAPIFunc(NewMockIAM( + _, err = c.CreateAccessKey(context.Background(), WithIAMAPIFunc(NewMockIAM( WithGetUserOutput(&iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ UserName: aws.String("foobar"), }, }), @@ -212,15 +237,16 @@ func TestCreateAccessKeyWithIAMAndSTSMockGetCallerIdentityError(t *testing.T) { c, err := NewCredentialsConfig() require.NoError(err) _, err = c.CreateAccessKey( + context.Background(), WithValidityCheckTimeout(time.Nanosecond), WithIAMAPIFunc(NewMockIAM( WithGetUserOutput(&iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ UserName: aws.String("foobar"), }, }), WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ - AccessKey: &iam.AccessKey{ + AccessKey: &iamTypes.AccessKey{ AccessKeyId: aws.String("foobar"), SecretAccessKey: aws.String("bazqux"), }, @@ -236,14 +262,15 @@ func TestCreateAccessKeyWithIAMAndSTSMockGetCallerIdentityError(t *testing.T) { func TestCreateAccessKeyNilResponse(t *testing.T) { require := require.New(t) - expectedErr := "nil response from aws.CreateAccessKey" + expectedErr := "nil response from iam.CreateAccessKey" c, err := NewCredentialsConfig() require.NoError(err) _, err = c.CreateAccessKey( + context.Background(), WithValidityCheckTimeout(time.Nanosecond), WithIAMAPIFunc(NewMockIAM( WithGetUserOutput(&iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ UserName: aws.String("foobar"), }, }), @@ -264,18 +291,18 @@ func TestRotateKeysWithMocks(t *testing.T) { { name: "CreateAccessKey IAM error", mockIAMOpts: []MockIAMOption{WithGetUserError(mockErr)}, - requireErr: "error calling CreateAccessKey: error calling aws.GetUser: this is the expected error", + requireErr: "error calling CreateAccessKey: error calling iam.GetUser: this is the expected error", }, { name: "CreateAccessKey STS error", mockIAMOpts: []MockIAMOption{ WithGetUserOutput(&iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ UserName: aws.String("foobar"), }, }), WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ - AccessKey: &iam.AccessKey{ + AccessKey: &iamTypes.AccessKey{ AccessKeyId: aws.String("foobar"), SecretAccessKey: aws.String("bazqux"), }, @@ -288,12 +315,12 @@ func TestRotateKeysWithMocks(t *testing.T) { name: "DeleteAccessKey IAM error", mockIAMOpts: []MockIAMOption{ WithGetUserOutput(&iam.GetUserOutput{ - User: &iam.User{ + User: &iamTypes.User{ UserName: aws.String("foobar"), }, }), WithCreateAccessKeyOutput(&iam.CreateAccessKeyOutput{ - AccessKey: &iam.AccessKey{ + AccessKey: &iamTypes.AccessKey{ AccessKeyId: aws.String("foobar"), SecretAccessKey: aws.String("bazqux"), UserName: aws.String("foouser"), @@ -323,6 +350,7 @@ func TestRotateKeysWithMocks(t *testing.T) { ) require.NoError(err) err = c.RotateKeys( + context.Background(), WithIAMAPIFunc(NewMockIAM(tc.mockIAMOpts...)), WithSTSAPIFunc(NewMockSTS(tc.mockSTSOpts...)), WithValidityCheckTimeout(time.Nanosecond),