Skip to content

Commit

Permalink
Merge pull request #685 from stanvit/session-tags
Browse files Browse the repository at this point in the history
STS AssumeRole Session Tags implementation
  • Loading branch information
mtibben authored Mar 18, 2021
2 parents 351fe9c + a3e7912 commit 4270224
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 13 deletions.
19 changes: 19 additions & 0 deletions USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ include_profile = root
role_arn=arn:aws:iam::123456789:role/administrators
```

#### `session_tags` and `transitive_session_tags`

It is possible to set [session tags](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_session-tags.html) when `AssumeRole` is used. Two custom config variables could be defined for that: `session_tags` and `transitive_session_tags`. The former defines a comma separated key=value list of tags and the latter is a comma separated list of tags that should be persited during role chaining:

```ini
[profile root]
region=eu-west-1

[profile order-dev]
source_profile = root
role_arn=arn:aws:iam::123456789:role/developers
session_tags = key1=value1,key2=value2,key3=value3
transitive_session_tags = key1,key2
```


### Environment variables

Expand Down Expand Up @@ -132,6 +147,10 @@ To override session durations (used in `exec` and `login`):

Note that the session durations above expect a unit after the number (e.g. 12h or 43200s).

To override or set session tagging (used in `exec`):
* `AWS_ROLE_TAGS`: Comma separated key-value list of tags passed with the `AssumeRole` call, overrides `session_tags` profile config variable
* `AWS_TRANSITIVE_TAGS`: Comma separated list of transitive tags passed with the `AssumeRole` call, overrides `transitive_session_tags` profile config variable


## Backends

Expand Down
29 changes: 23 additions & 6 deletions vault/assumeroleprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (

// AssumeRoleProvider retrieves temporary credentials from STS using AssumeRole
type AssumeRoleProvider struct {
StsClient *sts.STS
RoleARN string
RoleSessionName string
ExternalID string
Duration time.Duration
ExpiryWindow time.Duration
StsClient *sts.STS
RoleARN string
RoleSessionName string
ExternalID string
Duration time.Duration
ExpiryWindow time.Duration
Tags map[string]string
TransitiveTagKeys []string
Mfa
credentials.Expiry
}
Expand Down Expand Up @@ -67,6 +69,21 @@ func (p *AssumeRoleProvider) assumeRole() (*sts.Credentials, error) {
}
}

if len(p.Tags) > 0 {
input.Tags = make([]*sts.Tag, 0)
for key, value := range p.Tags {
tag := &sts.Tag{
Key: aws.String(key),
Value: aws.String(value),
}
input.Tags = append(input.Tags, tag)
}
}

if len(p.TransitiveTagKeys) > 0 {
input.TransitiveTagKeys = aws.StringSlice(p.TransitiveTagKeys)
}

log.Printf("Using STS endpoint %s", p.StsClient.Endpoint)

resp, err := p.StsClient.AssumeRole(input)
Expand Down
55 changes: 54 additions & 1 deletion vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ type ProfileSection struct {
WebIdentityTokenFile string `ini:"web_identity_token_file,omitempty"`
WebIdentityTokenProcess string `ini:"web_identity_token_process,omitempty"`
STSRegionalEndpoints string `ini:"sts_regional_endpoints,omitempty"`
SessionTags string `ini:"session_tags,omitempty"`
TransitiveSessionTags string `ini:"transitive_session_tags,omitempty"`
}

func (s ProfileSection) IsEmpty() bool {
Expand Down Expand Up @@ -326,6 +328,15 @@ func (cl *ConfigLoader) populateFromConfigFile(config *Config, profileName strin
if config.STSRegionalEndpoints == "" {
config.STSRegionalEndpoints = psection.STSRegionalEndpoints
}
if sessionTags := psection.SessionTags; sessionTags != "" && config.SessionTags == nil {
err := config.SetSessionTags(sessionTags)
if err != nil {
return fmt.Errorf("Failed to parse session_tags profile setting: %s", err)
}
}
if transitiveSessionTags := psection.TransitiveSessionTags; transitiveSessionTags != "" && config.TransitiveSessionTags == nil {
config.SetTransitiveSessionTags(transitiveSessionTags)
}

if psection.ParentProfile != "" {
fmt.Fprint(os.Stderr, "Warning: parent_profile is deprecated, please use include_profile instead in your AWS config\n")
Expand Down Expand Up @@ -406,7 +417,7 @@ func (cl *ConfigLoader) populateFromEnv(profile *Config) {
}
}

// AWS_ROLE_ARN and AWS_ROLE_SESSION_NAME only apply to the target profile
// AWS_ROLE_ARN, AWS_ROLE_SESSION_NAME, AWS_SESSION_TAGS and AWS_TRANSITIVE_TAGS only apply to the target profile
if profile.ProfileName == cl.ActiveProfile {
if roleARN := os.Getenv("AWS_ROLE_ARN"); roleARN != "" && profile.RoleARN == "" {
log.Printf("Using role_arn %q from AWS_ROLE_ARN", roleARN)
Expand All @@ -417,6 +428,19 @@ func (cl *ConfigLoader) populateFromEnv(profile *Config) {
log.Printf("Using role_session_name %q from AWS_ROLE_SESSION_NAME", roleSessionName)
profile.RoleSessionName = roleSessionName
}

if sessionTags := os.Getenv("AWS_SESSION_TAGS"); sessionTags != "" && profile.SessionTags == nil {
err := profile.SetSessionTags(sessionTags)
if err != nil {
log.Fatalf("Failed to parse AWS_SESSION_TAGS environment variable: %s", err)
}
log.Printf("Using session_tags %v from AWS_SESSION_TAGS", profile.SessionTags)
}

if transitiveSessionTags := os.Getenv("AWS_TRANSITIVE_TAGS"); transitiveSessionTags != "" && profile.TransitiveSessionTags == nil {
profile.SetTransitiveSessionTags(transitiveSessionTags)
log.Printf("Using transitive_session_tags %v from AWS_TRANSITIVE_TAGS", profile.TransitiveSessionTags)
}
}
}

Expand Down Expand Up @@ -511,6 +535,35 @@ type Config struct {

// SSORoleName specifies the AWS SSO Role name to target.
SSORoleName string

// SessionTags specifies assumed role Session Tags
SessionTags map[string]string

// TransitiveSessionTags specifies assumed role Transitive Session Tags keys
TransitiveSessionTags []string
}

// SetSessionTags parses a comma separated key=vaue string and sets Config.SessionTags map
func (c *Config) SetSessionTags(s string) error {
c.SessionTags = make(map[string]string)
for _, tag := range strings.Split(s, ",") {
kvPair := strings.SplitN(tag, "=", 2)
if len(kvPair) != 2 {
return errors.New("session tags string must be <key1>=<value1>,[<key2>=<value2>[,...]]")
}
c.SessionTags[strings.TrimSpace(kvPair[0])] = strings.TrimSpace(kvPair[1])
}

return nil
}

// SetTransitiveSessionTags parses a comma separated string and sets Config.TransitiveSessionTags
func (c *Config) SetTransitiveSessionTags(s string) {
for _, tag := range strings.Split(s, ",") {
if tag = strings.TrimSpace(tag); tag != "" {
c.TransitiveSessionTags = append(c.TransitiveSessionTags, tag)
}
}
}

func (c *Config) IsChained() bool {
Expand Down
201 changes: 201 additions & 0 deletions vault/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,204 @@ source_profile=root
t.Fatalf("Expected '%s', got '%s'", expectedSourceProfileName, config.SourceProfileName)
}
}

func TestSetSessionTags(t *testing.T) {
var testCases = []struct {
stringValue string
expected map[string]string
ok bool
}{
{"tag1=value1", map[string]string{"tag1": "value1"}, true},
{
"tag2=value2,tag3=value3,tag4=value4",
map[string]string{"tag2": "value2", "tag3": "value3", "tag4": "value4"},
true,
},
{" tagA = valueA , tagB = valueB , tagC = valueC ",
map[string]string{"tagA": "valueA", "tagB": "valueB", "tagC": "valueC"},
true,
},
{"", nil, false},
{"tag1=value1,", nil, false},
{"tagA=valueA,tagB", nil, false},
{"tagOne,tagTwo=valueTwo", nil, false},
{"tagI=valueI,tagII,tagIII=valueIII", nil, false},
}

for _, tc := range testCases {
config := vault.Config{}
err := config.SetSessionTags(tc.stringValue)
if tc.ok {
if err != nil {
t.Fatalf("Unsexpected parsing error: %s", err)
}
if !reflect.DeepEqual(tc.expected, config.SessionTags) {
t.Fatalf("Expected SessionTags: %+v, got %+v", tc.expected, config.SessionTags)
}
} else {
if err == nil {
t.Fatalf("Expected an error parsing %#v, but got none", tc.stringValue)
}
}
}
}

func TestSetTransitiveSessionTags(t *testing.T) {
var testCases = []struct {
stringValue string
expected []string
}{
{"tag1", []string{"tag1"}},
{"tag2,tag3,tag4", []string{"tag2", "tag3", "tag4"}},
{" tagA , tagB , tagC ", []string{"tagA", "tagB", "tagC"}},
{"tag1,", []string{"tag1"}},
{",tagA", []string{"tagA"}},
{"", nil},
{",", nil},
}

for _, tc := range testCases {
config := vault.Config{}
config.SetTransitiveSessionTags(tc.stringValue)
if !reflect.DeepEqual(tc.expected, config.TransitiveSessionTags) {
t.Fatalf("Expected TransitiveSessionTags: %+v, got %+v", tc.expected, config.TransitiveSessionTags)
}
}
}

func TestSessionTaggingFromIni(t *testing.T) {
os.Unsetenv("AWS_SESSION_TAGS")
os.Unsetenv("AWS_TRANSITIVE_TAGS")
f := newConfigFile(t, []byte(`
[profile tagged]
session_tags = tag1 = value1 , tag2=value2 ,tag3=value3
transitive_session_tags = tagOne ,tagTwo,tagThree
`))
defer os.Remove(f)

configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "tagged"}
config, err := configLoader.LoadFromProfile("tagged")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
expectedSessionTags := map[string]string{
"tag1": "value1",
"tag2": "value2",
"tag3": "value3",
}
if !reflect.DeepEqual(expectedSessionTags, config.SessionTags) {
t.Fatalf("Expected session_tags: %+v, got %+v", expectedSessionTags, config.SessionTags)
}

expectedTransitiveSessionTags := []string{"tagOne", "tagTwo", "tagThree"}
if !reflect.DeepEqual(expectedTransitiveSessionTags, config.TransitiveSessionTags) {
t.Fatalf("Expected transitive_session_tags: %+v, got %+v", expectedTransitiveSessionTags, config.TransitiveSessionTags)
}
}

func TestSessionTaggingFromEnvironment(t *testing.T) {
os.Setenv("AWS_SESSION_TAGS", " tagA = val1 , tagB=val2 ,tagC=val3")
os.Setenv("AWS_TRANSITIVE_TAGS", " tagD ,tagE")
defer os.Unsetenv("AWS_SESSION_TAGS")
defer os.Unsetenv("AWS_TRANSITIVE_TAGS")

f := newConfigFile(t, []byte(`
[profile tagged]
session_tags = tag1 = value1 , tag2=value2 ,tag3=value3
transitive_session_tags = tagOne ,tagTwo,tagThree
`))
defer os.Remove(f)

configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "tagged"}
config, err := configLoader.LoadFromProfile("tagged")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}
expectedSessionTags := map[string]string{
"tagA": "val1",
"tagB": "val2",
"tagC": "val3",
}
if !reflect.DeepEqual(expectedSessionTags, config.SessionTags) {
t.Fatalf("Expected session_tags: %+v, got %+v", expectedSessionTags, config.SessionTags)
}

expectedTransitiveSessionTags := []string{"tagD", "tagE"}
if !reflect.DeepEqual(expectedTransitiveSessionTags, config.TransitiveSessionTags) {
t.Fatalf("Expected transitive_session_tags: %+v, got %+v", expectedTransitiveSessionTags, config.TransitiveSessionTags)
}
}

func TestSessionTaggingFromEnvironmentChainedRoles(t *testing.T) {
os.Setenv("AWS_SESSION_TAGS", "tagI=valI")
os.Setenv("AWS_TRANSITIVE_TAGS", " tagII")
defer os.Unsetenv("AWS_SESSION_TAGS")
defer os.Unsetenv("AWS_TRANSITIVE_TAGS")

f := newConfigFile(t, []byte(`
[profile base]
[profile interim]
session_tags=tag1=value1
transitive_session_tags=tag2
source_profile = base
[profile target]
session_tags=tagA=valueA
transitive_session_tags=tagB
source_profile = interim
`))
defer os.Remove(f)

configFile, err := vault.LoadConfig(f)
if err != nil {
t.Fatal(err)
}
configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "target"}
config, err := configLoader.LoadFromProfile("target")
if err != nil {
t.Fatalf("Should have found a profile: %v", err)
}

// Testing target profile, should have values populated from environment variables
expectedSessionTags := map[string]string{"tagI": "valI"}
if !reflect.DeepEqual(expectedSessionTags, config.SessionTags) {
t.Fatalf("Expected session_tags: %+v, got %+v", expectedSessionTags, config.SessionTags)
}

expectedTransitiveSessionTags := []string{"tagII"}
if !reflect.DeepEqual(expectedTransitiveSessionTags, config.TransitiveSessionTags) {
t.Fatalf("Expected transitive_session_tags: %+v, got %+v", expectedTransitiveSessionTags, config.TransitiveSessionTags)
}

// Testing interim profile, parameters should come from the config, not environment
interimConfig := config.SourceProfile
expectedSessionTags = map[string]string{"tag1": "value1"}
if !reflect.DeepEqual(expectedSessionTags, interimConfig.SessionTags) {
t.Fatalf("Expected session_tags: %+v, got %+v", expectedSessionTags, interimConfig.SessionTags)
}

expectedTransitiveSessionTags = []string{"tag2"}
if !reflect.DeepEqual(expectedTransitiveSessionTags, interimConfig.TransitiveSessionTags) {
t.Fatalf("Expected transitive_session_tags: %+v, got %+v", expectedTransitiveSessionTags, interimConfig.TransitiveSessionTags)
}

// Testing base profile, should have empty parameters
baseConfig := interimConfig.SourceProfile
if len(baseConfig.SessionTags) > 0 {
t.Fatalf("Expected session_tags to be empty, got %+v", baseConfig.SessionTags)
}

expectedTransitiveSessionTags = []string{}
if len(baseConfig.TransitiveSessionTags) > 0 {
t.Fatalf("Expected transitive_session_tags to be empty, got %+v", baseConfig.TransitiveSessionTags)
}
}
14 changes: 8 additions & 6 deletions vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,14 @@ func NewAssumeRoleProvider(creds *credentials.Credentials, k keyring.Keyring, co
}

p := &AssumeRoleProvider{
StsClient: sts.New(sess),
RoleARN: config.RoleARN,
RoleSessionName: config.RoleSessionName,
ExternalID: config.ExternalID,
Duration: config.AssumeRoleDuration,
ExpiryWindow: defaultExpirationWindow,
StsClient: sts.New(sess),
RoleARN: config.RoleARN,
RoleSessionName: config.RoleSessionName,
ExternalID: config.ExternalID,
Duration: config.AssumeRoleDuration,
ExpiryWindow: defaultExpirationWindow,
Tags: config.SessionTags,
TransitiveTagKeys: config.TransitiveSessionTags,
Mfa: Mfa{
MfaSerial: config.MfaSerial,
MfaToken: config.MfaToken,
Expand Down

0 comments on commit 4270224

Please sign in to comment.