diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index 3967ec2138bd..88e2945e5226 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -23,11 +23,10 @@ type acquiringResourceState struct { // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function -func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { - s := state.(acquiringResourceState) - tk, err := s.p.cred.GetToken(s.ctx, shared.TokenRequestOptions{ - Scopes: s.p.options.Scopes, - TenantID: s.tenant, +func acquire(state acquiringResourceState) (newResource *shared.AccessToken, newExpiration time.Time, err error) { + tk, err := state.p.cred.GetToken(state.ctx, shared.TokenRequestOptions{ + Scopes: state.p.options.Scopes, + TenantID: state.tenant, }) if err != nil { return nil, time.Time{}, err @@ -38,9 +37,9 @@ func acquire(state interface{}) (newResource interface{}, newExpiration time.Tim // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *shared.ExpiringResource + mainResource *shared.ExpiringResource[*shared.AccessToken, acquiringResourceState] // auxResources are additional resources that are required for cross-tenant applications - auxResources map[string]*shared.ExpiringResource + auxResources map[string]*shared.ExpiringResource[*shared.AccessToken, acquiringResourceState] // the following fields are read-only cred shared.TokenCredential options armpolicy.BearerTokenOptions @@ -59,7 +58,7 @@ func NewBearerTokenPolicy(cred shared.TokenCredential, opts *armpolicy.BearerTok mainResource: shared.NewExpiringResource(acquire), } if len(opts.AuxiliaryTenants) > 0 { - p.auxResources = map[string]*shared.ExpiringResource{} + p.auxResources = map[string]*shared.ExpiringResource[*shared.AccessToken, acquiringResourceState]{} } for _, t := range opts.AuxiliaryTenants { p.auxResources[t] = shared.NewExpiringResource(acquire) @@ -78,9 +77,7 @@ func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) { if err != nil { return nil, err } - if token, ok := tk.(*shared.AccessToken); ok { - req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) - } + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+tk.Token) auxTokens := []string{} for tenant, er := range b.auxResources { as.tenant = tenant @@ -88,7 +85,7 @@ func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) { if err != nil { return nil, err } - auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.(*shared.AccessToken).Token)) + auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.Token)) } if len(auxTokens) > 0 { req.Raw().Header.Set(shared.HeaderAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) diff --git a/sdk/azcore/internal/shared/expiring_resource.go b/sdk/azcore/internal/shared/expiring_resource.go index 5ff6b3d1d676..f6eae3f5178d 100644 --- a/sdk/azcore/internal/shared/expiring_resource.go +++ b/sdk/azcore/internal/shared/expiring_resource.go @@ -12,10 +12,10 @@ import ( ) // AcquireResource abstracts a method for refreshing an expiring resource. -type AcquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) +type AcquireResource[T, U any] func(state U) (newResource T, newExpiration time.Time, err error) // ExpiringResource is a temporal resource (usually a credential) that requires periodic refreshing. -type ExpiringResource struct { +type ExpiringResource[T, U any] struct { // cond is used to synchronize access to the shared resource embodied by the remaining fields cond *sync.Cond @@ -23,7 +23,7 @@ type ExpiringResource struct { acquiring bool // resource contains the value of the shared resource - resource interface{} + resource T // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired expiration time.Time @@ -32,17 +32,17 @@ type ExpiringResource struct { lastAttempt time.Time // acquireResource is the callback function that actually acquires the resource - acquireResource AcquireResource + acquireResource AcquireResource[T, U] } // NewExpiringResource creates a new ExpiringResource that uses the specified AcquireResource for refreshing. -func NewExpiringResource(ar AcquireResource) *ExpiringResource { - return &ExpiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} +func NewExpiringResource[T, U any](ar AcquireResource[T, U]) *ExpiringResource[T, U] { + return &ExpiringResource[T, U]{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} } // GetResource returns the underlying resource. // If the resource is fresh, no refresh is performed. -func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error) { +func (er *ExpiringResource[T, U]) GetResource(state U) (T, error) { // If the resource is expiring within this time window, update it eagerly. // This allows other threads/goroutines to keep running by using the not-yet-expired // resource value while one thread/goroutine updates the resource. @@ -87,7 +87,7 @@ func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error) if acquire { // This thread/goroutine has been selected to acquire/update the resource var expiration time.Time - var newValue interface{} + var newValue T er.lastAttempt = now newValue, expiration, err = er.acquireResource(state) diff --git a/sdk/azcore/internal/shared/expiring_resource_test.go b/sdk/azcore/internal/shared/expiring_resource_test.go index 02e0783c7281..075cb1baffee 100644 --- a/sdk/azcore/internal/shared/expiring_resource_test.go +++ b/sdk/azcore/internal/shared/expiring_resource_test.go @@ -15,15 +15,14 @@ import ( ) func TestNewExpiringResource(t *testing.T) { - er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { - s := state.(string) - switch s { + er := NewExpiringResource(func(state string) (newResource string, newExpiration time.Time, err error) { + switch state { case "initial": return "updated", time.Now().Add(-time.Minute), nil case "updated": return "refreshed", time.Now().Add(1 * time.Hour), nil default: - t.Fatalf("unexpected state %s", s) + t.Fatalf("unexpected state %s", state) return "", time.Time{}, errors.New("unexpected") } }) @@ -42,7 +41,7 @@ func TestExpiringResourceError(t *testing.T) { expectedState := "expected state" expectedError := "expected error" calls := 0 - er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + er := NewExpiringResource(func(state string) (newResource string, newExpiration time.Time, err error) { calls += 1 if calls == 1 { return expectedState, time.Now().Add(time.Minute), nil diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 187642d7fbfe..75a23f035322 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -14,7 +14,7 @@ import ( // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *shared.ExpiringResource + mainResource *shared.ExpiringResource[*shared.AccessToken, acquiringResourceState] // the following fields are read-only cred shared.TokenCredential scopes []string @@ -27,9 +27,8 @@ type acquiringResourceState struct { // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function -func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { - s := state.(acquiringResourceState) - tk, err := s.p.cred.GetToken(s.req.Raw().Context(), shared.TokenRequestOptions{Scopes: s.p.scopes}) +func acquire(state acquiringResourceState) (newResource *shared.AccessToken, newExpiration time.Time, err error) { + tk, err := state.p.cred.GetToken(state.req.Raw().Context(), shared.TokenRequestOptions{Scopes: state.p.scopes}) if err != nil { return nil, time.Time{}, err } @@ -58,8 +57,6 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) { if err != nil { return nil, err } - if token, ok := tk.(*shared.AccessToken); ok { - req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token) - } + req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+tk.Token) return req.Next() }