Skip to content

Commit

Permalink
Merge pull request #362 from hashicorp/bugfix/token-expiry-race-condi…
Browse files Browse the repository at this point in the history
…tion

auth: Use a custom function to determine access token expiry
  • Loading branch information
manicminer authored Mar 10, 2023
2 parents 45368c7 + 86de601 commit 5cab2c3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 6 deletions.
11 changes: 5 additions & 6 deletions sdk/auth/cached_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ type CachedAuthorizer struct {
// Token returns the current token if it's still valid, else will acquire a new token
func (c *CachedAuthorizer) Token(ctx context.Context, req *http.Request) (*oauth2.Token, error) {
c.mutex.RLock()
valid := c.token != nil && c.token.Valid()
dueForRenewal := tokenDueForRenewal(c.token)
c.mutex.RUnlock()

if !valid {
if dueForRenewal {
c.mutex.Lock()
defer c.mutex.Unlock()
var err error
Expand All @@ -46,16 +46,15 @@ func (c *CachedAuthorizer) Token(ctx context.Context, req *http.Request) (*oauth
// AuxiliaryTokens returns additional tokens for auxiliary tenant IDs, for use in multi-tenant scenarios
func (c *CachedAuthorizer) AuxiliaryTokens(ctx context.Context, req *http.Request) ([]*oauth2.Token, error) {
c.mutex.RLock()
var valid bool
var dueForRenewal bool
for _, token := range c.auxTokens {
valid = token != nil && token.Valid()
if !valid {
if dueForRenewal = tokenDueForRenewal(token); dueForRenewal {
break
}
}
c.mutex.RUnlock()

if !valid {
if !dueForRenewal {
c.mutex.Lock()
defer c.mutex.Unlock()
var err error
Expand Down
1 change: 1 addition & 0 deletions sdk/auth/client_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ func clientCredentialsToken(ctx context.Context, endpoint string, params *url.Va
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
}

var secs time.Duration
if exp, ok := tokenRes.ExpiresIn.(string); ok && exp != "" {
if v, err := strconv.Atoi(exp); err == nil {
Expand Down
44 changes: 44 additions & 0 deletions sdk/auth/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package auth

import (
"golang.org/x/oauth2"
"time"

"github.com/hashicorp/go-azure-sdk/sdk/claims"
)

const tokenExpiryDelta = 10 * time.Minute

// tokenExpiresSoon returns true if the token expires within 10 minutes, or if more than 50% of its validity period has elapsed (if this can be determined), whichever is later
func tokenDueForRenewal(token *oauth2.Token) bool {
if token == nil {
return true
}

// Some tokens may never expire
if token.Expiry.IsZero() {
return false
}

expiry := token.Expiry.Round(0)
delta := tokenExpiryDelta
now := time.Now()
expiresWithinTenMinutes := expiry.Add(-delta).Before(now)

// Try to parse the token claims to retrieve the issuedAt time
if claims, err := claims.ParseClaims(token); err != nil {
if claims.IssuedAt > 0 {
issued := time.Unix(claims.IssuedAt, 0)
validity := expiry.Sub(issued)

// If the validity period is less than double the expiry delta, then instead
// determine whether >50% of the validity period has elapsed
if validity < delta*2 {
halfValidityHasElapsed := issued.Add(validity / 2).Before(now)
return halfValidityHasElapsed
}
}
}

return expiresWithinTenMinutes
}
1 change: 1 addition & 0 deletions sdk/claims/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
// Claims is used to unmarshall the claims from a JWT issued by the Microsoft Identity Platform.
type Claims struct {
Audience string `json:"aud"`
IssuedAt int64 `json:"iat"`
Issuer string `json:"iss"`
IdentityProvider string `json:"idp"`
ObjectId string `json:"oid"`
Expand Down

0 comments on commit 5cab2c3

Please sign in to comment.