diff --git a/sdk/auth/cached_authorizer.go b/sdk/auth/cached_authorizer.go index 7ec4dd76cc9..0de357d0c57 100644 --- a/sdk/auth/cached_authorizer.go +++ b/sdk/auth/cached_authorizer.go @@ -27,10 +27,10 @@ type CachedAuthorizer struct { // Token returns the current token if it's still valid, else will acquire a new token func (c *CachedAuthorizer) Token(ctx context.Context, req *http.Request) (*oauth2.Token, error) { c.mutex.RLock() - valid := c.token != nil && c.token.Valid() + dueForRenewal := tokenDueForRenewal(c.token) c.mutex.RUnlock() - if !valid { + if dueForRenewal { c.mutex.Lock() defer c.mutex.Unlock() var err error @@ -46,16 +46,15 @@ func (c *CachedAuthorizer) Token(ctx context.Context, req *http.Request) (*oauth // AuxiliaryTokens returns additional tokens for auxiliary tenant IDs, for use in multi-tenant scenarios func (c *CachedAuthorizer) AuxiliaryTokens(ctx context.Context, req *http.Request) ([]*oauth2.Token, error) { c.mutex.RLock() - var valid bool + var dueForRenewal bool for _, token := range c.auxTokens { - valid = token != nil && token.Valid() - if !valid { + if dueForRenewal = tokenDueForRenewal(token); dueForRenewal { break } } c.mutex.RUnlock() - if !valid { + if !dueForRenewal { c.mutex.Lock() defer c.mutex.Unlock() var err error diff --git a/sdk/auth/client_credentials.go b/sdk/auth/client_credentials.go index a3c5eae80f7..e022630d341 100644 --- a/sdk/auth/client_credentials.go +++ b/sdk/auth/client_credentials.go @@ -343,6 +343,7 @@ func clientCredentialsToken(ctx context.Context, endpoint string, params *url.Va AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, } + var secs time.Duration if exp, ok := tokenRes.ExpiresIn.(string); ok && exp != "" { if v, err := strconv.Atoi(exp); err == nil { diff --git a/sdk/auth/token.go b/sdk/auth/token.go new file mode 100644 index 00000000000..af2d91680ca --- /dev/null +++ b/sdk/auth/token.go @@ -0,0 +1,44 @@ +package auth + +import ( + "golang.org/x/oauth2" + "time" + + "github.com/hashicorp/go-azure-sdk/sdk/claims" +) + +const tokenExpiryDelta = 10 * time.Minute + +// tokenExpiresSoon returns true if the token expires within 10 minutes, or if more than 50% of its validity period has elapsed (if this can be determined), whichever is later +func tokenDueForRenewal(token *oauth2.Token) bool { + if token == nil { + return true + } + + // Some tokens may never expire + if token.Expiry.IsZero() { + return false + } + + expiry := token.Expiry.Round(0) + delta := tokenExpiryDelta + now := time.Now() + expiresWithinTenMinutes := expiry.Add(-delta).Before(now) + + // Try to parse the token claims to retrieve the issuedAt time + if claims, err := claims.ParseClaims(token); err != nil { + if claims.IssuedAt > 0 { + issued := time.Unix(claims.IssuedAt, 0) + validity := expiry.Sub(issued) + + // If the validity period is less than double the expiry delta, then instead + // determine whether >50% of the validity period has elapsed + if validity < delta*2 { + halfValidityHasElapsed := issued.Add(validity / 2).Before(now) + return halfValidityHasElapsed + } + } + } + + return expiresWithinTenMinutes +} diff --git a/sdk/claims/claims.go b/sdk/claims/claims.go index ae02e5fc5a3..85a0242a44a 100644 --- a/sdk/claims/claims.go +++ b/sdk/claims/claims.go @@ -15,6 +15,7 @@ import ( // Claims is used to unmarshall the claims from a JWT issued by the Microsoft Identity Platform. type Claims struct { Audience string `json:"aud"` + IssuedAt int64 `json:"iat"` Issuer string `json:"iss"` IdentityProvider string `json:"idp"` ObjectId string `json:"oid"`