Skip to content

Commit

Permalink
feat(awsutil-v2): implement awsutil for aws-sdk-go-v2 (#83)
Browse files Browse the repository at this point in the history
This major version release utilizes the latest version
of the aws-sdk-go-v2. The following behavioral changes
are included in this major version release:

- Custom endpoint resolvers are attached to the STS and
IAM clients, not to the credentials. This is apart of the
aws-sdk-go-v2 EndpointResolverV2 feature.
- withStsEndpoint is no longer a string type, but a
sts.EndpointResolverV2 type. This option was relabeled
to withStsEndpointResolver.
- withIamEndpoint is no longer a string type, but a
iam.EndpointResolverV2 type. This option was relabeled
to withIamEndpointResolver.
- By default, aws credential configurations will load values
from environment variables. The user provided options will
overload the default values.
- The ability to mock out the underlying credential provider
for unit testing.

Changed behaviors from awsutil v1 includes the following:

- Replaced aws errors with aws smithy-go errors
- No longer able to utilize the aws default remote credential
provider
- The function GenerateCredentialChain returns a aws.Config,
which contains the credential provider.
  • Loading branch information
ddebko authored Sep 21, 2023
1 parent 7a5e901 commit 22f76f8
Show file tree
Hide file tree
Showing 17 changed files with 1,155 additions and 639 deletions.
37 changes: 37 additions & 0 deletions awsutil/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# AWSUTIL - Go library for generating aws credentials

*NOTE*: This is version 2 of the library. The `v0` branch contains version 0,
which may be needed for legacy applications or while transitioning to version 2.

## Usage

Following is an example usage of generating AWS credentials with static user credentials

```go

// AWS access keys for an IAM user can be used as your AWS credentials.
// This is an example of an access key and secret key
var accessKey = "AKIAIOSFODNN7EXAMPLE"
var secretKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"

// Access key IDs beginning with AKIA are long-term access keys. A long-term
// access key should be supplied when generating static credentials.
config, err := awsutil.NewCredentialsConfig(
awsutil.WithAccessKey(accessKey),
awsutil.WithSecretKey(secretKey),
)
if err != nil {
return err
}

s3Client := s3.NewFromConfig(config)

```

## Contributing to v0

To push a bug fix or feature for awsutil `v0`, branch out from the [awsutil/v0](https://github.com/hashicorp/go-secure-stdlib/tree/awsutil/v0) branch.
Commit the code changes you want to this new branch and open a PR. Make sure the PR
is configured so that the base branch is set to `awsutil/v0` and not `main`. Once the PR
is reviewed, feel free to merge it into the `awsutil/v0` branch. When creating a new
release, validate that the `Target` branch is `awsutil/v0` and the tag is `awsutil/v0.x.x`.
80 changes: 44 additions & 36 deletions awsutil/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,98 @@
package awsutil

import (
"errors"
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
)

// IAMAPIFunc is a factory function for returning an IAM interface,
// useful for supplying mock interfaces for testing IAM. The session
// is passed into the function in the same way as done with the
// standard iam.New() constructor.
type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error)
// useful for supplying mock interfaces for testing IAM.
type IAMAPIFunc func(awsConfig *aws.Config) (IAMClient, error)

// IAMClient represents an iam.Client
type IAMClient interface {
CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error)
DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error)
ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error)
GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error)
}

// STSAPIFunc is a factory function for returning a STS interface,
// useful for supplying mock interfaces for testing STS. The session
// is passed into the function in the same way as done with the
// standard sts.New() constructor.
type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error)
// useful for supplying mock interfaces for testing STS.
type STSAPIFunc func(awsConfig *aws.Config) (STSClient, error)

// STSClient represents an sts.Client
type STSClient interface {
AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error)
GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
}

