Skip to content

Commit

Permalink
provider/aws: Add assume_role block to provider
Browse files Browse the repository at this point in the history
This replaces the previous `role_arn` with a block which looks like
this:

```
provider "aws" {
        // secret key, access key etc

	assume_role {
	        role_arn = "<Role ARN>"
		session_name = "<Session Name>"
		external_id = "<External ID>"
	}
}
```

We also modify the configuration structure and read the values from the
block if present into those values and adjust the call to AssumeRole to
include the SessionName and ExternalID based on the values set in the
configuration block.

Finally we clean up the tests and add in missing error checks, and clean
up the error handling logic in the Auth helper functions.
  • Loading branch information
jen20 committed Sep 3, 2016
1 parent d444d12 commit 305f114
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 201 deletions.
86 changes: 56 additions & 30 deletions builtin/providers/aws/auth_helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aws

import (
"errors"
"fmt"
"log"
"os"
Expand All @@ -18,7 +19,6 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-multierror"
)

func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
Expand Down Expand Up @@ -77,7 +77,7 @@ func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (
}

if len(outRoles.Roles) < 1 {
return "", fmt.Errorf("Failed getting account ID via 'iam:ListRoles': No roles available")
return "", errors.New("Failed getting account ID via 'iam:ListRoles': No roles available")
}

return parseAccountIdFromArn(*outRoles.Roles[0].Arn)
Expand All @@ -95,8 +95,6 @@ func parseAccountIdFromArn(arn string) (string, error) {
// environment in the case that they're not explicitly specified
// in the Terraform configuration.
func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
var errs []error

// build a chain provider, lazy-evaulated by aws-sdk
providers := []awsCredentials.Provider{
&awsCredentials.StaticProvider{Value: awsCredentials.Value{
Expand Down Expand Up @@ -130,7 +128,7 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
providers = append(providers, &ec2rolecreds.EC2RoleProvider{
Client: metadataClient,
})
log.Printf("[INFO] AWS EC2 instance detected via default metadata" +
log.Print("[INFO] AWS EC2 instance detected via default metadata" +
" API endpoint, EC2RoleProvider added to the auth chain")
} else {
if usedEndpoint == "" {
Expand All @@ -141,40 +139,68 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
}
}

if c.RoleArn != "" {
log.Printf("[INFO] attempting to assume role %s", c.RoleArn)
// This is the "normal" flow (i.e. not assuming a role)
if c.AssumeRoleARN == "" {
return awsCredentials.NewChainCredentials(providers), nil
}

// Otherwise we need to construct and STS client with the main credentials, and verify
// that we can assume the defined role.
log.Printf("[INFO] Attempting to AssumeRole %s (SessionName: %q, ExternalId: %q)",
c.AssumeRoleARN, c.AssumeRoleSessionName, c.AssumeRoleExternalID)

creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
errs = append(errs, fmt.Errorf(`No valid credential sources found for AWS Provider.
creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`))
} else {
errs = append(errs, fmt.Errorf("Error loading credentials for AWS Provider: %s", err))
}
return nil, &multierror.Error{Errors: errs}
providing credentials for the AWS Provider`)
}

log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}

log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
S3ForcePathStyle: aws.Bool(c.S3ForcePathStyle),
awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
S3ForcePathStyle: aws.Bool(c.S3ForcePathStyle),
}

stsclient := sts.New(session.New(awsConfig))
assumeRoleProvider := &stscreds.AssumeRoleProvider{
Client: stsclient,
RoleARN: c.AssumeRoleARN,
}
if c.AssumeRoleSessionName != "" {
assumeRoleProvider.RoleSessionName = c.AssumeRoleSessionName
}
if c.AssumeRoleExternalID != "" {
assumeRoleProvider.ExternalID = aws.String(c.AssumeRoleExternalID)
}

providers = []awsCredentials.Provider{assumeRoleProvider}

assumeRoleCreds := awsCredentials.NewChainCredentials(providers)
_, err = assumeRoleCreds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
return nil, fmt.Errorf("The role %q cannot be assumed.\n\n" +
" There are a number of possible causes of this - the most common are:\n" +
" * The credentials used in order to assume the role are invalid\n" +
" * The credentials do not have appropriate permission to assume the role\n" +
" * The role ARN is not valid",
c.AssumeRoleARN)
}

stsclient := sts.New(session.New(awsConfig))
providers = []awsCredentials.Provider{&stscreds.AssumeRoleProvider{
Client: stsclient,
RoleARN: c.RoleArn,
}}
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}

return awsCredentials.NewChainCredentials(providers), nil
return assumeRoleCreds, nil
}

func setOptionalEndpoint(cfg *aws.Config) string {
Expand Down
78 changes: 46 additions & 32 deletions builtin/providers/aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
defer awsTs()

iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
Expand All @@ -72,7 +72,7 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
Expand All @@ -94,11 +94,11 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
},
Expand All @@ -119,15 +119,15 @@ func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
Expand All @@ -148,11 +148,11 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {

func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{400, iamResponse_GetUser_federatedFailure, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
Expand All @@ -173,11 +173,11 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {

func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
},
Expand Down Expand Up @@ -221,17 +221,17 @@ func TestAWSGetCredentials_shouldError(t *testing.T) {
c, err := GetCredentials(&cfg)
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error")
t.Fatal("Expected NoCredentialProviders error")
}
}
_, err = c.Get()
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatalf("Expected NoCredentialProviders error")
t.Fatal("Expected NoCredentialProviders error")
}
}
if err == nil {
t.Fatalf("Expected an error with empty env, keys, and IAM in AWS Config")
t.Fatal("Expected an error with empty env, keys, and IAM in AWS Config")
}
}

Expand All @@ -257,16 +257,18 @@ func TestAWSGetCredentials_shouldBeStatic(t *testing.T) {
}

creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}

if v.AccessKeyID != c.Key {
t.Fatalf("AccessKeyID mismatch, expected: (%s), got (%s)", c.Key, v.AccessKeyID)
}
Expand Down Expand Up @@ -295,12 +297,13 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) {
cfg := Config{}

creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down Expand Up @@ -346,12 +349,13 @@ func TestAWSGetCredentials_shouldIgnoreIAM(t *testing.T) {
}

creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down Expand Up @@ -379,6 +383,10 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err == nil {
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
Expand All @@ -404,6 +412,9 @@ func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {
if err != nil {
t.Fatalf("Getting static credentials w/ invalid EC2 endpoint failed: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

if v.ProviderName != "StaticProvider" {
t.Fatalf("Expected provider name to be %q, %q given", "StaticProvider", v.ProviderName)
Expand All @@ -426,12 +437,13 @@ func TestAWSGetCredentials_shouldCatchEC2RoleProvider(t *testing.T) {
defer ts()

creds, err := GetCredentials(&Config{})
if creds == nil {
t.Fatalf("Expected an EC2Role creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected an EC2Role creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Expected no error when getting creds: %s", err)
Expand Down Expand Up @@ -475,12 +487,13 @@ func TestAWSGetCredentials_shouldBeShared(t *testing.T) {
}

creds, err := GetCredentials(&Config{Profile: "myprofile", CredsFilename: file.Name()})
if creds == nil {
t.Fatalf("Expected a provider chain to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a provider chain to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand All @@ -505,12 +518,13 @@ func TestAWSGetCredentials_shouldBeENV(t *testing.T) {

cfg := Config{}
creds, err := GetCredentials(&cfg)
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatalf("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
Expand Down
Loading

0 comments on commit 305f114

Please sign in to comment.