Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth: Use a custom function to determine access token expiry #362

Merged
merged 1 commit into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions sdk/auth/cached_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ type CachedAuthorizer struct {
// Token returns the current token if it's still valid, else will acquire a new token
func (c *CachedAuthorizer) Token(ctx context.Context, req *http.Request) (*oauth2.Token, error) {
c.mutex.RLock()
valid := c.token != nil && c.token.Valid()
dueForRenewal := tokenDueForRenewal(c.token)
c.mutex.RUnlock()

if !valid {
if dueForRenewal {
c.mutex.Lock()
defer c.mutex.Unlock()
var err error
Expand All @@ -46,16 +46,15 @@ func (c *CachedAuthorizer) Token(ctx context.Context, req *http.Request) (*oauth
// AuxiliaryTokens returns additional tokens for auxiliary tenant IDs, for use in multi-tenant scenarios
func (c *CachedAuthorizer) AuxiliaryTokens(ctx context.Context, req *http.Request) ([]*oauth2.Token, error) {
c.mutex.RLock()
var valid bool
var dueForRenewal bool
for _, token := range c.auxTokens {
valid = token != nil && token.Valid()
if !valid {
if dueForRenewal = tokenDueForRenewal(token); dueForRenewal {
break
}
}
c.mutex.RUnlock()

if !valid {
if !dueForRenewal {
c.mutex.Lock()
defer c.mutex.Unlock()
var err error
Expand Down
1 change: 1 addition & 0 deletions sdk/auth/client_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ func clientCredentialsToken(ctx context.Context, endpoint string, params *url.Va
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
}

var secs time.Duration
if exp, ok := tokenRes.ExpiresIn.(string); ok && exp != "" {
if v, err := strconv.Atoi(exp); err == nil {
Expand Down
44 changes: 44 additions & 0 deletions sdk/auth/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package auth

import (
"golang.org/x/oauth2"
"time"

"github.com/hashicorp/go-azure-sdk/sdk/claims"
)

const tokenExpiryDelta = 10 * time.Minute

// tokenExpiresSoon returns true if the token expires within 10 minutes, or if more than 50% of its validity period has elapsed (if this can be determined), whichever is later
func tokenDueForRenewal(token *oauth2.Token) bool {
if token == nil {
return true
}

// Some tokens may never expire
if token.Expiry.IsZero() {
return false
}

expiry := token.Expiry.Round(0)
delta := tokenExpiryDelta
now := time.Now()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC the access token expiration time is in UTC, so do we want to explicitly convert these to UTC too?

Copy link
Contributor Author

@manicminer manicminer Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in local time actually, we're computing the expiration time based on the expires_in field from the login response, which is a seconds value that we add to the local time.

Screenshot 2023-03-10 at 10 25 41

expiresWithinTenMinutes := expiry.Add(-delta).Before(now)

// Try to parse the token claims to retrieve the issuedAt time
if claims, err := claims.ParseClaims(token); err != nil {
if claims.IssuedAt > 0 {
issued := time.Unix(claims.IssuedAt, 0)
validity := expiry.Sub(issued)

// If the validity period is less than double the expiry delta, then instead
// determine whether >50% of the validity period has elapsed
if validity < delta*2 {
halfValidityHasElapsed := issued.Add(validity / 2).Before(now)
return halfValidityHasElapsed
}
}
}

return expiresWithinTenMinutes
}
1 change: 1 addition & 0 deletions sdk/claims/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
// Claims is used to unmarshall the claims from a JWT issued by the Microsoft Identity Platform.
type Claims struct {
Audience string `json:"aud"`
IssuedAt int64 `json:"iat"`
Issuer string `json:"iss"`
IdentityProvider string `json:"idp"`
ObjectId string `json:"oid"`
Expand Down