Skip to content

Commit

Permalink
Merge pull request #46 from nathandines/assume-role
Browse files Browse the repository at this point in the history
Added support for assuming roles
  • Loading branch information
nathandines authored Apr 22, 2018
2 parents eab6bfb + f1c1a9a commit 8535bce
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 35 deletions.
7 changes: 7 additions & 0 deletions commands/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/service/cloudformation"
forge "github.com/nathandines/forge/forgelib"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -61,6 +62,12 @@ var deployCmd = &cobra.Command{
stack.StackPolicyBody = string(stackPolicyBody)
}

if assumeRoleArn != "" {
if err := forge.AssumeRole(assumeRoleArn); err != nil {
log.Fatal(err)
}
}

// Populate Stack ID
// Deliberately ignore errors here
stack.GetStackInfo()
Expand Down
7 changes: 7 additions & 0 deletions commands/destroy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@ import (
"time"

"github.com/aws/aws-sdk-go/service/cloudformation"
forge "github.com/nathandines/forge/forgelib"
"github.com/spf13/cobra"
)

var destroyCmd = &cobra.Command{
Use: "destroy",
Short: "Destroy a CloudFormation Stack",
Run: func(cmd *cobra.Command, args []string) {
if assumeRoleArn != "" {
if err := forge.AssumeRole(assumeRoleArn); err != nil {
log.Fatal(err)
}
}

// Populate Stack ID
if err := stack.GetStackInfo(); err != nil {
log.Fatal(err)
Expand Down
11 changes: 9 additions & 2 deletions commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

var stack = forge.Stack{}
var stackInProgressRegexp = regexp.MustCompile("^.*_IN_PROGRESS$")
var assumeRoleArn string

var rootCmd = &cobra.Command{
Use: "forge",
Expand All @@ -25,7 +26,7 @@ friendly for continuous delivery environments.
GitHub: https://github.com/nathandines/forge
`,
Version: "v0.1.1",
Version: "v0.1.2",
}

func init() {
Expand All @@ -40,7 +41,13 @@ func init() {
&stack.CfnRoleName,
"cfn-role-name",
"",
"Name of IAM role in the destination account for CloudFormation to assume",
"Name of IAM role in the destination account for the CloudFormation service to assume",
)
rootCmd.PersistentFlags().StringVar(
&assumeRoleArn,
"assume-role-arn",
"",
"Name of IAM role to assume BEFORE making requests to CloudFormation",
)
}

Expand Down
74 changes: 74 additions & 0 deletions forgelib/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package forgelib

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudformation"
"github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"strings"
)

var originalSession *session.Session
var cfnClient cloudformationiface.CloudFormationAPI // CloudFormation service
var stsClient stsiface.STSAPI // STS Service

func init() {
originalSession = session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
setupClients(originalSession)
}

func setupClients(sess *session.Session) {
cfnClient = cloudformation.New(sess)
stsClient = sts.New(sess)
}

// AssumeRole will change your credentials for Forge to those of an assumed role
// as specific by the ARN specified in the arguments to AssumeRole
func AssumeRole(roleArn string) error {
roleSessionName, err := getRoleSessionName()
if err != nil {
return err
}
assumeOut, err := stsClient.AssumeRole(&sts.AssumeRoleInput{
RoleSessionName: aws.String(roleSessionName),
RoleArn: aws.String(roleArn),
})
if err != nil {
return err
}
session, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Credentials: credentials.NewStaticCredentials(
*assumeOut.Credentials.AccessKeyId,
*assumeOut.Credentials.SecretAccessKey,
*assumeOut.Credentials.SessionToken,
),
},
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return err
}
setupClients(session)
return nil
}

// UnassumeAllRoles will change your credentials back to their original state
// after using AssumeRole
func UnassumeAllRoles() {
setupClients(originalSession)
}

func getRoleSessionName() (string, error) {
callerIdentity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}
roleSessionNameSlice := strings.Split(*callerIdentity.Arn, "/")
return roleSessionNameSlice[len(roleSessionNameSlice)-1], nil
}
22 changes: 22 additions & 0 deletions forgelib/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package forgelib

import "testing"

func TestGetRoleSessionName(t *testing.T) {
oldSTSClient := stsClient
defer func() { stsClient = oldSTSClient }()
stsClient = mockSTS{
callerArn: "arn:aws:iam::111111111111:user/nathan",
accountID: "111111111111",
}
expectedName := "nathan"

roleSessionName, err := getRoleSessionName()
if err != nil {
t.Fatalf("unexpected error, %v", err)
}

if e, g := expectedName, roleSessionName; e != g {
t.Errorf("expected \"%s\", got \"%s\"", e, g)
}
}
18 changes: 10 additions & 8 deletions forgelib/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,23 @@ func (s *Stack) Deploy() (output DeployOut, err error) {
}
}

var roleARN string
var roleARN *string
if s.CfnRoleName != "" {
roleARN, err = roleARNFromName(s.CfnRoleName)
roleARNString, err := roleARNFromName(s.CfnRoleName)
if err != nil {
return output, err
}
roleARN = &roleARNString
}

var inputStackPolicy string
var inputStackPolicy *string
if s.StackPolicyBody != "" {
jsonStackPolicy, err := yaml.YAMLToJSON([]byte(s.StackPolicyBody))
if err != nil {
return output, err
}
inputStackPolicy = string(jsonStackPolicy)
jsonStackPolicyString := string(jsonStackPolicy)
inputStackPolicy = &jsonStackPolicyString
}

if s.StackInfo == nil {
Expand All @@ -94,8 +96,8 @@ func (s *Stack) Deploy() (output DeployOut, err error) {
Capabilities: validationResult.Capabilities,
Tags: tags,
Parameters: inputParams,
RoleARN: &roleARN,
StackPolicyBody: &inputStackPolicy,
RoleARN: roleARN,
StackPolicyBody: inputStackPolicy,
},
)
if err != nil {
Expand All @@ -110,8 +112,8 @@ func (s *Stack) Deploy() (output DeployOut, err error) {
Capabilities: validationResult.Capabilities,
Tags: tags,
Parameters: inputParams,
RoleARN: &roleARN,
StackPolicyBody: &inputStackPolicy,
RoleARN: roleARN,
StackPolicyBody: inputStackPolicy,
},
)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions forgelib/destroy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ func (s *Stack) Destroy() (err error) {
return errorNoStackID
}

var roleARN string
var roleARN *string
if s.CfnRoleName != "" {
roleARN, err = roleARNFromName(s.CfnRoleName)
roleARNString, err := roleARNFromName(s.CfnRoleName)
if err != nil {
return err
}
roleARN = &roleARNString
}

// Delete stack by Stack ID. This removes the risk of deleting a stack with
Expand All @@ -25,7 +26,7 @@ func (s *Stack) Destroy() (err error) {
_, err = cfnClient.DeleteStack(
&cloudformation.DeleteStackInput{
StackName: &s.StackID,
RoleARN: &roleARN,
RoleARN: roleARN,
},
)
return
Expand Down
19 changes: 1 addition & 18 deletions forgelib/forge.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
package forgelib

import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudformation"
"github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)
import "github.com/aws/aws-sdk-go/service/cloudformation"

// Stack represents the attributes of a stack deployment, including the AWS
// parameters, and local resources which represent what needs to be deployed
Expand All @@ -22,17 +16,6 @@ type Stack struct {
TemplateBody string
}

var cfnClient cloudformationiface.CloudFormationAPI // CloudFormation service
var stsClient stsiface.STSAPI // STS Service

func init() {
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
cfnClient = cloudformation.New(sess)
stsClient = sts.New(sess)
}

// GetStackInfo populates the StackInfo for this object from the existing stack
// found in the environment
func (s *Stack) GetStackInfo() (err error) {
Expand Down
4 changes: 2 additions & 2 deletions forgelib/mock_cfn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ REQUIRED_PARAMETERS:
}
*m.stacks = append(*m.stacks, thisStack)

if *input.StackPolicyBody != "" {
if input.StackPolicyBody != nil {
(*m.stackPolicies)[m.newStackID] = *input.StackPolicyBody
}

Expand Down Expand Up @@ -168,7 +168,7 @@ REQUIRED_PARAMETERS:
(*m.stacks)[i].Tags = input.Tags
(*m.stacks)[i].Parameters = input.Parameters

if *input.StackPolicyBody != "" {
if input.StackPolicyBody != nil {
(*m.stackPolicies)[*(*m.stacks)[i].StackId] = *input.StackPolicyBody
}

Expand Down
14 changes: 12 additions & 2 deletions forgelib/mock_sts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@ import (
)

type mockSTS struct {
accountID string
accountID string
callerArn string
fakeCredentials *sts.Credentials
stsiface.STSAPI
}

func (m mockSTS) GetCallerIdentity(*sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) {
output := sts.GetCallerIdentityOutput{Account: aws.String(m.accountID)}
output := sts.GetCallerIdentityOutput{
Account: aws.String(m.accountID),
Arn: aws.String(m.callerArn),
}
return &output, nil
}

func (m mockSTS) AssumeRole(*sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
output := sts.AssumeRoleOutput{Credentials: m.fakeCredentials}
return &output, nil
}

0 comments on commit 8535bce

Please sign in to comment.