Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): Add non-blocking token refresh for compute MDS #10263

Merged
merged 18 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 123 additions & 12 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ const (
universeDomainDefault = "googleapis.com"
)

// tokenState represents different states for a [Token].
type tokenState int

const (
// fresh indicates that the [Token] is valid. It is not expired or close to
// expired, or the token has no expiry.
fresh tokenState = iota
// stale indicates that the [Token] is close to expired, and should be
// refreshed. The token can be used normally.
stale
// invalid indicates that the [Token] is expired or invalid. The token
// cannot be used for a normal operation.
invalid
)

var (
defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
defaultHeader = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType}
Expand Down Expand Up @@ -81,7 +96,7 @@ type Token struct {

// IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not
// expired. A token is considered expired if [Token.Expiry] has passed or will
// pass in the next 10 seconds.
// pass in the next 225 seconds.
func (t *Token) IsValid() bool {
return t.isValidWithEarlyExpiry(defaultExpiryDelta)
}
Expand Down Expand Up @@ -210,11 +225,15 @@ func NewCredentials(opts *CredentialsOptions) *Credentials {
// CachedTokenProvider.
type CachedTokenProviderOptions struct {
// DisableAutoRefresh makes the TokenProvider always return the same token,
// even if it is expired.
// even if it is expired. The default is false. Optional.
DisableAutoRefresh bool
// ExpireEarly configures the amount of time before a token expires, that it
// should be refreshed. If unset, the default value is 10 seconds.
// should be refreshed. If unset, the default value is 3 minutes and 45
// seconds. Optional.
ExpireEarly time.Duration
// DisableAsyncRefresh configures a synchronous workflow that refreshes
// stale tokens while blocking. The default is false. Optional.
DisableAsyncRefresh bool
}

func (ctpo *CachedTokenProviderOptions) autoRefresh() bool {
Expand All @@ -231,33 +250,125 @@ func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration {
return ctpo.ExpireEarly
}

func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool {
if ctpo == nil {
return false
}
return ctpo.DisableAsyncRefresh
}

// NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned
// by the underlying provider. By default it will refresh tokens ten seconds
// before they expire, but this time can be configured with the optional
// options.
// by the underlying provider. By default it will refresh tokens asynchronously
// (non-blocking mode) within a window that starts 3 minutes and 45 seconds
// before they expire. The asynchronous (non-blocking) refresh can be changed to
// a synchronous (blocking) refresh using the
// CachedTokenProviderOptions.DisableAsyncRefresh option. The time-before-expiry
// duration can be configured using the CachedTokenProviderOptions.ExpireEarly
// option.
func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider {
if ctp, ok := tp.(*cachedTokenProvider); ok {
return ctp
}
return &cachedTokenProvider{
tp: tp,
autoRefresh: opts.autoRefresh(),
expireEarly: opts.expireEarly(),
tp: tp,
autoRefresh: opts.autoRefresh(),
expireEarly: opts.expireEarly(),
blockingRefresh: opts.blockingRefresh(),
}
}

type cachedTokenProvider struct {
tp TokenProvider
autoRefresh bool
expireEarly time.Duration
tp TokenProvider
autoRefresh bool
expireEarly time.Duration
blockingRefresh bool

mu sync.Mutex
cachedToken *Token
// isRefreshRunning ensures that the non-blocking refresh will only be
// attempted once, even if multiple callers enter the Token method.
isRefreshRunning bool
quartzmo marked this conversation as resolved.
Show resolved Hide resolved
// isRefreshErr ensures that the non-blocking refresh will only be attempted
// once per refresh window if an error is encountered.
isRefreshErr bool
}

func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) {
if c.blockingRefresh {
return c.tokenBlocking(ctx)
}
return c.tokenNonBlocking(ctx)
}

func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) {
switch c.tokenState() {
case fresh:
c.mu.Lock()
defer c.mu.Unlock()
return c.cachedToken, nil
case stale:
c.tokenAsync(ctx)
// Return the stale token immediately to not block customer requests to Cloud services.
c.mu.Lock()
defer c.mu.Unlock()
return c.cachedToken, nil
default: // invalid
return c.tokenBlocking(ctx)
}
}

// tokenState reports the token's validity.
func (c *cachedTokenProvider) tokenState() tokenState {
c.mu.Lock()
defer c.mu.Unlock()
t := c.cachedToken
if t == nil || t.Value == "" {
return invalid
} else if t.Expiry.IsZero() {
return fresh
} else if timeNow().After(t.Expiry.Round(0)) {
return invalid
} else if timeNow().After(t.Expiry.Round(0).Add(-c.expireEarly)) {
return stale
}
return fresh
}

// tokenAsync uses a bool to ensure that only one non-blocking token refresh
// happens at a time, even if multiple callers have entered this function
// concurrently. This avoids creating an arbitrary number of concurrent
// goroutines. Retries should be attempted and managed within the Token method.
// If the refresh attempt fails, no further attempts are made until the refresh
// window expires and the token enters the invalid state, at which point the
// blocking call to Token should likely return the same error on the main goroutine.
func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
fn := func() {
c.mu.Lock()
c.isRefreshRunning = true
c.mu.Unlock()
t, err := c.tp.Token(ctx)
c.mu.Lock()
defer c.mu.Unlock()
c.isRefreshRunning = false
if err != nil {
// Discard errors from the non-blocking refresh, but prevent further
// attempts.
c.isRefreshErr = true
return
}
c.cachedToken = t
}
c.mu.Lock()
defer c.mu.Unlock()
if !c.isRefreshRunning && !c.isRefreshErr {
go fn()
}
}

