From c2883bbcbb7d48abd8058598d1060091e322b663 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Wed, 22 May 2024 13:49:00 -0600 Subject: [PATCH 01/11] feat(auth): Add non-blocking token refresh for compute MDS --- auth/auth.go | 136 ++++++++++++++++++++-- auth/auth_test.go | 48 ++++++++ auth/credentials/compute.go | 7 +- auth/credentials/compute_test.go | 9 +- auth/credentials/detect.go | 9 +- auth/go.mod | 2 +- auth/internal/transport/transport_test.go | 2 +- auth/threelegged.go | 3 +- 8 files changed, 196 insertions(+), 20 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index d579e482e896..578a6aafb146 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -27,6 +27,8 @@ import ( "cloud.google.com/go/auth/internal" "cloud.google.com/go/auth/internal/jwt" + + "golang.org/x/sync/singleflight" ) const ( @@ -41,9 +43,32 @@ const ( // so we give it 15 seconds to refresh it's cache before attempting to refresh a token. defaultExpiryDelta = 225 * time.Second + // nonBlockingRefreshKey is the singleflight uniqueness key for asynchronous + // refresh of a token. It can be any value, since there is only one + // cachedTokenProvider.cachedToken. + nonBlockingRefreshKey = "computeProvider" + // nonBlockingRefreshTimeout is the timeout for asynchronous refresh of a + // token. + nonBlockingRefreshTimeout = 30 * time.Second + universeDomainDefault = "googleapis.com" ) +// tokenState represents different states for a [Token]. +type tokenState int + +const ( + // fresh indicates that the [Token] is valid. It is not expired or close to + // expired, or the token has no expiry. + fresh tokenState = iota + // stale indicates that the [Token] is close to expired, and should be + // refreshed. The token can be used normally. + stale + // invalid indicates that the [Token] is expired or invalid. The token + // cannot be used for a normal operation. + invalid +) + var ( defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" defaultHeader = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType} @@ -81,9 +106,9 @@ type Token struct { // IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not // expired. A token is considered expired if [Token.Expiry] has passed or will -// pass in the next 10 seconds. +// pass in the next 225 seconds. func (t *Token) IsValid() bool { - return t.isValidWithEarlyExpiry(defaultExpiryDelta) + return t.isValidWithEarlyExpiry(defaultExpiryDelta) // TODO(quartzmo): investigate why EarlyTokenRefresh, ExpireEarly isn't used here. Bug? } func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool { @@ -206,11 +231,15 @@ func NewCredentials(opts *CredentialsOptions) *Credentials { // CachedTokenProvider. type CachedTokenProviderOptions struct { // DisableAutoRefresh makes the TokenProvider always return the same token, - // even if it is expired. + // even if it is expired. The default is false. Optional. DisableAutoRefresh bool // ExpireEarly configures the amount of time before a token expires, that it - // should be refreshed. If unset, the default value is 10 seconds. + // should be refreshed. If unset, the default value is 3 minutes and 45 + // seconds. Optional. ExpireEarly time.Duration + // NonBlockingRefresh configures an asynchronous workflow that refreshes + // stale tokens without blocking. The default is false. Optional. + NonBlockingRefresh bool } func (ctpo *CachedTokenProviderOptions) autoRefresh() bool { @@ -227,6 +256,13 @@ func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration { return ctpo.ExpireEarly } +func (ctpo *CachedTokenProviderOptions) nonBlockingRefresh() bool { + if ctpo == nil { + return false + } + return ctpo.NonBlockingRefresh +} + // NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned // by the underlying provider. By default it will refresh tokens ten seconds // before they expire, but this time can be configured with the optional @@ -236,22 +272,102 @@ func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) return ctp } return &cachedTokenProvider{ - tp: tp, - autoRefresh: opts.autoRefresh(), - expireEarly: opts.expireEarly(), + tp: tp, + autoRefresh: opts.autoRefresh(), + expireEarly: opts.expireEarly(), + nonBlockingRefresh: opts.nonBlockingRefresh(), } } type cachedTokenProvider struct { - tp TokenProvider - autoRefresh bool - expireEarly time.Duration + tp TokenProvider + autoRefresh bool + expireEarly time.Duration + nonBlockingRefresh bool + // loadGroup ensures that the non-blocking refresh will only happen on one + // goroutine, even if multiple callers have entered the Token method. + loadGroup singleflight.Group mu sync.Mutex cachedToken *Token } func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) { + if c.nonBlockingRefresh { + return c.tokenNonBlocking(ctx) + } + return c.tokenBlocking(ctx) +} + +func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) { + switch c.tokenState() { + case fresh: + c.mu.Lock() + defer c.mu.Unlock() + return c.cachedToken, nil + case stale: + // Call singleflight's DoChan (via tokenAsync) but discard the returned + // chan. In order to return an err from tokenAsync, we would need to + // wait on chan and read its err. Instead, allow all requests during + // the refresh window to join serial attempts managed by singleflight. + // If all fail, the Expired case should return the same error. + c.tokenAsync() + // Return the stale token immediately to not block customer requests to Cloud services + c.mu.Lock() + defer c.mu.Unlock() + return c.cachedToken, nil + case invalid: + return c.tokenBlocking(ctx) + default: + panic("unreachable") + } +} + +// tokenState reports the token's validity. +func (c *cachedTokenProvider) tokenState() tokenState { + c.mu.Lock() + defer c.mu.Unlock() + t := c.cachedToken + if t == nil || t.Value == "" { + return invalid + } else if t.Expiry.IsZero() { + return fresh + } else if timeNow().After(t.Expiry.Round(0)) { + return invalid + } else if timeNow().After(t.Expiry.Round(0).Add(-c.expireEarly)) { + return stale + } + return fresh +} + +// tokenAsync uses singleflight to ensure that only one async token fetch +// happens at a time, even if multiple callers have entered this function +// concurrently. This avoids creating an arbitrary number of concurrent +// goroutines. +func (c *cachedTokenProvider) tokenAsync() <-chan singleflight.Result { + return c.loadGroup.DoChan(nonBlockingRefreshKey, func() (entry any, err error) { + + // Use a new context with timeout. This allows metadata.GetWithContext + // to retry regardless of the original request context. + refreshCtx, refreshCancel := context.WithTimeout(context.Background(), nonBlockingRefreshTimeout) + defer refreshCancel() + + t, err := c.tp.Token(refreshCtx) + if err != nil { + // In order to return this err to callers of the main goroutine, a + // call to tokenAsync would need to wait on the returned chan and + // read its err. Currently, it is ignored in tokenNonBlocking, above. + return nil, err + } + + c.mu.Lock() + defer c.mu.Unlock() + c.cachedToken = t + return t, nil + }) +} + +func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) { c.mu.Lock() defer c.mu.Unlock() if c.cachedToken.IsValid() || !c.autoRefresh { diff --git a/auth/auth_test.go b/auth/auth_test.go index 47f335187781..58334de553cb 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -506,3 +506,51 @@ func TestNew2LOTokenProvider_Validate(t *testing.T) { }) } } + +type countingTestProvider struct { + count int +} + +func (tp *countingTestProvider) Token(ctx context.Context) (*Token, error) { + tok := &Token{ + Value: fmt.Sprint(tp.count), + // Set expiry to q1 second from now. + Expiry: time.Now().Add(time.Second), + } + tp.count++ + return tok, nil +} +func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { + // Freeze now for consistent results. + now := time.Now() + timeNow = func() time.Time { return now } + defer func() { timeNow = time.Now }() + tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ + NonBlockingRefresh: true, + // EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale. + ExpireEarly: 2 * time.Second, + }) + freshToken, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if want := "1"; freshToken.Value != want { + t.Errorf("got %q, want %q", freshToken.Value, want) + } + staleToken, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if want := "1"; staleToken.Value != want { + t.Errorf("got %q, want %q", staleToken.Value, want) + } + // Allow time for async refresh. + time.Sleep(100 * time.Millisecond) + freshToken2, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if want := "2"; freshToken2.Value != want { + t.Errorf("got %q, want %q", freshToken2.Value, want) + } +} diff --git a/auth/credentials/compute.go b/auth/credentials/compute.go index f3ec8882424f..656786e6955a 100644 --- a/auth/credentials/compute.go +++ b/auth/credentials/compute.go @@ -37,9 +37,10 @@ var ( // computeTokenProvider creates a [cloud.google.com/go/auth.TokenProvider] that // uses the metadata service to retrieve tokens. -func computeTokenProvider(earlyExpiry time.Duration, scope ...string) auth.TokenProvider { - return auth.NewCachedTokenProvider(computeProvider{scopes: scope}, &auth.CachedTokenProviderOptions{ - ExpireEarly: earlyExpiry, +func computeTokenProvider(opts *DetectOptions) auth.TokenProvider { + return auth.NewCachedTokenProvider(computeProvider{scopes: opts.Scopes}, &auth.CachedTokenProviderOptions{ + ExpireEarly: opts.EarlyTokenRefresh, + NonBlockingRefresh: opts.NonBlockingRefresh, }) } diff --git a/auth/credentials/compute_test.go b/auth/credentials/compute_test.go index 0b5eca6ce419..6d0d691839e1 100644 --- a/auth/credentials/compute_test.go +++ b/auth/credentials/compute_test.go @@ -37,7 +37,12 @@ func TestComputeTokenProvider(t *testing.T) { w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "token_type": "bearer", "expires_in": 86400}`)) })) t.Setenv(computeMetadataEnvVar, strings.TrimPrefix(ts.URL, "http://")) - tp := computeTokenProvider(0, scope) + tp := computeTokenProvider(&DetectOptions{ + EarlyTokenRefresh: 0, + Scopes: []string{ + scope, + }, + }) tok, err := tp.Token(context.Background()) if err != nil { t.Fatal(err) @@ -46,6 +51,6 @@ func TestComputeTokenProvider(t *testing.T) { t.Errorf("got %q, want %q", tok.Value, want) } if want := "bearer"; tok.Type != want { - t.Errorf("got %q, want %q", tok.Value, want) + t.Errorf("got %q, want %q", tok.Type, want) } } diff --git a/auth/credentials/detect.go b/auth/credentials/detect.go index cb3f44f5873f..a30fe2e3123c 100644 --- a/auth/credentials/detect.go +++ b/auth/credentials/detect.go @@ -92,7 +92,7 @@ func DetectDefault(opts *DetectOptions) (*auth.Credentials, error) { if OnGCE() { return auth.NewCredentials(&auth.CredentialsOptions{ - TokenProvider: computeTokenProvider(opts.EarlyTokenRefresh, opts.Scopes...), + TokenProvider: computeTokenProvider(opts), ProjectIDProvider: auth.CredentialsPropertyFunc(func(context.Context) (string, error) { return metadata.ProjectID() }), @@ -116,8 +116,13 @@ type DetectOptions struct { // Optional. Subject string // EarlyTokenRefresh configures how early before a token expires that it - // should be refreshed. + // should be refreshed. Once the token’s time until expiration has entered + // this refresh window the token is considered valid but stale. If unset, + // the default value is 3 minutes and 45 seconds. Optional. EarlyTokenRefresh time.Duration + // NonBlockingRefresh configures an asynchronous workflow that refreshes + // stale tokens without blocking. The default is false. Optional. + NonBlockingRefresh bool // AuthHandlerOptions configures an authorization handler and other options // for 3LO flows. It is required, and only used, for client credential // flows. diff --git a/auth/go.mod b/auth/go.mod index b92034bb80f0..36a9515865b3 100644 --- a/auth/go.mod +++ b/auth/go.mod @@ -9,6 +9,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 go.opencensus.io v0.24.0 golang.org/x/net v0.26.0 + golang.org/x/sync v0.7.0 google.golang.org/grpc v1.64.0 google.golang.org/protobuf v1.34.1 ) @@ -18,7 +19,6 @@ require ( github.com/golang/protobuf v1.5.4 // indirect golang.org/x/crypto v0.24.0 // indirect golang.org/x/oauth2 v0.19.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be // indirect diff --git a/auth/internal/transport/transport_test.go b/auth/internal/transport/transport_test.go index 3518195a55ad..69b382c2b734 100644 --- a/auth/internal/transport/transport_test.go +++ b/auth/internal/transport/transport_test.go @@ -29,7 +29,7 @@ import ( // future. To make the test pass simply bump the int, but please also clone the // relevant fields. func TestCloneDetectOptions_FieldTest(t *testing.T) { - const WantNumberOfFields = 12 + const WantNumberOfFields = 13 o := credentials.DetectOptions{} got := reflect.TypeOf(o).NumField() if got != WantNumberOfFields { diff --git a/auth/threelegged.go b/auth/threelegged.go index 1b8d83c4b4fe..1ccdeff84d0f 100644 --- a/auth/threelegged.go +++ b/auth/threelegged.go @@ -62,7 +62,8 @@ type Options3LO struct { // Optional. Client *http.Client // EarlyTokenExpiry is the time before the token expires that it should be - // refreshed. If not set the default value is 10 seconds. Optional. + // refreshed. If not set the default value is 3 minutes and 45 seconds. + // Optional. EarlyTokenExpiry time.Duration // AuthHandlerOpts provides a set of options for doing a From 9cd60beae20f6441d81649f7b66a6433abbd7c2c Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Wed, 12 Jun 2024 13:30:10 -0600 Subject: [PATCH 02/11] make asynchronous (non-blocking) refresh the default --- auth/auth.go | 42 ++++++++++++++++++++----------------- auth/auth_test.go | 26 ++++++++++++++++++++++- auth/credentials/compute.go | 4 ++-- auth/credentials/detect.go | 6 +++--- 4 files changed, 53 insertions(+), 25 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 578a6aafb146..ce4d73c6b570 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -237,9 +237,9 @@ type CachedTokenProviderOptions struct { // should be refreshed. If unset, the default value is 3 minutes and 45 // seconds. Optional. ExpireEarly time.Duration - // NonBlockingRefresh configures an asynchronous workflow that refreshes - // stale tokens without blocking. The default is false. Optional. - NonBlockingRefresh bool + // BlockingRefresh configures a synchronous workflow that refreshes + // stale tokens while blocking. The default is false. Optional. + BlockingRefresh bool } func (ctpo *CachedTokenProviderOptions) autoRefresh() bool { @@ -256,34 +256,38 @@ func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration { return ctpo.ExpireEarly } -func (ctpo *CachedTokenProviderOptions) nonBlockingRefresh() bool { +func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool { if ctpo == nil { return false } - return ctpo.NonBlockingRefresh + return ctpo.BlockingRefresh } // NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned -// by the underlying provider. By default it will refresh tokens ten seconds -// before they expire, but this time can be configured with the optional -// options. +// by the underlying provider. By default it will refresh tokens asynchronously +// (non-blocking mode) within a window that starts 3 minutes and 45 seconds +// before they expire. The asynchronous (non-blocking) refresh can be changed to +// a synchronous (blocking) refresh using the +// CachedTokenProviderOptions.BlockingRefresh option. The time-before-expiry +// duration can be configured using the CachedTokenProviderOptions.ExpireEarly +// option. func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider { if ctp, ok := tp.(*cachedTokenProvider); ok { return ctp } return &cachedTokenProvider{ - tp: tp, - autoRefresh: opts.autoRefresh(), - expireEarly: opts.expireEarly(), - nonBlockingRefresh: opts.nonBlockingRefresh(), + tp: tp, + autoRefresh: opts.autoRefresh(), + expireEarly: opts.expireEarly(), + blockingRefresh: opts.blockingRefresh(), } } type cachedTokenProvider struct { - tp TokenProvider - autoRefresh bool - expireEarly time.Duration - nonBlockingRefresh bool + tp TokenProvider + autoRefresh bool + expireEarly time.Duration + blockingRefresh bool // loadGroup ensures that the non-blocking refresh will only happen on one // goroutine, even if multiple callers have entered the Token method. loadGroup singleflight.Group @@ -293,10 +297,10 @@ type cachedTokenProvider struct { } func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) { - if c.nonBlockingRefresh { - return c.tokenNonBlocking(ctx) + if c.blockingRefresh { + return c.tokenBlocking(ctx) } - return c.tokenBlocking(ctx) + return c.tokenNonBlocking(ctx) } func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) { diff --git a/auth/auth_test.go b/auth/auth_test.go index 58334de553cb..1a3a5f33e0af 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -526,7 +526,6 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { timeNow = func() time.Time { return now } defer func() { timeNow = time.Now }() tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ - NonBlockingRefresh: true, // EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale. ExpireEarly: 2 * time.Second, }) @@ -554,3 +553,28 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { t.Errorf("got %q, want %q", freshToken2.Value, want) } } +func TestComputeTokenProvider_BlockingRefresh(t *testing.T) { + // Freeze now for consistent results. + now := time.Now() + timeNow = func() time.Time { return now } + defer func() { timeNow = time.Now }() + tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ + BlockingRefresh: true, + // EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale. + ExpireEarly: 2 * time.Second, + }) + freshToken, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if want := "1"; freshToken.Value != want { + t.Errorf("got %q, want %q", freshToken.Value, want) + } + freshToken2, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if want := "2"; freshToken2.Value != want { + t.Errorf("got %q, want %q", freshToken2.Value, want) + } +} diff --git a/auth/credentials/compute.go b/auth/credentials/compute.go index 656786e6955a..6e31fbe8d10d 100644 --- a/auth/credentials/compute.go +++ b/auth/credentials/compute.go @@ -39,8 +39,8 @@ var ( // uses the metadata service to retrieve tokens. func computeTokenProvider(opts *DetectOptions) auth.TokenProvider { return auth.NewCachedTokenProvider(computeProvider{scopes: opts.Scopes}, &auth.CachedTokenProviderOptions{ - ExpireEarly: opts.EarlyTokenRefresh, - NonBlockingRefresh: opts.NonBlockingRefresh, + ExpireEarly: opts.EarlyTokenRefresh, + BlockingRefresh: opts.BlockingRefresh, }) } diff --git a/auth/credentials/detect.go b/auth/credentials/detect.go index a30fe2e3123c..037fe58a641e 100644 --- a/auth/credentials/detect.go +++ b/auth/credentials/detect.go @@ -120,9 +120,9 @@ type DetectOptions struct { // this refresh window the token is considered valid but stale. If unset, // the default value is 3 minutes and 45 seconds. Optional. EarlyTokenRefresh time.Duration - // NonBlockingRefresh configures an asynchronous workflow that refreshes - // stale tokens without blocking. The default is false. Optional. - NonBlockingRefresh bool + // BlockingRefresh configures a synchronous workflow that refreshes + // stale tokens while blocking. The default is false. Optional. + BlockingRefresh bool // AuthHandlerOptions configures an authorization handler and other options // for 3LO flows. It is required, and only used, for client credential // flows. From d702505767eddc00e3d8b69c1d7f56c4b15f5517 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Thu, 13 Jun 2024 15:05:18 -0600 Subject: [PATCH 03/11] use existing context from main goroutine --- auth/auth.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index ce4d73c6b570..2baffef1ac06 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -315,7 +315,7 @@ func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, err // wait on chan and read its err. Instead, allow all requests during // the refresh window to join serial attempts managed by singleflight. // If all fail, the Expired case should return the same error. - c.tokenAsync() + c.tokenAsync(ctx) // Return the stale token immediately to not block customer requests to Cloud services c.mu.Lock() defer c.mu.Unlock() @@ -348,12 +348,11 @@ func (c *cachedTokenProvider) tokenState() tokenState { // happens at a time, even if multiple callers have entered this function // concurrently. This avoids creating an arbitrary number of concurrent // goroutines. -func (c *cachedTokenProvider) tokenAsync() <-chan singleflight.Result { +func (c *cachedTokenProvider) tokenAsync(ctx context.Context) <-chan singleflight.Result { return c.loadGroup.DoChan(nonBlockingRefreshKey, func() (entry any, err error) { - // Use a new context with timeout. This allows metadata.GetWithContext - // to retry regardless of the original request context. - refreshCtx, refreshCancel := context.WithTimeout(context.Background(), nonBlockingRefreshTimeout) + // Set a 30s timeout. + refreshCtx, refreshCancel := context.WithTimeout(ctx, nonBlockingRefreshTimeout) defer refreshCancel() t, err := c.tp.Token(refreshCtx) From d72de41a4db03f35a5575353b7b544910df565ff Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Thu, 13 Jun 2024 15:56:23 -0600 Subject: [PATCH 04/11] remove context timeout --- auth/auth.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 2baffef1ac06..c898caaaf2bc 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -350,12 +350,7 @@ func (c *cachedTokenProvider) tokenState() tokenState { // goroutines. func (c *cachedTokenProvider) tokenAsync(ctx context.Context) <-chan singleflight.Result { return c.loadGroup.DoChan(nonBlockingRefreshKey, func() (entry any, err error) { - - // Set a 30s timeout. - refreshCtx, refreshCancel := context.WithTimeout(ctx, nonBlockingRefreshTimeout) - defer refreshCancel() - - t, err := c.tp.Token(refreshCtx) + t, err := c.tp.Token(ctx) if err != nil { // In order to return this err to callers of the main goroutine, a // call to tokenAsync would need to wait on the returned chan and From 17ae457df05ff70b03ed35db7a2a79fe1d50d448 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 17 Jun 2024 13:27:00 -0600 Subject: [PATCH 05/11] update docs --- auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth/auth.go b/auth/auth.go index c898caaaf2bc..877c27b41791 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -314,7 +314,7 @@ func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, err // chan. In order to return an err from tokenAsync, we would need to // wait on chan and read its err. Instead, allow all requests during // the refresh window to join serial attempts managed by singleflight. - // If all fail, the Expired case should return the same error. + // If all fail, the invalid case should return the same error. c.tokenAsync(ctx) // Return the stale token immediately to not block customer requests to Cloud services c.mu.Lock() From 27c6bc5de52a57941d8daed3caa3797f4e177252 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Fri, 21 Jun 2024 11:03:03 -0600 Subject: [PATCH 06/11] replace singleflight with bool --- auth/auth.go | 59 +++++++++++++++++++++++++---------------------- auth/auth_test.go | 9 ++++++++ 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 877c27b41791..67c737147a19 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -27,8 +27,6 @@ import ( "cloud.google.com/go/auth/internal" "cloud.google.com/go/auth/internal/jwt" - - "golang.org/x/sync/singleflight" ) const ( @@ -43,10 +41,6 @@ const ( // so we give it 15 seconds to refresh it's cache before attempting to refresh a token. defaultExpiryDelta = 225 * time.Second - // nonBlockingRefreshKey is the singleflight uniqueness key for asynchronous - // refresh of a token. It can be any value, since there is only one - // cachedTokenProvider.cachedToken. - nonBlockingRefreshKey = "computeProvider" // nonBlockingRefreshTimeout is the timeout for asynchronous refresh of a // token. nonBlockingRefreshTimeout = 30 * time.Second @@ -288,9 +282,12 @@ type cachedTokenProvider struct { autoRefresh bool expireEarly time.Duration blockingRefresh bool - // loadGroup ensures that the non-blocking refresh will only happen on one - // goroutine, even if multiple callers have entered the Token method. - loadGroup singleflight.Group + // isRefreshRunning ensures that the non-blocking refresh will only be + // attempted once, even if multiple callers enter the Token method. + isRefreshRunning bool + // isRefreshErr ensures that the non-blocking refresh will only be attempted + // once per refresh window if an error is encountered. + isRefreshErr bool mu sync.Mutex cachedToken *Token @@ -310,13 +307,8 @@ func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, err defer c.mu.Unlock() return c.cachedToken, nil case stale: - // Call singleflight's DoChan (via tokenAsync) but discard the returned - // chan. In order to return an err from tokenAsync, we would need to - // wait on chan and read its err. Instead, allow all requests during - // the refresh window to join serial attempts managed by singleflight. - // If all fail, the invalid case should return the same error. c.tokenAsync(ctx) - // Return the stale token immediately to not block customer requests to Cloud services + // Return the stale token immediately to not block customer requests to Cloud services. c.mu.Lock() defer c.mu.Unlock() return c.cachedToken, nil @@ -344,30 +336,41 @@ func (c *cachedTokenProvider) tokenState() tokenState { return fresh } -// tokenAsync uses singleflight to ensure that only one async token fetch +// tokenAsync uses a bool to ensure that only one non-blocking token refresh // happens at a time, even if multiple callers have entered this function // concurrently. This avoids creating an arbitrary number of concurrent -// goroutines. -func (c *cachedTokenProvider) tokenAsync(ctx context.Context) <-chan singleflight.Result { - return c.loadGroup.DoChan(nonBlockingRefreshKey, func() (entry any, err error) { +// goroutines. Retries should be attempted and managed within the Token method. +// If the refresh attempt fails, no further attempts are made until the refresh +// window expires and the token enters the invalid state, at which point the +// blocking call to Token should likely return the same error on the main goroutine. +func (c *cachedTokenProvider) tokenAsync(ctx context.Context) { + fn := func() { + c.mu.Lock() + c.isRefreshRunning = true + c.mu.Unlock() t, err := c.tp.Token(ctx) - if err != nil { - // In order to return this err to callers of the main goroutine, a - // call to tokenAsync would need to wait on the returned chan and - // read its err. Currently, it is ignored in tokenNonBlocking, above. - return nil, err - } - c.mu.Lock() defer c.mu.Unlock() + c.isRefreshRunning = false + if err != nil { + // Discard errors from the non-blocking refresh, but prevent further + // attempts. + c.isRefreshErr = true + return + } c.cachedToken = t - return t, nil - }) + } + c.mu.Lock() + if !c.isRefreshRunning && !c.isRefreshErr { + go fn() + } + defer c.mu.Unlock() } func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) { c.mu.Lock() defer c.mu.Unlock() + c.isRefreshErr = false if c.cachedToken.IsValid() || !c.autoRefresh { return c.cachedToken, nil } diff --git a/auth/auth_test.go b/auth/auth_test.go index 1a3a5f33e0af..20e88027c1a0 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -552,6 +552,15 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { if want := "2"; freshToken2.Value != want { t.Errorf("got %q, want %q", freshToken2.Value, want) } + // Allow time for 2nd async refresh. + time.Sleep(100 * time.Millisecond) + freshToken3, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if want := "3"; freshToken3.Value != want { + t.Errorf("got %q, want %q", freshToken3.Value, want) + } } func TestComputeTokenProvider_BlockingRefresh(t *testing.T) { // Freeze now for consistent results. From e0e540cdfec2f2a537182f535fd9645f611fddff Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Fri, 21 Jun 2024 11:09:04 -0600 Subject: [PATCH 07/11] go mod tidy --- auth/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth/go.mod b/auth/go.mod index 9612ab50c0ed..e3bf20248552 100644 --- a/auth/go.mod +++ b/auth/go.mod @@ -9,7 +9,6 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 go.opencensus.io v0.24.0 golang.org/x/net v0.26.0 - golang.org/x/sync v0.7.0 google.golang.org/grpc v1.64.0 google.golang.org/protobuf v1.34.2 ) @@ -19,6 +18,7 @@ require ( github.com/golang/protobuf v1.5.4 // indirect golang.org/x/crypto v0.24.0 // indirect golang.org/x/oauth2 v0.19.0 // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be // indirect From 3622735a7eaabf9e01beeabee4440022f27d1f84 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Fri, 21 Jun 2024 15:25:15 -0600 Subject: [PATCH 08/11] fix codyoss nits --- auth/auth.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 67c737147a19..abf9bb59b667 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -41,10 +41,6 @@ const ( // so we give it 15 seconds to refresh it's cache before attempting to refresh a token. defaultExpiryDelta = 225 * time.Second - // nonBlockingRefreshTimeout is the timeout for asynchronous refresh of a - // token. - nonBlockingRefreshTimeout = 30 * time.Second - universeDomainDefault = "googleapis.com" ) @@ -102,7 +98,7 @@ type Token struct { // expired. A token is considered expired if [Token.Expiry] has passed or will // pass in the next 225 seconds. func (t *Token) IsValid() bool { - return t.isValidWithEarlyExpiry(defaultExpiryDelta) // TODO(quartzmo): investigate why EarlyTokenRefresh, ExpireEarly isn't used here. Bug? + return t.isValidWithEarlyExpiry(defaultExpiryDelta) } func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool { @@ -282,15 +278,15 @@ type cachedTokenProvider struct { autoRefresh bool expireEarly time.Duration blockingRefresh bool + + mu sync.Mutex + cachedToken *Token // isRefreshRunning ensures that the non-blocking refresh will only be // attempted once, even if multiple callers enter the Token method. isRefreshRunning bool // isRefreshErr ensures that the non-blocking refresh will only be attempted // once per refresh window if an error is encountered. isRefreshErr bool - - mu sync.Mutex - cachedToken *Token } func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) { @@ -312,10 +308,8 @@ func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, err c.mu.Lock() defer c.mu.Unlock() return c.cachedToken, nil - case invalid: + default: // invalid return c.tokenBlocking(ctx) - default: - panic("unreachable") } } @@ -361,10 +355,10 @@ func (c *cachedTokenProvider) tokenAsync(ctx context.Context) { c.cachedToken = t } c.mu.Lock() + defer c.mu.Unlock() if !c.isRefreshRunning && !c.isRefreshErr { go fn() } - defer c.mu.Unlock() } func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) { From 707fbe2e21adee1e524a1dff06c6787143ef5ec5 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 24 Jun 2024 12:11:16 -0600 Subject: [PATCH 09/11] rename BlockingRefresh to DisableAsyncRefresh --- auth/auth.go | 6 +++--- auth/auth_test.go | 4 ++-- auth/credentials/compute.go | 4 ++-- auth/credentials/detect.go | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 69a64909e546..7d4f2322068f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -231,9 +231,9 @@ type CachedTokenProviderOptions struct { // should be refreshed. If unset, the default value is 3 minutes and 45 // seconds. Optional. ExpireEarly time.Duration - // BlockingRefresh configures a synchronous workflow that refreshes + // DisableAsyncRefresh configures a synchronous workflow that refreshes // stale tokens while blocking. The default is false. Optional. - BlockingRefresh bool + DisableAsyncRefresh bool } func (ctpo *CachedTokenProviderOptions) autoRefresh() bool { @@ -254,7 +254,7 @@ func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool { if ctpo == nil { return false } - return ctpo.BlockingRefresh + return ctpo.DisableAsyncRefresh } // NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned diff --git a/auth/auth_test.go b/auth/auth_test.go index a5492638f0fb..447ca283606d 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -570,8 +570,8 @@ func TestComputeTokenProvider_BlockingRefresh(t *testing.T) { timeNow = func() time.Time { return now } defer func() { timeNow = time.Now }() tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ - BlockingRefresh: true, - DisableAutoRefresh: true, + DisableAsyncRefresh: true, + DisableAutoRefresh: true, // EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale. ExpireEarly: 2 * time.Millisecond, }) diff --git a/auth/credentials/compute.go b/auth/credentials/compute.go index 6e31fbe8d10d..6f70fa353b00 100644 --- a/auth/credentials/compute.go +++ b/auth/credentials/compute.go @@ -39,8 +39,8 @@ var ( // uses the metadata service to retrieve tokens. func computeTokenProvider(opts *DetectOptions) auth.TokenProvider { return auth.NewCachedTokenProvider(computeProvider{scopes: opts.Scopes}, &auth.CachedTokenProviderOptions{ - ExpireEarly: opts.EarlyTokenRefresh, - BlockingRefresh: opts.BlockingRefresh, + ExpireEarly: opts.EarlyTokenRefresh, + DisableAsyncRefresh: opts.DisableAsyncRefresh, }) } diff --git a/auth/credentials/detect.go b/auth/credentials/detect.go index 037fe58a641e..f9c8956ed90a 100644 --- a/auth/credentials/detect.go +++ b/auth/credentials/detect.go @@ -120,9 +120,9 @@ type DetectOptions struct { // this refresh window the token is considered valid but stale. If unset, // the default value is 3 minutes and 45 seconds. Optional. EarlyTokenRefresh time.Duration - // BlockingRefresh configures a synchronous workflow that refreshes + // DisableAsyncRefresh configures a synchronous workflow that refreshes // stale tokens while blocking. The default is false. Optional. - BlockingRefresh bool + DisableAsyncRefresh bool // AuthHandlerOptions configures an authorization handler and other options // for 3LO flows. It is required, and only used, for client credential // flows. From 5e13e7ee172898c14376955bf2ebc6d3c77ba1f0 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 24 Jun 2024 12:13:31 -0600 Subject: [PATCH 10/11] update docs for rename --- auth/auth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth/auth.go b/auth/auth.go index 7d4f2322068f..36729b604aba 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -262,7 +262,7 @@ func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool { // (non-blocking mode) within a window that starts 3 minutes and 45 seconds // before they expire. The asynchronous (non-blocking) refresh can be changed to // a synchronous (blocking) refresh using the -// CachedTokenProviderOptions.BlockingRefresh option. The time-before-expiry +// CachedTokenProviderOptions.DisableAsyncRefresh option. The time-before-expiry // duration can be configured using the CachedTokenProviderOptions.ExpireEarly // option. func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider { From 05aa0ca2280dec3687ca95e53cb0f414533d5b6e Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Mon, 24 Jun 2024 15:42:50 -0600 Subject: [PATCH 11/11] improve tests --- auth/auth_test.go | 107 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 84 insertions(+), 23 deletions(-) diff --git a/auth/auth_test.go b/auth/auth_test.go index 447ca283606d..b25222bfafa7 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -514,8 +514,9 @@ type countingTestProvider struct { func (tp *countingTestProvider) Token(ctx context.Context) (*Token, error) { tok := &Token{ Value: fmt.Sprint(tp.count), - // Set expiry to q1 second from now. - Expiry: time.Now().Add(time.Second), + // Set expiry to count times seconds from now, so that as count increases + // to 2, token state changes from stale to fresh. + Expiry: time.Now().Add(time.Duration(tp.count) * time.Second), } tp.count++ return tok, nil @@ -527,13 +528,19 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { timeNow = func() time.Time { return now } defer func() { timeNow = time.Now }() tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ - // EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale. - ExpireEarly: 2 * time.Second, + // EarlyTokenRefresh ensures that token with early expiry just less than 2 seconds before now is already stale. + ExpireEarly: 1990 * time.Millisecond, }) + if state := tp.(*cachedTokenProvider).tokenState(); state != invalid { + t.Errorf("got %d, want %d", state, invalid) + } freshToken, err := tp.Token(context.Background()) if err != nil { t.Fatal(err) } + if state := tp.(*cachedTokenProvider).tokenState(); state != stale { + t.Errorf("got %d, want %d", state, stale) + } if want := "1"; freshToken.Value != want { t.Errorf("got %q, want %q", freshToken.Value, want) } @@ -541,6 +548,9 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { if err != nil { t.Fatal(err) } + if state := tp.(*cachedTokenProvider).tokenState(); state != stale { + t.Errorf("got %d, want %d", state, stale) + } if want := "1"; staleToken.Value != want { t.Errorf("got %q, want %q", staleToken.Value, want) } @@ -550,6 +560,9 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { if err != nil { t.Fatal(err) } + if state := tp.(*cachedTokenProvider).tokenState(); state != fresh { + t.Errorf("got %d, want %d", state, fresh) + } if want := "2"; freshToken2.Value != want { t.Errorf("got %q, want %q", freshToken2.Value, want) } @@ -559,30 +572,78 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) { if err != nil { t.Fatal(err) } - if want := "3"; freshToken3.Value != want { + if state := tp.(*cachedTokenProvider).tokenState(); state != fresh { + t.Errorf("got %d, want %d", state, fresh) + } + if want := "2"; freshToken3.Value != want { t.Errorf("got %q, want %q", freshToken3.Value, want) } } func TestComputeTokenProvider_BlockingRefresh(t *testing.T) { - // Freeze now for consistent results. - now := time.Now() - timeNow = func() time.Time { return now } - defer func() { timeNow = time.Now }() - tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ - DisableAsyncRefresh: true, - DisableAutoRefresh: true, - // EarlyTokenRefresh ensures that token with expiry 1 second from now is already stale. - ExpireEarly: 2 * time.Millisecond, - }) - freshToken, err := tp.Token(context.Background()) - if err != nil { - t.Fatal(err) - } - if freshToken == nil { - t.Fatal("freshToken is nil") + tests := []struct { + name string + disableAutoRefresh bool + want1 string + want2 string + wantState2 tokenState + }{ + { + name: "disableAutoRefresh", + disableAutoRefresh: true, + want1: "1", + want2: "1", + // Because token "count" does not increase, it will always be stale. + wantState2: stale, + }, + { + name: "autoRefresh", + disableAutoRefresh: false, + want1: "1", + want2: "2", + // As token "count" increases to 2, it transitions to fresh. + wantState2: fresh, + }, } - if want := "1"; freshToken.Value != want { - t.Errorf("got %q, want %q", freshToken.Value, want) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Freeze now for consistent results. + now := time.Now() + timeNow = func() time.Time { return now } + defer func() { timeNow = time.Now }() + tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{ + DisableAsyncRefresh: true, + DisableAutoRefresh: tt.disableAutoRefresh, + // EarlyTokenRefresh ensures that token with early expiry just less than 2 seconds before now is already stale. + ExpireEarly: 1990 * time.Millisecond, + }) + if state := tp.(*cachedTokenProvider).tokenState(); state != invalid { + t.Errorf("got %d, want %d", state, invalid) + } + freshToken, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if freshToken == nil { + t.Fatal("freshToken is nil") + } + if state := tp.(*cachedTokenProvider).tokenState(); state != stale { + t.Errorf("got %d, want %d", state, stale) + } + if freshToken.Value != tt.want1 { + t.Errorf("got %q, want %q", freshToken.Value, tt.want1) + } + time.Sleep(100 * time.Millisecond) + freshToken2, err := tp.Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if state := tp.(*cachedTokenProvider).tokenState(); state != tt.wantState2 { + t.Errorf("got %d, want %d", state, tt.wantState2) + } + if freshToken2.Value != tt.want2 { + t.Errorf("got %q, want %q", freshToken2.Value, tt.want2) + } + }) } }