Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests for PR 8991 #9021

Closed
wants to merge 11 commits into from
11 changes: 6 additions & 5 deletions builtin/credential/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/awsutil"
"github.com/hashicorp/vault/sdk/helper/consts"
Expand Down Expand Up @@ -57,13 +58,13 @@ type backend struct {
// This avoids the overhead of creating a client object for every login request.
// When the credentials are modified or deleted, all the cached client objects
// will be flushed. The empty STS role signifies the master account
EC2ClientsMap map[string]map[string]*ec2.EC2
EC2ClientsMap map[string]map[string]ec2iface.EC2API

// Map to hold the IAM client objects indexed by region and STS role.
// This avoids the overhead of creating a client object for every login request.
// When the credentials are modified or deleted, all the cached client objects
// will be flushed. The empty STS role signifies the master account
IAMClientsMap map[string]map[string]*iam.IAM
IAMClientsMap map[string]map[string]iamiface.IAMAPI

// Map to associate a partition to a random region in that partition. Users of
// this don't care what region in the partition they use, but there is some client
Expand Down Expand Up @@ -97,8 +98,8 @@ func Backend(_ *logical.BackendConfig) (*backend, error) {
// Setting the periodic func to be run once in an hour.
// If there is a real need, this can be made configurable.
tidyCooldownPeriod: time.Hour,
EC2ClientsMap: make(map[string]map[string]*ec2.EC2),
IAMClientsMap: make(map[string]map[string]*iam.IAM),
EC2ClientsMap: make(map[string]map[string]ec2iface.EC2API),
IAMClientsMap: make(map[string]map[string]iamiface.IAMAPI),
iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour),
tidyBlacklistCASGuard: new(uint32),
tidyWhitelistCASGuard: new(uint32),
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
return
}

stsService := sts.New(awsSession)
stsService := newSTSClient(awsSession)
stsInputParams := &sts.GetCallerIdentityInput{}

testIdentity, err := stsService.GetCallerIdentity(stsInputParams)
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func GenerateLoginData(creds *credentials.Credentials, headerValue, configuredRe
}

var params *sts.GetCallerIdentityInput
svc := sts.New(stsSession)
svc := newSTSClient(stsSession)
stsRequest, _ := svc.GetCallerIdentityRequest(params)

// Inject the required auth header value, if supplied, and then sign the request including that header
Expand Down
56 changes: 49 additions & 7 deletions builtin/credential/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,26 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/sdk/helper/awsutil"
"github.com/hashicorp/vault/sdk/logical"
)

var (
// These variables are intended to be set by tests. If set, the given
// client will override the AWS client, allowing client responses to
// be mocked out.
mockEC2Client ec2iface.EC2API = nil
mockIAMClient iamiface.IAMAPI = nil
mockSTSClient stsiface.STSAPI = nil
)

// getRawClientConfig creates a aws-sdk-go config, which is used to create client
// that can interact with AWS API. This builds credentials in the following
// order of preference:
Expand Down Expand Up @@ -115,7 +127,7 @@ func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region
if err != nil {
return nil, err
}
client := sts.New(sess)
client := newSTSClient(sess)
if client == nil {
return nil, errwrap.Wrapf("could not obtain sts client: {{err}}", err)
}
Expand Down Expand Up @@ -192,7 +204,7 @@ func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, acco
}

// clientEC2 creates a client to interact with AWS EC2 API
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (ec2iface.EC2API, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -231,12 +243,12 @@ func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, acco
if err != nil {
return nil, err
}
client := ec2.New(sess)
client := newEC2Client(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain ec2 client")
}
if _, ok := b.EC2ClientsMap[region]; !ok {
b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
b.EC2ClientsMap[region] = map[string]ec2iface.EC2API{stsRole: client}
} else {
b.EC2ClientsMap[region][stsRole] = client
}
Expand All @@ -245,7 +257,7 @@ func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, acco
}

// clientIAM creates a client to interact with AWS IAM API
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (iamiface.IAMAPI, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -291,14 +303,44 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, acco
if err != nil {
return nil, err
}
client := iam.New(sess)
client := newIAMClient(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
}
if _, ok := b.IAMClientsMap[region]; !ok {
b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
b.IAMClientsMap[region] = map[string]iamiface.IAMAPI{stsRole: client}
} else {
b.IAMClientsMap[region][stsRole] = client
}
return b.IAMClientsMap[region][stsRole], nil
}

// newEC2Client should be used instead of using ec2.New()
// directly because it allows us to mock out the EC2 client
// as needed for testing.
func newEC2Client(sess *session.Session) ec2iface.EC2API {
if mockEC2Client != nil {
return mockEC2Client
}
return ec2.New(sess)
}

// newIAMClient should be used instead of using iam.New()
// directly because it allows us to mock out the IAM client
// as needed for testing.
func newIAMClient(sess *session.Session) iamiface.IAMAPI {
if mockIAMClient != nil {
return mockIAMClient
}
return iam.New(sess)
}

// newSTSClient should be used instead of using sts.New()
// directly because it allows us to mock out the STS client
// as needed for testing.
func newSTSClient(sess *session.Session) stsiface.STSAPI {
if mockSTSClient != nil {
return mockSTSClient
}
return sts.New(sess)
}
3 changes: 2 additions & 1 deletion builtin/credential/aws/path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/fullsailor/pkcs7"
"github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp"
Expand Down Expand Up @@ -131,7 +132,7 @@ needs to be supplied along with 'identity' parameter.`,

// instanceIamRoleARN fetches the IAM role ARN associated with the given
// instance profile name
func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName string) (string, error) {
func (b *backend) instanceIamRoleARN(iamClient iamiface.IAMAPI, instanceProfileName string) (string, error) {
if iamClient == nil {
return "", fmt.Errorf("nil iamClient")
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/credential/aws/path_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func defaultLoginData() (map[string]interface{}, error) {
return nil, fmt.Errorf("failed to create session: %s", err)
}

stsService := sts.New(awsSession)
stsService := newSTSClient(awsSession)
stsInputParams := &sts.GetCallerIdentityInput{}
stsRequestValid, _ := stsService.GetCallerIdentityRequest(stsInputParams)
stsRequestValid.HTTPRequest.Header.Add(iamServerIdHeader, testVaultHeaderValue)
Expand Down
Loading