diff --git a/commands/deploy.go b/commands/deploy.go index 3a47db6..6421998 100644 --- a/commands/deploy.go +++ b/commands/deploy.go @@ -8,6 +8,7 @@ import ( "time" "github.com/aws/aws-sdk-go/service/cloudformation" + forge "github.com/nathandines/forge/forgelib" "github.com/spf13/cobra" ) @@ -61,6 +62,12 @@ var deployCmd = &cobra.Command{ stack.StackPolicyBody = string(stackPolicyBody) } + if assumeRoleArn != "" { + if err := forge.AssumeRole(assumeRoleArn); err != nil { + log.Fatal(err) + } + } + // Populate Stack ID // Deliberately ignore errors here stack.GetStackInfo() diff --git a/commands/destroy.go b/commands/destroy.go index 86e92a6..7d457f4 100644 --- a/commands/destroy.go +++ b/commands/destroy.go @@ -6,6 +6,7 @@ import ( "time" "github.com/aws/aws-sdk-go/service/cloudformation" + forge "github.com/nathandines/forge/forgelib" "github.com/spf13/cobra" ) @@ -13,6 +14,12 @@ var destroyCmd = &cobra.Command{ Use: "destroy", Short: "Destroy a CloudFormation Stack", Run: func(cmd *cobra.Command, args []string) { + if assumeRoleArn != "" { + if err := forge.AssumeRole(assumeRoleArn); err != nil { + log.Fatal(err) + } + } + // Populate Stack ID if err := stack.GetStackInfo(); err != nil { log.Fatal(err) diff --git a/commands/root.go b/commands/root.go index 7605b28..dfa9755 100644 --- a/commands/root.go +++ b/commands/root.go @@ -15,6 +15,7 @@ import ( var stack = forge.Stack{} var stackInProgressRegexp = regexp.MustCompile("^.*_IN_PROGRESS$") +var assumeRoleArn string var rootCmd = &cobra.Command{ Use: "forge", @@ -25,7 +26,7 @@ friendly for continuous delivery environments. GitHub: https://github.com/nathandines/forge `, - Version: "v0.1.1", + Version: "v0.1.2", } func init() { @@ -40,7 +41,13 @@ func init() { &stack.CfnRoleName, "cfn-role-name", "", - "Name of IAM role in the destination account for CloudFormation to assume", + "Name of IAM role in the destination account for the CloudFormation service to assume", + ) + rootCmd.PersistentFlags().StringVar( + &assumeRoleArn, + "assume-role-arn", + "", + "Name of IAM role to assume BEFORE making requests to CloudFormation", ) } diff --git a/forgelib/auth.go b/forgelib/auth.go new file mode 100644 index 0000000..9b08a64 --- /dev/null +++ b/forgelib/auth.go @@ -0,0 +1,74 @@ +package forgelib + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudformation" + "github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" + "strings" +) + +var originalSession *session.Session +var cfnClient cloudformationiface.CloudFormationAPI // CloudFormation service +var stsClient stsiface.STSAPI // STS Service + +func init() { + originalSession = session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + setupClients(originalSession) +} + +func setupClients(sess *session.Session) { + cfnClient = cloudformation.New(sess) + stsClient = sts.New(sess) +} + +// AssumeRole will change your credentials for Forge to those of an assumed role +// as specific by the ARN specified in the arguments to AssumeRole +func AssumeRole(roleArn string) error { + roleSessionName, err := getRoleSessionName() + if err != nil { + return err + } + assumeOut, err := stsClient.AssumeRole(&sts.AssumeRoleInput{ + RoleSessionName: aws.String(roleSessionName), + RoleArn: aws.String(roleArn), + }) + if err != nil { + return err + } + session, err := session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + Credentials: credentials.NewStaticCredentials( + *assumeOut.Credentials.AccessKeyId, + *assumeOut.Credentials.SecretAccessKey, + *assumeOut.Credentials.SessionToken, + ), + }, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return err + } + setupClients(session) + return nil +} + +// UnassumeAllRoles will change your credentials back to their original state +// after using AssumeRole +func UnassumeAllRoles() { + setupClients(originalSession) +} + +func getRoleSessionName() (string, error) { + callerIdentity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + return "", err + } + roleSessionNameSlice := strings.Split(*callerIdentity.Arn, "/") + return roleSessionNameSlice[len(roleSessionNameSlice)-1], nil +} diff --git a/forgelib/auth_test.go b/forgelib/auth_test.go new file mode 100644 index 0000000..40cd376 --- /dev/null +++ b/forgelib/auth_test.go @@ -0,0 +1,22 @@ +package forgelib + +import "testing" + +func TestGetRoleSessionName(t *testing.T) { + oldSTSClient := stsClient + defer func() { stsClient = oldSTSClient }() + stsClient = mockSTS{ + callerArn: "arn:aws:iam::111111111111:user/nathan", + accountID: "111111111111", + } + expectedName := "nathan" + + roleSessionName, err := getRoleSessionName() + if err != nil { + t.Fatalf("unexpected error, %v", err) + } + + if e, g := expectedName, roleSessionName; e != g { + t.Errorf("expected \"%s\", got \"%s\"", e, g) + } +} diff --git a/forgelib/deploy.go b/forgelib/deploy.go index 6781fc5..6c81675 100644 --- a/forgelib/deploy.go +++ b/forgelib/deploy.go @@ -68,21 +68,23 @@ func (s *Stack) Deploy() (output DeployOut, err error) { } } - var roleARN string + var roleARN *string if s.CfnRoleName != "" { - roleARN, err = roleARNFromName(s.CfnRoleName) + roleARNString, err := roleARNFromName(s.CfnRoleName) if err != nil { return output, err } + roleARN = &roleARNString } - var inputStackPolicy string + var inputStackPolicy *string if s.StackPolicyBody != "" { jsonStackPolicy, err := yaml.YAMLToJSON([]byte(s.StackPolicyBody)) if err != nil { return output, err } - inputStackPolicy = string(jsonStackPolicy) + jsonStackPolicyString := string(jsonStackPolicy) + inputStackPolicy = &jsonStackPolicyString } if s.StackInfo == nil { @@ -94,8 +96,8 @@ func (s *Stack) Deploy() (output DeployOut, err error) { Capabilities: validationResult.Capabilities, Tags: tags, Parameters: inputParams, - RoleARN: &roleARN, - StackPolicyBody: &inputStackPolicy, + RoleARN: roleARN, + StackPolicyBody: inputStackPolicy, }, ) if err != nil { @@ -110,8 +112,8 @@ func (s *Stack) Deploy() (output DeployOut, err error) { Capabilities: validationResult.Capabilities, Tags: tags, Parameters: inputParams, - RoleARN: &roleARN, - StackPolicyBody: &inputStackPolicy, + RoleARN: roleARN, + StackPolicyBody: inputStackPolicy, }, ) if err != nil { diff --git a/forgelib/destroy.go b/forgelib/destroy.go index 5f7c4fb..b6cac3e 100644 --- a/forgelib/destroy.go +++ b/forgelib/destroy.go @@ -10,12 +10,13 @@ func (s *Stack) Destroy() (err error) { return errorNoStackID } - var roleARN string + var roleARN *string if s.CfnRoleName != "" { - roleARN, err = roleARNFromName(s.CfnRoleName) + roleARNString, err := roleARNFromName(s.CfnRoleName) if err != nil { return err } + roleARN = &roleARNString } // Delete stack by Stack ID. This removes the risk of deleting a stack with @@ -25,7 +26,7 @@ func (s *Stack) Destroy() (err error) { _, err = cfnClient.DeleteStack( &cloudformation.DeleteStackInput{ StackName: &s.StackID, - RoleARN: &roleARN, + RoleARN: roleARN, }, ) return diff --git a/forgelib/forge.go b/forgelib/forge.go index 92893b8..4308b2f 100644 --- a/forgelib/forge.go +++ b/forgelib/forge.go @@ -1,12 +1,6 @@ package forgelib -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" -) +import "github.com/aws/aws-sdk-go/service/cloudformation" // Stack represents the attributes of a stack deployment, including the AWS // parameters, and local resources which represent what needs to be deployed @@ -22,17 +16,6 @@ type Stack struct { TemplateBody string } -var cfnClient cloudformationiface.CloudFormationAPI // CloudFormation service -var stsClient stsiface.STSAPI // STS Service - -func init() { - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - cfnClient = cloudformation.New(sess) - stsClient = sts.New(sess) -} - // GetStackInfo populates the StackInfo for this object from the existing stack // found in the environment func (s *Stack) GetStackInfo() (err error) { diff --git a/forgelib/mock_cfn_test.go b/forgelib/mock_cfn_test.go index 2d98f9a..81f6a42 100644 --- a/forgelib/mock_cfn_test.go +++ b/forgelib/mock_cfn_test.go @@ -112,7 +112,7 @@ REQUIRED_PARAMETERS: } *m.stacks = append(*m.stacks, thisStack) - if *input.StackPolicyBody != "" { + if input.StackPolicyBody != nil { (*m.stackPolicies)[m.newStackID] = *input.StackPolicyBody } @@ -168,7 +168,7 @@ REQUIRED_PARAMETERS: (*m.stacks)[i].Tags = input.Tags (*m.stacks)[i].Parameters = input.Parameters - if *input.StackPolicyBody != "" { + if input.StackPolicyBody != nil { (*m.stackPolicies)[*(*m.stacks)[i].StackId] = *input.StackPolicyBody } diff --git a/forgelib/mock_sts_test.go b/forgelib/mock_sts_test.go index 728adc7..966e63e 100644 --- a/forgelib/mock_sts_test.go +++ b/forgelib/mock_sts_test.go @@ -7,11 +7,21 @@ import ( ) type mockSTS struct { - accountID string + accountID string + callerArn string + fakeCredentials *sts.Credentials stsiface.STSAPI } func (m mockSTS) GetCallerIdentity(*sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { - output := sts.GetCallerIdentityOutput{Account: aws.String(m.accountID)} + output := sts.GetCallerIdentityOutput{ + Account: aws.String(m.accountID), + Arn: aws.String(m.callerArn), + } + return &output, nil +} + +func (m mockSTS) AssumeRole(*sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { + output := sts.AssumeRoleOutput{Credentials: m.fakeCredentials} return &output, nil }