diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index 83f1bf86e65e..54b3bb78d859 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -28,12 +28,12 @@ type acquiringResourceState struct { tenant string } -// acquire acquires or updates the resource; only one -// thread/goroutine at a time ever calls this function -func acquire(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) { +// acquireAuxToken acquires a token from an auxiliary tenant. Only one thread/goroutine at a time ever calls this function. +func acquireAuxToken(state acquiringResourceState) (newResource azcore.AccessToken, newExpiration time.Time, err error) { tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{ - Scopes: state.p.scopes, - TenantID: state.tenant, + EnableCAE: true, + Scopes: state.p.scopes, + TenantID: state.tenant, }) if err != nil { return azcore.AccessToken{}, time.Time{}, err @@ -59,7 +59,7 @@ func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTok p := &BearerTokenPolicy{cred: cred} p.auxResources = make(map[string]*temporal.Resource[azcore.AccessToken, acquiringResourceState], len(opts.AuxiliaryTenants)) for _, t := range opts.AuxiliaryTenants { - p.auxResources[t] = temporal.NewResource(acquire) + p.auxResources[t] = temporal.NewResource(acquireAuxToken) } p.scopes = make([]string, len(opts.Scopes)) copy(p.scopes, opts.Scopes) @@ -80,7 +80,7 @@ func (b *BearerTokenPolicy) onChallenge(req *azpolicy.Request, res *http.Respons return err } else if claims != "" { // request a new token having the specified claims, send the request again - return authNZ(azpolicy.TokenRequestOptions{Claims: claims, Scopes: b.scopes}) + return authNZ(azpolicy.TokenRequestOptions{Claims: claims, EnableCAE: true, Scopes: b.scopes}) } // auth challenge didn't include claims, so this is a simple authorization failure return azruntime.NewResponseError(res) @@ -89,7 +89,7 @@ func (b *BearerTokenPolicy) onChallenge(req *azpolicy.Request, res *http.Respons // onRequest authorizes requests with one or more bearer tokens func (b *BearerTokenPolicy) onRequest(req *azpolicy.Request, authNZ func(azpolicy.TokenRequestOptions) error) error { // authorize the request with a token for the primary tenant - err := authNZ(azpolicy.TokenRequestOptions{Scopes: b.scopes}) + err := authNZ(azpolicy.TokenRequestOptions{EnableCAE: true, Scopes: b.scopes}) if err != nil || len(b.auxResources) == 0 { return err } diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go index 1ab06ae00d76..b062a50f3ef7 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -34,6 +34,9 @@ type mockCredential struct { } func (mc mockCredential) GetToken(ctx context.Context, options azpolicy.TokenRequestOptions) (azcore.AccessToken, error) { + if !options.EnableCAE { + return azcore.AccessToken{}, errors.New("ARM clients should set EnableCAE to true") + } if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) }