Skip to content

Commit

Permalink
Add the type casting for audience
Browse files Browse the repository at this point in the history
  • Loading branch information
kovayur committed May 21, 2024
1 parent a40721e commit fe51ad9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
23 changes: 18 additions & 5 deletions pkg/auth/acs_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,25 @@ func (c *ACSClaims) GetSubject() (string, error) {
}

// GetAudience returns the audience claim of the token. It identifies the token consumer.
func (c *ACSClaims) GetAudience() (interface{}, error) {
if aud, ok := (*c)[audienceClaim]; ok {
return aud, nil
func (c *ACSClaims) GetAudience() ([]string, error) {
aud := make([]string, 0)
switch v := (*c)[audienceClaim].(type) {
case string:
aud = append(aud, v)
case []string:
aud = v
case []interface{}:
for _, a := range v {
vs, ok := a.(string)
if !ok {
return nil, fmt.Errorf("can't parse part of the audience claim: %q", a)
}
aud = append(aud, vs)
}
default:
return nil, fmt.Errorf("can't parse the audience claim: %q", v)
}

return "", fmt.Errorf("can't find %q attribute in claims", audienceClaim)
return aud, nil
}

// IsOrgAdmin ...
Expand Down
59 changes: 59 additions & 0 deletions pkg/auth/acs_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,62 @@ func TestACSClaims_IsOrgAdmin(t *testing.T) {
})
}
}

func TestACSClaims_Audience(t *testing.T) {
tests := map[string]struct {
claims ACSClaims
expectValues []string
expectError bool
}{
"should parse the audience claim as string": {
claims: ACSClaims(jwt.MapClaims{
audienceClaim: "test",
}),
expectValues: []string{"test"},
},
"should parse the audience claim as an array of strings": {
claims: ACSClaims(jwt.MapClaims{
audienceClaim: []string{"test1", "test2"},
}),
expectValues: []string{"test1", "test2"},
},
"should parse the audience claim as an array of interfaces": {
claims: ACSClaims(jwt.MapClaims{
audienceClaim: []interface{}{"test"},
}),
expectValues: []string{"test"},
},
"should return error if there's no claim": {
claims: ACSClaims(jwt.MapClaims{}),
expectError: true,
},
"should return empty slice if the claim is empty array": {
claims: ACSClaims(jwt.MapClaims{
audienceClaim: []string{},
}),
expectValues: []string{},
},
"should return empty slice if the claim is empty interface": {
claims: ACSClaims(jwt.MapClaims{
audienceClaim: []interface{}{},
}),
expectValues: []string{},
},
"should return error if can't parse the claim": {
claims: ACSClaims(jwt.MapClaims{
audienceClaim: 123,
}),
expectError: true,
},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
audience, err := tt.claims.GetAudience()
assert.Equal(t, tt.expectError, err != nil)
if !tt.expectError {
assert.Equal(t, tt.expectValues, audience)
}
})
}
}
17 changes: 5 additions & 12 deletions pkg/auth/fleetshard_authz_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,18 @@ func checkAudience(allowedAudiences []string) mux.MiddlewareFunc {
return
}

audienceAccepted := false
for _, audience := range allowedAudiences {
if claims.VerifyAudience(audience, true) {
audienceAccepted = true
next.ServeHTTP(writer, request)
break
}
}

if !audienceAccepted {
audience, _ := claims.GetAudience()
glog.Infof("none of the audiences (%q) in the access token is not in the list of allowed values [%s]",
audience, strings.Join(allowedAudiences, ","))

shared.HandleError(request, writer, errors.NotFound(""))
return
}

next.ServeHTTP(writer, request)
audience, _ := claims.GetAudience()
glog.Infof("none of the audiences [%s] in the access token is not in the list of allowed values [%s]",
strings.Join(audience, ","), strings.Join(allowedAudiences, ","))

shared.HandleError(request, writer, errors.NotFound(""))
})
}
}
Expand Down

0 comments on commit fe51ad9

Please sign in to comment.