func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.isRefreshErr = false
if c.cachedToken.IsValid() || (!c.autoRefresh && !c.cachedToken.isEmpty()) {
return c.cachedToken, nil
}
Expand Down
46 changes: 45 additions & 1 deletion auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,57 @@ func (tp *countingTestProvider) Token(ctx context.Context) (*Token, error) {
return tok, nil
}

func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) {
// Freeze now for consistent results.
now := time.Now()
timeNow = func() time.Time { return now }
defer func() { timeNow = time.Now }()
tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{
// EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale.
ExpireEarly: 2 * time.Second,
})
freshToken, err := tp.Token(context.Background())
if err != nil {
t.Fatal(err)
}
if want := "1"; freshToken.Value != want {
quartzmo marked this conversation as resolved.
Show resolved Hide resolved
t.Errorf("got %q, want %q", freshToken.Value, want)
}
staleToken, err := tp.Token(context.Background())
if err != nil {
t.Fatal(err)
}
if want := "1"; staleToken.Value != want {
t.Errorf("got %q, want %q", staleToken.Value, want)
}
// Allow time for async refresh.
time.Sleep(100 * time.Millisecond)
freshToken2, err := tp.Token(context.Background())
if err != nil {
t.Fatal(err)
}
if want := "2"; freshToken2.Value != want {
t.Errorf("got %q, want %q", freshToken2.Value, want)
}
// Allow time for 2nd async refresh.
time.Sleep(100 * time.Millisecond)
freshToken3, err := tp.Token(context.Background())
if err != nil {
t.Fatal(err)
}
if want := "3"; freshToken3.Value != want {
t.Errorf("got %q, want %q", freshToken3.Value, want)
}
}

func TestComputeTokenProvider_BlockingRefresh(t *testing.T) {
// Freeze now for consistent results.
now := time.Now()
timeNow = func() time.Time { return now }
defer func() { timeNow = time.Now }()
tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{
DisableAutoRefresh: true,
DisableAsyncRefresh: true,
DisableAutoRefresh: true,
// EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale.
ExpireEarly: 2 * time.Millisecond,
})
Expand Down
7 changes: 4 additions & 3 deletions auth/credentials/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ var (

// computeTokenProvider creates a [cloud.google.com/go/auth.TokenProvider] that
// uses the metadata service to retrieve tokens.
func computeTokenProvider(earlyExpiry time.Duration, scope ...string) auth.TokenProvider {
return auth.NewCachedTokenProvider(computeProvider{scopes: scope}, &auth.CachedTokenProviderOptions{
ExpireEarly: earlyExpiry,
func computeTokenProvider(opts *DetectOptions) auth.TokenProvider {
return auth.NewCachedTokenProvider(computeProvider{scopes: opts.Scopes}, &auth.CachedTokenProviderOptions{
ExpireEarly: opts.EarlyTokenRefresh,
DisableAsyncRefresh: opts.DisableAsyncRefresh,
})
}

Expand Down
9 changes: 7 additions & 2 deletions auth/credentials/compute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ func TestComputeTokenProvider(t *testing.T) {
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "token_type": "bearer", "expires_in": 86400}`))
}))
t.Setenv(computeMetadataEnvVar, strings.TrimPrefix(ts.URL, "http://"))
tp := computeTokenProvider(0, scope)
tp := computeTokenProvider(&DetectOptions{
EarlyTokenRefresh: 0,
Scopes: []string{
scope,
},
})
tok, err := tp.Token(context.Background())
if err != nil {
t.Fatal(err)
Expand All @@ -46,6 +51,6 @@ func TestComputeTokenProvider(t *testing.T) {
t.Errorf("got %q, want %q", tok.Value, want)
}
if want := "bearer"; tok.Type != want {
t.Errorf("got %q, want %q", tok.Value, want)
t.Errorf("got %q, want %q", tok.Type, want)
}
}
9 changes: 7 additions & 2 deletions auth/credentials/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func DetectDefault(opts *DetectOptions) (*auth.Credentials, error) {

if OnGCE() {
return auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: computeTokenProvider(opts.EarlyTokenRefresh, opts.Scopes...),
TokenProvider: computeTokenProvider(opts),
ProjectIDProvider: auth.CredentialsPropertyFunc(func(context.Context) (string, error) {
return metadata.ProjectID()
}),
Expand All @@ -116,8 +116,13 @@ type DetectOptions struct {
// Optional.
Subject string
// EarlyTokenRefresh configures how early before a token expires that it
// should be refreshed.
// should be refreshed. Once the token’s time until expiration has entered
// this refresh window the token is considered valid but stale. If unset,
// the default value is 3 minutes and 45 seconds. Optional.
EarlyTokenRefresh time.Duration
// DisableAsyncRefresh configures a synchronous workflow that refreshes
// stale tokens while blocking. The default is false. Optional.
DisableAsyncRefresh bool
// AuthHandlerOptions configures an authorization handler and other options
// for 3LO flows. It is required, and only used, for client credential
// flows.
Expand Down
2 changes: 1 addition & 1 deletion auth/internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
// future. To make the test pass simply bump the int, but please also clone the
// relevant fields.
func TestCloneDetectOptions_FieldTest(t *testing.T) {
const WantNumberOfFields = 12
const WantNumberOfFields = 13
o := credentials.DetectOptions{}
got := reflect.TypeOf(o).NumField()
if got != WantNumberOfFields {
Expand Down
3 changes: 2 additions & 1 deletion auth/threelegged.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ type Options3LO struct {
// Optional.
Client *http.Client
// EarlyTokenExpiry is the time before the token expires that it should be
// refreshed. If not set the default value is 10 seconds. Optional.
// refreshed. If not set the default value is 3 minutes and 45 seconds.
// Optional.
EarlyTokenExpiry time.Duration

// AuthHandlerOpts provides a set of options for doing a
Expand Down
Loading