diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index c18eb0773766..03f531c994f3 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -5,6 +5,7 @@ package runtime import ( "context" + "encoding/base64" "fmt" "net/http" "strings" @@ -45,11 +46,28 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTok p.scopes = make([]string, len(opts.Scopes)) copy(p.scopes, opts.Scopes) p.btp = azruntime.NewBearerTokenPolicy(cred, opts.Scopes, &azpolicy.BearerTokenOptions{ - AuthorizationHandler: azpolicy.AuthorizationHandler{OnRequest: p.onRequest}, + AuthorizationHandler: azpolicy.AuthorizationHandler{ + OnChallenge: p.onChallenge, + OnRequest: p.onRequest, + }, }) return p } +func (b *BearerTokenPolicy) onChallenge(req *azpolicy.Request, res *http.Response, authNZ func(azpolicy.TokenRequestOptions) error) error { + challenge := res.Header.Get(shared.HeaderWWWAuthenticate) + claims, err := parseChallenge(challenge) + if err != nil { + // the challenge contains claims we can't parse + return err + } else if claims != "" { + // request a new token having the specified claims, send the request again + return authNZ(azpolicy.TokenRequestOptions{Claims: claims, Scopes: b.scopes}) + } + // auth challenge didn't include claims, so this is a simple authorization failure + return azruntime.NewResponseError(res) +} + // onRequest authorizes requests with one or more bearer tokens func (b *BearerTokenPolicy) onRequest(req *azpolicy.Request, authNZ func(azpolicy.TokenRequestOptions) error) error { // authorize the request with a token for the primary tenant @@ -79,3 +97,31 @@ func (b *BearerTokenPolicy) onRequest(req *azpolicy.Request, authNZ func(azpolic func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) { return b.btp.Do(req) } + +// parseChallenge parses claims from an authentication challenge issued by ARM so a client can request a token +// that will satisfy conditional access policies. It returns a non-nil error when the given value contains +// claims it can't parse. If the value contains no claims, it returns an empty string and a nil error. +func parseChallenge(wwwAuthenticate string) (string, error) { + claims := "" + var err error + for _, param := range strings.Split(wwwAuthenticate, ",") { + if _, after, found := strings.Cut(param, "claims="); found { + if claims != "" { + // The header contains multiple challenges, at least two of which specify claims. The specs allow this + // but it's unclear what a client should do in this case and there's as yet no concrete example of it. + err = fmt.Errorf("found multiple claims challenges in %q", wwwAuthenticate) + break + } + // trim stuff that would get an error from RawURLEncoding; claims may or may not be padded + claims = strings.Trim(after, `\"=`) + // we don't return this error because it's something unhelpful like "illegal base64 data at input byte 42" + if b, decErr := base64.RawURLEncoding.DecodeString(claims); decErr == nil { + claims = string(b) + } else { + err = fmt.Errorf("failed to parse claims from %q", wwwAuthenticate) + break + } + } + } + return claims, err +} diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go index 52fa5223e3c9..eefb2db99cea 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -204,7 +204,6 @@ func TestAuxiliaryTenants(t *testing.T) { } func TestBearerTokenPolicyChallengeParsing(t *testing.T) { - t.Skip("unskip this test after adding back CAE support") for _, test := range []struct { challenge, desc, expectedClaims string err error @@ -263,10 +262,9 @@ func TestBearerTokenPolicyChallengeParsing(t *testing.T) { cred := mockCredential{ getTokenImpl: func(ctx context.Context, actual azpolicy.TokenRequestOptions) (azcore.AccessToken, error) { calls += 1 - // TODO: uncomment after restoring TokenRequestOptions.Claims - // if calls == 2 && test.expectedClaims != "" { - // require.Equal(t, test.expectedClaims, actual.Claims) - // } + if calls == 2 && test.expectedClaims != "" { + require.Equal(t, test.expectedClaims, actual.Claims) + } return azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour).UTC()}, nil }, } diff --git a/sdk/azcore/internal/exported/exported.go b/sdk/azcore/internal/exported/exported.go index 19861cde4165..2407f8e3e8bd 100644 --- a/sdk/azcore/internal/exported/exported.go +++ b/sdk/azcore/internal/exported/exported.go @@ -51,6 +51,10 @@ type AccessToken struct { // TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. // Exported as policy.TokenRequestOptions. type TokenRequestOptions struct { + // Claims are any additional claims required for the token to satisfy a conditional access policy, such as a + // service may return in a claims challenge following an authorization failure. If a service returned the + // claims value base64 encoded, it must be decoded before setting this field. + Claims string // Scopes contains the list of permission scopes required for the token. Scopes []string }