Skip to content

Commit

Permalink
Fix federated user ID parsing
Browse files Browse the repository at this point in the history
* refactor the parsing logic to make it easier
  • Loading branch information
nckturner committed Nov 28, 2023
1 parent 60c370e commit 9fa86f7
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 40 deletions.
48 changes: 32 additions & 16 deletions pkg/arn/arn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,35 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
)

type PrincipalType int

const (
// Supported principals
NONE PrincipalType = iota
ROLE
USER
ROOT
FEDERATED_USER
ASSUMED_ROLE
)

// Canonicalize validates IAM resources are appropriate for the authenticator
// and converts STS assumed roles into the IAM role resource.
//
// Supported IAM resources are:
// * AWS account: arn:aws:iam::123456789012:root
// * IAM user: arn:aws:iam::123456789012:user/Bob
// * IAM role: arn:aws:iam::123456789012:role/S3Access
// * IAM Assumed role: arn:aws:sts::123456789012:assumed-role/Accounting-Role/Mary (converted to IAM role)
// * Federated user: arn:aws:sts::123456789012:federated-user/Bob
func Canonicalize(arn string) (string, error) {
// - AWS root user: arn:aws:iam::123456789012:root
// - IAM user: arn:aws:iam::123456789012:user/Bob
// - IAM role: arn:aws:iam::123456789012:role/S3Access
// - IAM Assumed role: arn:aws:sts::123456789012:assumed-role/Accounting-Role/Mary (converted to IAM role)
// - Federated user: arn:aws:sts::123456789012:federated-user/Bob
func Canonicalize(arn string) (PrincipalType, string, error) {
parsed, err := awsarn.Parse(arn)
if err != nil {
return "", fmt.Errorf("arn '%s' is invalid: '%v'", arn, err)
return NONE, "", fmt.Errorf("arn '%s' is invalid: '%v'", arn, err)
}

if err := checkPartition(parsed.Partition); err != nil {
return "", fmt.Errorf("arn '%s' does not have a recognized partition", arn)
return NONE, "", fmt.Errorf("arn '%s' does not have a recognized partition", arn)
}

parts := strings.Split(parsed.Resource, "/")
Expand All @@ -34,27 +46,31 @@ func Canonicalize(arn string) (string, error) {
case "sts":
switch resource {
case "federated-user":
return arn, nil
return FEDERATED_USER, arn, nil
case "assumed-role":
if len(parts) < 3 {
return "", fmt.Errorf("assumed-role arn '%s' does not have a role", arn)
return NONE, "", fmt.Errorf("assumed-role arn '%s' does not have a role", arn)
}
// IAM ARNs can contain paths, part[0] is resource, parts[len(parts)] is the SessionName.
role := strings.Join(parts[1:len(parts)-1], "/")
return fmt.Sprintf("arn:%s:iam::%s:role/%s", parsed.Partition, parsed.AccountID, role), nil
return ASSUMED_ROLE, fmt.Sprintf("arn:%s:iam::%s:role/%s", parsed.Partition, parsed.AccountID, role), nil
default:
return "", fmt.Errorf("unrecognized resource %s for service sts", parsed.Resource)
return NONE, "", fmt.Errorf("unrecognized resource %s for service sts", parsed.Resource)
}
case "iam":
switch resource {
case "role", "user", "root":
return arn, nil
case "role":
return ROLE, arn, nil
case "user":
return USER, arn, nil
case "root":
return ROOT, arn, nil
default:
return "", fmt.Errorf("unrecognized resource %s for service iam", parsed.Resource)
return NONE, "", fmt.Errorf("unrecognized resource %s for service iam", parsed.Resource)
}
}

return "", fmt.Errorf("service %s in arn %s is not a valid service for identities", parsed.Service, arn)
return NONE, "", fmt.Errorf("service %s in arn %s is not a valid service for identities", parsed.Service, arn)
}

func checkPartition(partition string) error {
Expand Down
2 changes: 1 addition & 1 deletion pkg/arn/arn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ var arnTests = []struct {

func TestUserARN(t *testing.T) {
for _, tc := range arnTests {
actual, err := Canonicalize(tc.arn)
_, actual, err := Canonicalize(tc.arn)
if err != nil && tc.err == nil || err == nil && tc.err != nil {
t.Errorf("Canoncialize(%s) expected err: %v, actual err: %v", tc.arn, tc.err, err)
continue
Expand Down
2 changes: 1 addition & 1 deletion pkg/mapper/crd/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (c *Controller) syncHandler(key string) (err error) {
if iamIdentityMapping.Spec.ARN != "" {
iamIdentityMappingCopy := iamIdentityMapping.DeepCopy()

canonicalizedARN, err := arn.Canonicalize(strings.ToLower(iamIdentityMapping.Spec.ARN))
_, canonicalizedARN, err := arn.Canonicalize(strings.ToLower(iamIdentityMapping.Spec.ARN))
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/mapper/dynamicfile/dynamicfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ func (ms *DynamicFileMapStore) saveMap(
ms.awsAccounts = make(map[string]interface{})

for _, user := range userMappings {
key, _ := arn.Canonicalize(strings.ToLower(user.UserARN))
_, key, _ := arn.Canonicalize(strings.ToLower(user.UserARN))
if ms.userIDStrict {
key = user.UserId
}
ms.users[key] = user
}
for _, role := range roleMappings {
key, _ := arn.Canonicalize(strings.ToLower(role.RoleARN))
_, key, _ := arn.Canonicalize(strings.ToLower(role.RoleARN))
if ms.userIDStrict {
key = role.UserId
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/mapper/file/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package file

import (
"fmt"
"sigs.k8s.io/aws-iam-authenticator/pkg/token"
"strings"

"sigs.k8s.io/aws-iam-authenticator/pkg/token"

"sigs.k8s.io/aws-iam-authenticator/pkg/arn"
"sigs.k8s.io/aws-iam-authenticator/pkg/config"
"sigs.k8s.io/aws-iam-authenticator/pkg/mapper"
Expand Down Expand Up @@ -32,7 +33,7 @@ func NewFileMapper(cfg config.Config) (*FileMapper, error) {
return nil, err
}
if m.RoleARN != "" {
canonicalizedARN, err := arn.Canonicalize(m.RoleARN)
_, canonicalizedARN, err := arn.Canonicalize(m.RoleARN)
if err != nil {
return nil, err
}
Expand All @@ -47,7 +48,7 @@ func NewFileMapper(cfg config.Config) (*FileMapper, error) {
}
var key string
if m.UserARN != "" {
canonicalizedARN, err := arn.Canonicalize(strings.ToLower(m.UserARN))
_, canonicalizedARN, err := arn.Canonicalize(strings.ToLower(m.UserARN))
if err != nil {
return nil, fmt.Errorf("error canonicalizing ARN: %v", err)
}
Expand Down
43 changes: 28 additions & 15 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,29 +600,42 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
return nil, NewSTSError(err.Error())
}

// parse the response into an Identity
id := &Identity{
ARN: callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.Arn,
AccountID: callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.Account,
AccessKeyID: accessKeyID,
}
id.CanonicalARN, err = arn.Canonicalize(id.ARN)
return getIdentityFromSTSResponse(id, callerIdentity)
}

func getIdentityFromSTSResponse(id *Identity, wrapper getCallerIdentityWrapper) (*Identity, error) {
var err error
result := wrapper.GetCallerIdentityResponse.GetCallerIdentityResult

id.ARN = result.Arn
id.AccountID = result.Account

var principalType arn.PrincipalType
principalType, id.CanonicalARN, err = arn.Canonicalize(id.ARN)
if err != nil {
return nil, NewSTSError(err.Error())
}

// The user ID is either UserID:SessionName (for assumed roles) or just
// UserID (for IAM User principals).
userIDParts := strings.Split(callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID, ":")
if len(userIDParts) == 2 {
id.UserID = userIDParts[0]
id.SessionName = userIDParts[1]
} else if len(userIDParts) == 1 {
id.UserID = userIDParts[0]
// The user ID is one of:
// 1. UserID:SessionName (for assumed roles)
// 2. UserID (for IAM User principals).
// 3. AWSAccount:CallerSpecifiedName (for federated users)
// We want the entire UserID for federated users because otherwise,
// its just the account ID and is indistinguishable from the UserID
// of the root user.
if principalType == arn.FEDERATED_USER || principalType == arn.USER || principalType == arn.ROOT {
id.UserID = result.UserID
} else {
return nil, STSError{fmt.Sprintf(
"malformed UserID %q",
callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)}
userIDParts := strings.Split(result.UserID, ":")
if len(userIDParts) == 2 {
id.UserID = userIDParts[0]
id.SessionName = userIDParts[1]
} else {
return nil, NewSTSError(fmt.Sprintf("malformed UserID %q", result.UserID))
}
}

return id, nil
Expand Down
106 changes: 104 additions & 2 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/prometheus/client_golang/prometheus"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/pkg/apis/clientauthentication"
Expand Down Expand Up @@ -278,7 +279,7 @@ func TestVerifyInvalidCanonicalARNError(t *testing.T) {
}

func TestVerifyInvalidUserIDError(t *testing.T) {
_, err := newVerifier("aws", 200, jsonResponse("arn:aws:iam::123456789012:user/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken)
_, err := newVerifier("aws", 200, jsonResponse("arn:aws:iam::123456789012:role/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken)
errorContains(t, err, "malformed UserID")
assertSTSError(t, err)
}
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestVerifyNoSession(t *testing.T) {
}

func TestVerifySessionName(t *testing.T) {
arn := "arn:aws:iam::123456789012:user/Alice"
arn := "arn:aws:iam::123456789012:role/Alice"
account := "123456789012"
userID := "Alice"
session := "session-name"
Expand Down Expand Up @@ -413,3 +414,104 @@ func TestFormatJson(t *testing.T) {
})
}
}

func TestGetIdentityFromSTSResponse(t *testing.T) {
var (
accessKeyID = "AKIAVVVVVVVVVVVAGAVA"
defaultID = Identity{
AccessKeyID: accessKeyID,
}
defaultAccount = "123456789012"
rootUserARN = "arn:aws:iam::123456789012:root"
userARN = "arn:aws:iam::123456789012:user/Alice"
userID = "AIDAIYCCCMMMMMMMMGGDA"
fedUserID = "123456789012:Alice"
fedUserARN = "arn:aws:sts::123456789012:federated-user/Alice"
roleARN = "arn:aws:iam::123456789012:role/Alice"
roleID = "AROAZZCCCNNNNNNNNFFFA"
)

cases := []struct {
name string
inputID Identity
inputResponse getCallerIdentityWrapper
expectedErr bool
want Identity
}{
{
name: "Root User",
inputID: defaultID,
inputResponse: response(defaultAccount, defaultAccount, rootUserARN),
expectedErr: false,
want: Identity{
ARN: rootUserARN,
CanonicalARN: rootUserARN,
AccountID: defaultAccount,
UserID: defaultAccount,
AccessKeyID: accessKeyID,
},
},
{
name: "User",
inputID: defaultID,
inputResponse: response(defaultAccount, userID, userARN),
expectedErr: false,
want: Identity{
ARN: userARN,
CanonicalARN: userARN,
AccountID: defaultAccount,
UserID: userID,
AccessKeyID: accessKeyID,
},
},
{
name: "Role",
inputID: defaultID,
inputResponse: response(defaultAccount, roleID, roleARN),
expectedErr: false,
want: Identity{
ARN: roleARN,
CanonicalARN: roleARN,
AccountID: defaultAccount,
UserID: roleID,
AccessKeyID: accessKeyID,
},
},
{
name: "Federated User",
inputID: defaultID,
inputResponse: response(defaultAccount, fedUserID, fedUserARN),
expectedErr: false,
want: Identity{
ARN: fedUserARN,
CanonicalARN: fedUserARN,
AccountID: defaultAccount,
UserID: fedUserID,
AccessKeyID: accessKeyID,
},
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {

if got, err := getIdentityFromSTSResponse(&c.inputID, c.inputResponse); err == nil {
if c.expectedErr {
t.Errorf("expected err to be nil but was %s", err)
}

if diff := cmp.Diff(c.want, *got); diff != "" {
t.Errorf("getIdentityFromSTSResponse() mismatch (-want +got):\n%s", diff)
}
}
})
}
}

func response(account, userID, arn string) getCallerIdentityWrapper {
wrapper := getCallerIdentityWrapper{}
wrapper.GetCallerIdentityResponse.GetCallerIdentityResult.Account = account
wrapper.GetCallerIdentityResponse.GetCallerIdentityResult.Arn = arn
wrapper.GetCallerIdentityResponse.GetCallerIdentityResult.UserID = userID
wrapper.GetCallerIdentityResponse.ResponseMetadata.RequestID = "id1234"
return wrapper
}

0 comments on commit 9fa86f7

Please sign in to comment.