Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate expiring resource to generics #16974

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions sdk/azcore/arm/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ type acquiringResourceState struct {

// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
s := state.(acquiringResourceState)
tk, err := s.p.cred.GetToken(s.ctx, shared.TokenRequestOptions{
Scopes: s.p.options.Scopes,
TenantID: s.tenant,
func acquire(state acquiringResourceState) (newResource *shared.AccessToken, newExpiration time.Time, err error) {
tk, err := state.p.cred.GetToken(state.ctx, shared.TokenRequestOptions{
Scopes: state.p.options.Scopes,
TenantID: state.tenant,
})
if err != nil {
return nil, time.Time{}, err
Expand All @@ -38,9 +37,9 @@ func acquire(state interface{}) (newResource interface{}, newExpiration time.Tim
// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
type BearerTokenPolicy struct {
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *shared.ExpiringResource
mainResource *shared.ExpiringResource[*shared.AccessToken, acquiringResourceState]
// auxResources are additional resources that are required for cross-tenant applications
auxResources map[string]*shared.ExpiringResource
auxResources map[string]*shared.ExpiringResource[*shared.AccessToken, acquiringResourceState]
// the following fields are read-only
cred shared.TokenCredential
options armpolicy.BearerTokenOptions
Expand All @@ -59,7 +58,7 @@ func NewBearerTokenPolicy(cred shared.TokenCredential, opts *armpolicy.BearerTok
mainResource: shared.NewExpiringResource(acquire),
}
if len(opts.AuxiliaryTenants) > 0 {
p.auxResources = map[string]*shared.ExpiringResource{}
p.auxResources = map[string]*shared.ExpiringResource[*shared.AccessToken, acquiringResourceState]{}
}
for _, t := range opts.AuxiliaryTenants {
p.auxResources[t] = shared.NewExpiringResource(acquire)
Expand All @@ -78,17 +77,15 @@ func (b *BearerTokenPolicy) Do(req *azpolicy.Request) (*http.Response, error) {
if err != nil {
return nil, err
}
if token, ok := tk.(*shared.AccessToken); ok {
req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token)
}
req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+tk.Token)
auxTokens := []string{}
for tenant, er := range b.auxResources {
as.tenant = tenant
auxTk, err := er.GetResource(as)
if err != nil {
return nil, err
}
auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.(*shared.AccessToken).Token))
auxTokens = append(auxTokens, fmt.Sprintf("%s%s", shared.BearerTokenPrefix, auxTk.Token))
}
if len(auxTokens) > 0 {
req.Raw().Header.Set(shared.HeaderAuxiliaryAuthorization, strings.Join(auxTokens, ", "))
Expand Down
16 changes: 8 additions & 8 deletions sdk/azcore/internal/shared/expiring_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ import (
)

// AcquireResource abstracts a method for refreshing an expiring resource.
type AcquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error)
type AcquireResource[T, U any] func(state U) (newResource T, newExpiration time.Time, err error)

// ExpiringResource is a temporal resource (usually a credential) that requires periodic refreshing.
type ExpiringResource struct {
type ExpiringResource[T, U any] struct {
// cond is used to synchronize access to the shared resource embodied by the remaining fields
cond *sync.Cond

// acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource
acquiring bool

// resource contains the value of the shared resource
resource interface{}
resource T

// expiration indicates when the shared resource expires; it is 0 if the resource was never acquired
expiration time.Time
Expand All @@ -32,17 +32,17 @@ type ExpiringResource struct {
lastAttempt time.Time

// acquireResource is the callback function that actually acquires the resource
acquireResource AcquireResource
acquireResource AcquireResource[T, U]
}

// NewExpiringResource creates a new ExpiringResource that uses the specified AcquireResource for refreshing.
func NewExpiringResource(ar AcquireResource) *ExpiringResource {
return &ExpiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar}
func NewExpiringResource[T, U any](ar AcquireResource[T, U]) *ExpiringResource[T, U] {
return &ExpiringResource[T, U]{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar}
}

// GetResource returns the underlying resource.
// If the resource is fresh, no refresh is performed.
func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error) {
func (er *ExpiringResource[T, U]) GetResource(state U) (T, error) {
// If the resource is expiring within this time window, update it eagerly.
// This allows other threads/goroutines to keep running by using the not-yet-expired
// resource value while one thread/goroutine updates the resource.
Expand Down Expand Up @@ -87,7 +87,7 @@ func (er *ExpiringResource) GetResource(state interface{}) (interface{}, error)
if acquire {
// This thread/goroutine has been selected to acquire/update the resource
var expiration time.Time
var newValue interface{}
var newValue T
er.lastAttempt = now
newValue, expiration, err = er.acquireResource(state)

Expand Down
9 changes: 4 additions & 5 deletions sdk/azcore/internal/shared/expiring_resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ import (
)

func TestNewExpiringResource(t *testing.T) {
er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
s := state.(string)
switch s {
er := NewExpiringResource(func(state string) (newResource string, newExpiration time.Time, err error) {
switch state {
case "initial":
return "updated", time.Now().Add(-time.Minute), nil
case "updated":
return "refreshed", time.Now().Add(1 * time.Hour), nil
default:
t.Fatalf("unexpected state %s", s)
t.Fatalf("unexpected state %s", state)
return "", time.Time{}, errors.New("unexpected")
}
})
Expand All @@ -42,7 +41,7 @@ func TestExpiringResourceError(t *testing.T) {
expectedState := "expected state"
expectedError := "expected error"
calls := 0
er := NewExpiringResource(func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
er := NewExpiringResource(func(state string) (newResource string, newExpiration time.Time, err error) {
calls += 1
if calls == 1 {
return expectedState, time.Now().Add(time.Minute), nil
Expand Down
11 changes: 4 additions & 7 deletions sdk/azcore/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
// BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential.
type BearerTokenPolicy struct {
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *shared.ExpiringResource
mainResource *shared.ExpiringResource[*shared.AccessToken, acquiringResourceState]
// the following fields are read-only
cred shared.TokenCredential
scopes []string
Expand All @@ -27,9 +27,8 @@ type acquiringResourceState struct {

// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
s := state.(acquiringResourceState)
tk, err := s.p.cred.GetToken(s.req.Raw().Context(), shared.TokenRequestOptions{Scopes: s.p.scopes})
func acquire(state acquiringResourceState) (newResource *shared.AccessToken, newExpiration time.Time, err error) {
tk, err := state.p.cred.GetToken(state.req.Raw().Context(), shared.TokenRequestOptions{Scopes: state.p.scopes})
if err != nil {
return nil, time.Time{}, err
}
Expand Down Expand Up @@ -58,8 +57,6 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
if err != nil {
return nil, err
}
if token, ok := tk.(*shared.AccessToken); ok {
req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+token.Token)
}
req.Raw().Header.Set(shared.HeaderAuthorization, shared.BearerTokenPrefix+tk.Token)
return req.Next()
}