// IAMClient returns an IAM client.
//
// Supported options: WithSession, WithIAMAPIFunc.
// Supported options: WithAwsConfig, WithIAMAPIFunc, WithIamEndpointResolver.
//
// If WithIAMAPIFunc is supplied, the included function is used as
// the IAM client constructor instead. This can be used for Mocking
// the IAM API.
func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) {
func (c *CredentialsConfig) IAMClient(ctx context.Context, opt ...Option) (IAMClient, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
cfg := opts.withAwsConfig
if cfg == nil {
cfg, err = c.GenerateCredentialChain(ctx, opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err)
}
}

if opts.withIAMAPIFunc != nil {
return opts.withIAMAPIFunc(sess)
return opts.withIAMAPIFunc(cfg)
}

client := iam.New(sess)
if client == nil {
return nil, errors.New("could not obtain iam client from session")
var iamOpts []func(*iam.Options)
if c.IAMEndpointResolver != nil {
iamOpts = append(iamOpts, iam.WithEndpointResolverV2(c.IAMEndpointResolver))
}

return client, nil
return iam.NewFromConfig(*cfg, iamOpts...), nil
}

// STSClient returns a STS client.
//
// Supported options: WithSession, WithSTSAPIFunc.
// Supported options: WithAwsConfig, WithSTSAPIFunc, WithStsEndpointResolver.
//
// If WithSTSAPIFunc is supplied, the included function is used as
// the STS client constructor instead. This can be used for Mocking
// the STS API.
func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) {
func (c *CredentialsConfig) STSClient(ctx context.Context, opt ...Option) (STSClient, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error reading options: %w", err)
}

sess := opts.withAwsSession
if sess == nil {
sess, err = c.GetSession(opt...)
cfg := opts.withAwsConfig
if cfg == nil {
cfg, err = c.GenerateCredentialChain(ctx, opt...)
if err != nil {
return nil, fmt.Errorf("error calling GetSession: %w", err)
return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err)
}
}

if opts.withSTSAPIFunc != nil {
return opts.withSTSAPIFunc(sess)
return opts.withSTSAPIFunc(cfg)
}

client := sts.New(sess)
if client == nil {
return nil, errors.New("could not obtain sts client from session")
var stsOpts []func(*sts.Options)
if c.STSEndpointResolver != nil {
stsOpts = append(stsOpts, sts.WithEndpointResolverV2(c.STSEndpointResolver))
}

return client, nil
return sts.NewFromConfig(*cfg, stsOpts...), nil
}
45 changes: 13 additions & 32 deletions awsutil/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,24 @@
package awsutil

import (
"context"
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/stretchr/testify/require"
)

const testOptionErr = "test option error"
const testBadClientType = "badclienttype"

func testWithBadClientType(o *options) error {
o.withClientType = testBadClientType
return nil
}

func TestCredentialsConfigIAMClient(t *testing.T) {
cases := []struct {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual iamiface.IAMAPI)
require func(t *testing.T, actual IAMClient)
requireErr string
}{
{
Expand All @@ -37,17 +30,11 @@ func TestCredentialsConfigIAMClient(t *testing.T) {
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock IAM session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithIAMAPIFunc(NewMockIAM())},
require: func(t *testing.T, actual iamiface.IAMAPI) {
require: func(t *testing.T, actual IAMClient) {
t.Helper()
require := require.New(t)
require.Equal(&MockIAM{}, actual)
Expand All @@ -57,10 +44,10 @@ func TestCredentialsConfigIAMClient(t *testing.T) {
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual iamiface.IAMAPI) {
require: func(t *testing.T, actual IAMClient) {
t.Helper()
require := require.New(t)
require.IsType(&iam.IAM{}, actual)
require.IsType(&iam.Client{}, actual)
},
},
}
Expand All @@ -69,7 +56,7 @@ func TestCredentialsConfigIAMClient(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.IAMClient(tc.opts...)
actual, err := tc.credentialsConfig.IAMClient(context.TODO(), tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
Expand All @@ -86,7 +73,7 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
name string
credentialsConfig *CredentialsConfig
opts []Option
require func(t *testing.T, actual stsiface.STSAPI)
require func(t *testing.T, actual STSClient)
requireErr string
}{
{
Expand All @@ -95,17 +82,11 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
opts: []Option{MockOptionErr(errors.New(testOptionErr))},
requireErr: fmt.Sprintf("error reading options: %s", testOptionErr),
},
{
name: "session error",
credentialsConfig: &CredentialsConfig{},
opts: []Option{testWithBadClientType},
requireErr: fmt.Sprintf("error calling GetSession: unknown client type %q in GetSession", testBadClientType),
},
{
name: "with mock STS session",
credentialsConfig: &CredentialsConfig{},
opts: []Option{WithSTSAPIFunc(NewMockSTS())},
require: func(t *testing.T, actual stsiface.STSAPI) {
require: func(t *testing.T, actual STSClient) {
t.Helper()
require := require.New(t)
require.Equal(&MockSTS{}, actual)
Expand All @@ -115,10 +96,10 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
name: "no mock client",
credentialsConfig: &CredentialsConfig{},
opts: []Option{},
require: func(t *testing.T, actual stsiface.STSAPI) {
require: func(t *testing.T, actual STSClient) {
t.Helper()
require := require.New(t)
require.IsType(&sts.STS{}, actual)
require.IsType(&sts.Client{}, actual)
},
},
}
Expand All @@ -127,7 +108,7 @@ func TestCredentialsConfigSTSClient(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
actual, err := tc.credentialsConfig.STSClient(tc.opts...)
actual, err := tc.credentialsConfig.STSClient(context.TODO(), tc.opts...)
if tc.requireErr != "" {
require.EqualError(err, tc.requireErr)
return
Expand Down
10 changes: 5 additions & 5 deletions awsutil/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package awsutil
import (
"errors"

awsRequest "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go-v2/aws/retry"
multierror "github.com/hashicorp/go-multierror"
)

Expand All @@ -15,10 +15,10 @@ var ErrUpstreamRateLimited = errors.New("upstream rate limited")
// CheckAWSError will examine an error and convert to a logical error if
// appropriate. If no appropriate error is found, return nil
func CheckAWSError(err error) error {
// IsErrorThrottle will check if the error returned is one that matches
// known request limiting errors:
// https://github.com/aws/aws-sdk-go/blob/488d634b5a699b9118ac2befb5135922b4a77210/aws/request/retryer.go#L35
if awsRequest.IsErrorThrottle(err) {
retryErr := retry.ThrottleErrorCode{
Codes: retry.DefaultThrottleErrorCodes,
}
if retryErr.IsErrorThrottle(err).Bool() {
return ErrUpstreamRateLimited
}
return nil
Expand Down
18 changes: 11 additions & 7 deletions awsutil/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
awserr "github.com/aws/smithy-go"
multierror "github.com/hashicorp/go-multierror"
)

Expand All @@ -23,12 +23,16 @@ func Test_CheckAWSError(t *testing.T) {
},
{
Name: "Upstream throttle error",
Err: awserr.New("Throttling", "", nil),
Err: MockAWSThrottleErr(),
Expected: ErrUpstreamRateLimited,
},
{
Name: "Upstream RequestLimitExceeded",
Err: awserr.New("RequestLimitExceeded", "Request rate limited", nil),
Name: "Upstream RequestLimitExceeded",
Err: &MockAWSErr{
Code: "RequestLimitExceeded",
Message: "Request rate limited",
Fault: awserr.FaultServer,
},
Expected: ErrUpstreamRateLimited,
},
}
Expand All @@ -50,7 +54,7 @@ func Test_CheckAWSError(t *testing.T) {
}

func Test_AppendRateLimitedError(t *testing.T) {
awsErr := awserr.New("Throttling", "", nil)
throttleErr := MockAWSThrottleErr()
testCases := []struct {
Name string
Err error
Expand All @@ -63,8 +67,8 @@ func Test_AppendRateLimitedError(t *testing.T) {
},
{
Name: "Upstream throttle error",
Err: awsErr,
Expected: multierror.Append(awsErr, ErrUpstreamRateLimited),
Err: throttleErr,
Expected: multierror.Append(throttleErr, ErrUpstreamRateLimited),
},
{
Name: "Nil",
Expand Down
Loading

0 comments on commit 22f76f8

Please sign in to comment.