diff --git a/flytectl/cmd/core/cmd.go b/flytectl/cmd/core/cmd.go index b2bd77c317..989f4b7ebb 100644 --- a/flytectl/cmd/core/cmd.go +++ b/flytectl/cmd/core/cmd.go @@ -73,10 +73,10 @@ func generateCommandFunc(cmdEntry CommandEntry) func(cmd *cobra.Command, args [] cmdCtx := NewCommandContextNoClient(cmd.OutOrStdout()) if !cmdEntry.DisableFlyteClient { clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)). - WithTokenCache(pkce.NewTokenCacheKeyringProvider( - pkce.KeyRingServiceName, - fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser), - )).Build(ctx) + WithTokenCache(pkce.TokenCacheKeyringProvider{ + ServiceUser: fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser), + ServiceName: pkce.KeyRingServiceName, + }).Build(ctx) if err != nil { return err } diff --git a/flytectl/pkg/pkce/token_cache_keyring.go b/flytectl/pkg/pkce/token_cache_keyring.go index ff547827bd..119fea5033 100644 --- a/flytectl/pkg/pkce/token_cache_keyring.go +++ b/flytectl/pkg/pkce/token_cache_keyring.go @@ -3,68 +3,23 @@ package pkce import ( "encoding/json" "fmt" - "sync" - - "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" "github.com/zalando/go-keyring" "golang.org/x/oauth2" ) -const ( - KeyRingServiceUser = "flytectl-user" - KeyRingServiceName = "flytectl" -) - // TokenCacheKeyringProvider wraps the logic to save and retrieve tokens from the OS's keyring implementation. type TokenCacheKeyringProvider struct { ServiceName string ServiceUser string - mu *sync.Mutex - cond *sync.Cond -} - -func (t *TokenCacheKeyringProvider) PurgeIfEquals(existing *oauth2.Token) (bool, error) { - if existingBytes, err := json.Marshal(existing); err != nil { - return false, fmt.Errorf("unable to marshal token to save in cache due to %w", err) - } else if tokenJSON, err := keyring.Get(t.ServiceName, t.ServiceUser); err != nil { - if err.Error() == "secret not found in keyring" { - return false, fmt.Errorf("unable to read token from cache. Error: %w", cache.ErrNotFound) - } - - return false, fmt.Errorf("unable to read token from cache. Error: %w", err) - } else if tokenJSON != string(existingBytes) { - return false, nil - } - - _ = keyring.Delete(t.ServiceName, t.ServiceUser) - return true, nil } -func (t *TokenCacheKeyringProvider) Lock() { - t.mu.Lock() -} - -func (t *TokenCacheKeyringProvider) Unlock() { - t.mu.Unlock() -} - -// TryLock the cache. -func (t *TokenCacheKeyringProvider) TryLock() bool { - return t.mu.TryLock() -} - -// CondWait waits for the condition to be true. -func (t *TokenCacheKeyringProvider) CondWait() { - t.cond.Wait() -} - -// CondBroadcast signals the condition. -func (t *TokenCacheKeyringProvider) CondBroadcast() { - t.cond.Broadcast() -} +const ( + KeyRingServiceUser = "flytectl-user" + KeyRingServiceName = "flytectl" +) -func (t *TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error { +func (t TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error { var tokenBytes []byte if token.AccessToken == "" { return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry) @@ -83,7 +38,7 @@ func (t *TokenCacheKeyringProvider) SaveToken(token *oauth2.Token) error { return nil } -func (t *TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) { +func (t TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) { // get saved token tokenJSON, err := keyring.Get(t.ServiceName, t.ServiceUser) if len(tokenJSON) == 0 { @@ -101,12 +56,3 @@ func (t *TokenCacheKeyringProvider) GetToken() (*oauth2.Token, error) { return &token, nil } - -func NewTokenCacheKeyringProvider(serviceName, serviceUser string) *TokenCacheKeyringProvider { - return &TokenCacheKeyringProvider{ - mu: &sync.Mutex{}, - cond: sync.NewCond(&sync.Mutex{}), - ServiceName: serviceName, - ServiceUser: serviceUser, - } -} diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index 4cebf6440f..8a0024b319 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -20,8 +20,7 @@ const ProxyAuthorizationHeader = "proxy-authorization" // MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. // Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. -func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, - perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { +func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture) if err != nil { return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) @@ -43,17 +42,11 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) if err != nil { - return fmt.Errorf("failed to get token source. Error: %w", err) - } - - _, err = tokenSource.Token() - if err != nil { - return fmt.Errorf("failed to issue token. Error: %w", err) + return err } wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey) perRPCCredentials.Store(wrappedTokenSource) - return nil } @@ -141,15 +134,6 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture) - // If there is already a token in the cache (e.g. key-ring), we should use it immediately... - t, _ := tokenCache.GetToken() - if t != nil { - err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) - if err != nil { - return fmt.Errorf("failed to materialize credentials. Error: %v", err) - } - } - err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err) @@ -157,34 +141,12 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut if st, ok := status.FromError(err); ok { // If the error we receive from executing the request expects if shouldAttemptToAuthenticate(st.Code()) { - err = func() error { - if !tokenCache.TryLock() { - tokenCache.CondWait() - return nil - } - - defer tokenCache.Unlock() - _, err := tokenCache.PurgeIfEquals(t) - if err != nil && !errors.Is(err, cache.ErrNotFound) { - logger.Errorf(ctx, "Failed to purge cache. Error [%v]", err) - return fmt.Errorf("failed to purge cache. Error: %w", err) - } - - logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) - newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) - if newErr != nil { - errString := fmt.Sprintf("authentication error! Original Error: %v, Auth Error: %v", err, newErr) - logger.Errorf(ctx, errString) - return fmt.Errorf(errString) - } - - tokenCache.CondBroadcast() - return nil - }() - - if err != nil { - return err + logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) + newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) + if newErr != nil { + return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) } + return invoker(ctx, method, req, reply, cc, opts...) } } diff --git a/flyteidl/clients/go/admin/auth_interceptor_test.go b/flyteidl/clients/go/admin/auth_interceptor_test.go index 10c96625b7..ce99c99270 100644 --- a/flyteidl/clients/go/admin/auth_interceptor_test.go +++ b/flyteidl/clients/go/admin/auth_interceptor_test.go @@ -2,14 +2,13 @@ package admin import ( "context" - "encoding/json" "errors" "fmt" "io" "net" "net/http" + "net/http/httptest" "net/url" - "os" "strings" "sync" "testing" @@ -32,11 +31,10 @@ import ( // authMetadataServer is a fake AuthMetadataServer that takes in an AuthMetadataServer implementation (usually one // initialized through mockery) and starts a local server that uses it to respond to grpc requests. type authMetadataServer struct { + s *httptest.Server t testing.TB - grpcPort int - httpPort int + port int grpcServer *grpc.Server - httpServer *http.Server netListener net.Listener impl service.AuthMetadataServiceServer lck *sync.RWMutex @@ -72,49 +70,27 @@ func (s authMetadataServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } -func (s *authMetadataServer) tokenHandler(w http.ResponseWriter, r *http.Request) { - tokenJSON := []byte(`{"access_token": "exampletoken", "token_type": "bearer"}`) - w.Header().Set("Content-Type", "application/json") - _, err := w.Write(tokenJSON) - assert.NoError(s.t, err) -} - func (s *authMetadataServer) Start(_ context.Context) error { s.lck.Lock() defer s.lck.Unlock() /***** Set up the server serving channelz service. *****/ - - lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.grpcPort)) + lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.port)) if err != nil { - return fmt.Errorf("failed to listen on port [%v]: %w", s.grpcPort, err) + return fmt.Errorf("failed to listen on port [%v]: %w", s.port, err) } - s.netListener = lis grpcS := grpc.NewServer() service.RegisterAuthMetadataServiceServer(grpcS, s) go func() { - defer grpcS.Stop() _ = grpcS.Serve(lis) + //assert.NoError(s.t, err) }() + s.grpcServer = grpcS - mux := http.NewServeMux() - // Attach the handler to the /oauth2/token path - mux.HandleFunc("/oauth2/token", s.tokenHandler) - - //nolint:gosec - s.httpServer = &http.Server{ - Addr: fmt.Sprintf("localhost:%d", s.httpPort), - Handler: mux, - } + s.netListener = lis - go func() { - defer s.httpServer.Close() - err := s.httpServer.ListenAndServe() - if err != nil { - panic(err) - } - }() + s.s = httptest.NewServer(s) return nil } @@ -122,30 +98,25 @@ func (s *authMetadataServer) Start(_ context.Context) error { func (s *authMetadataServer) Close() { s.lck.RLock() defer s.lck.RUnlock() + s.grpcServer.Stop() + s.s.Close() } -func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl service.AuthMetadataServiceServer) *authMetadataServer { +func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServiceServer) *authMetadataServer { return &authMetadataServer{ - grpcPort: grpcPort, - httpPort: httpPort, - t: t, - impl: impl, - lck: &sync.RWMutex{}, + port: port, + t: t, + impl: impl, + lck: &sync.RWMutex{}, } } func Test_newAuthInterceptor(t *testing.T) { - plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json") - var tokenData oauth2.Token - err := json.Unmarshal(plan, &tokenData) - assert.NoError(t, err) t.Run("Other Error", func(t *testing.T) { f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - mockTokenCache := &mocks.TokenCache{} - mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) - interceptor := NewAuthInterceptor(&Config{}, mockTokenCache, f, p) + interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p) otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Canceled, "").Err() } @@ -158,43 +129,35 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - httpPort := rand.IntnRange(10000, 60000) - grpcPort := rand.IntnRange(10000, 60000) + port := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ - AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort), - TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), - JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort), + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), }, nil) - m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) - s := newAuthMetadataServer(t, grpcPort, httpPort, m) + s := newAuthMetadataServer(t, port, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - c := &mocks.TokenCache{} - c.OnGetTokenMatch().Return(nil, nil) - c.OnTryLockMatch().Return(true) - c.OnSaveTokenMatch(mock.Anything).Return(nil) - c.On("CondBroadcast").Return() - c.On("Unlock").Return() - c.OnPurgeIfEqualsMatch(mock.Anything).Return(true, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, c, f, p) + }, &mocks.TokenCache{}, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Unauthenticated, "").Err() } + err = interceptor(ctx, "POST", nil, nil, nil, unauthenticated) assert.Error(t, err) assert.Truef(t, f.IsInitialized(), "PerRPCCredentialFuture should be initialized") @@ -206,26 +169,24 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - httpPort := rand.IntnRange(10000, 60000) - grpcPort := rand.IntnRange(10000, 60000) + port := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} - s := newAuthMetadataServer(t, grpcPort, httpPort, m) + s := newAuthMetadataServer(t, port, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - c := &mocks.TokenCache{} - c.OnGetTokenMatch().Return(nil, nil) + interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, c, f, p) + }, &mocks.TokenCache{}, f, p) authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil } @@ -240,39 +201,33 @@ func Test_newAuthInterceptor(t *testing.T) { Level: logger.DebugLevel, })) - httpPort := rand.IntnRange(10000, 60000) - grpcPort := rand.IntnRange(10000, 60000) + port := rand.IntnRange(10000, 60000) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ - AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort), - TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), - JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort), + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), }, nil) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) - s := newAuthMetadataServer(t, grpcPort, httpPort, m) + s := newAuthMetadataServer(t, port, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() - c := &mocks.TokenCache{} - c.OnGetTokenMatch().Return(nil, nil) - c.OnTryLockMatch().Return(true) - c.OnSaveTokenMatch(mock.Anything).Return(nil) - c.OnPurgeIfEqualsMatch(mock.Anything).Return(true, nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, c, f, p) + }, &mocks.TokenCache{}, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Aborted, "").Err() } @@ -284,21 +239,17 @@ func Test_newAuthInterceptor(t *testing.T) { } func TestMaterializeCredentials(t *testing.T) { + port := rand.IntnRange(10000, 60000) t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) { - httpPort := rand.IntnRange(10000, 60000) - grpcPort := rand.IntnRange(10000, 60000) - c := &mocks.TokenCache{} - c.OnGetTokenMatch().Return(nil, nil) - c.OnSaveTokenMatch(mock.Anything).Return(nil) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config")) - s := newAuthMetadataServer(t, grpcPort, httpPort, m) + s := newAuthMetadataServer(t, port, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() @@ -308,29 +259,24 @@ func TestMaterializeCredentials(t *testing.T) { Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort), + TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), Scopes: []string{"all"}, Audience: "http://localhost:30081", AuthorizationHeader: "authorization", - }, c, f, p) + }, &mocks.TokenCache{}, f, p) assert.NoError(t, err) }) t.Run("Failed to fetch client metadata", func(t *testing.T) { - httpPort := rand.IntnRange(10000, 60000) - grpcPort := rand.IntnRange(10000, 60000) - c := &mocks.TokenCache{} - c.OnGetTokenMatch().Return(nil, nil) - c.OnSaveTokenMatch(mock.Anything).Return(nil) m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) failedPublicClientConfigLookup := errors.New("expected err") m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup) - s := newAuthMetadataServer(t, grpcPort, httpPort, m) + s := newAuthMetadataServer(t, port, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) defer s.Close() - u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort)) + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) assert.NoError(t, err) f := NewPerRPCCredentialsFuture() @@ -340,9 +286,9 @@ func TestMaterializeCredentials(t *testing.T) { Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort), + TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), Scopes: []string{"all"}, - }, c, f, p) + }, &mocks.TokenCache{}, f, p) assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err") }) } diff --git a/flyteidl/clients/go/admin/cache/mocks/token_cache.go b/flyteidl/clients/go/admin/cache/mocks/token_cache.go index 88a1bef81c..0af58b381f 100644 --- a/flyteidl/clients/go/admin/cache/mocks/token_cache.go +++ b/flyteidl/clients/go/admin/cache/mocks/token_cache.go @@ -12,16 +12,6 @@ type TokenCache struct { mock.Mock } -// CondBroadcast provides a mock function with given fields: -func (_m *TokenCache) CondBroadcast() { - _m.Called() -} - -// CondWait provides a mock function with given fields: -func (_m *TokenCache) CondWait() { - _m.Called() -} - type TokenCache_GetToken struct { *mock.Call } @@ -63,50 +53,6 @@ func (_m *TokenCache) GetToken() (*oauth2.Token, error) { return r0, r1 } -// Lock provides a mock function with given fields: -func (_m *TokenCache) Lock() { - _m.Called() -} - -type TokenCache_PurgeIfEquals struct { - *mock.Call -} - -func (_m TokenCache_PurgeIfEquals) Return(_a0 bool, _a1 error) *TokenCache_PurgeIfEquals { - return &TokenCache_PurgeIfEquals{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *TokenCache) OnPurgeIfEquals(t *oauth2.Token) *TokenCache_PurgeIfEquals { - c_call := _m.On("PurgeIfEquals", t) - return &TokenCache_PurgeIfEquals{Call: c_call} -} - -func (_m *TokenCache) OnPurgeIfEqualsMatch(matchers ...interface{}) *TokenCache_PurgeIfEquals { - c_call := _m.On("PurgeIfEquals", matchers...) - return &TokenCache_PurgeIfEquals{Call: c_call} -} - -// PurgeIfEquals provides a mock function with given fields: t -func (_m *TokenCache) PurgeIfEquals(t *oauth2.Token) (bool, error) { - ret := _m.Called(t) - - var r0 bool - if rf, ok := ret.Get(0).(func(*oauth2.Token) bool); ok { - r0 = rf(t) - } else { - r0 = ret.Get(0).(bool) - } - - var r1 error - if rf, ok := ret.Get(1).(func(*oauth2.Token) error); ok { - r1 = rf(t) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - type TokenCache_SaveToken struct { *mock.Call } @@ -138,40 +84,3 @@ func (_m *TokenCache) SaveToken(token *oauth2.Token) error { return r0 } - -type TokenCache_TryLock struct { - *mock.Call -} - -func (_m TokenCache_TryLock) Return(_a0 bool) *TokenCache_TryLock { - return &TokenCache_TryLock{Call: _m.Call.Return(_a0)} -} - -func (_m *TokenCache) OnTryLock() *TokenCache_TryLock { - c_call := _m.On("TryLock") - return &TokenCache_TryLock{Call: c_call} -} - -func (_m *TokenCache) OnTryLockMatch(matchers ...interface{}) *TokenCache_TryLock { - c_call := _m.On("TryLock", matchers...) - return &TokenCache_TryLock{Call: c_call} -} - -// TryLock provides a mock function with given fields: -func (_m *TokenCache) TryLock() bool { - ret := _m.Called() - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -// Unlock provides a mock function with given fields: -func (_m *TokenCache) Unlock() { - _m.Called() -} diff --git a/flyteidl/clients/go/admin/cache/token_cache.go b/flyteidl/clients/go/admin/cache/token_cache.go index f2d55fc0dd..e4e2b7e17f 100644 --- a/flyteidl/clients/go/admin/cache/token_cache.go +++ b/flyteidl/clients/go/admin/cache/token_cache.go @@ -1,40 +1,14 @@ package cache -import ( - "fmt" - - "golang.org/x/oauth2" -) +import "golang.org/x/oauth2" //go:generate mockery -all -case=underscore -var ( - ErrNotFound = fmt.Errorf("secret not found in keyring") -) - // TokenCache defines the interface needed to cache and retrieve oauth tokens. type TokenCache interface { // SaveToken saves the token securely to cache. SaveToken(token *oauth2.Token) error - // GetToken retrieves the token from the cache. + // Retrieves the token from the cache. GetToken() (*oauth2.Token, error) - - // PurgeIfEquals purges the token from the cache. - PurgeIfEquals(t *oauth2.Token) (bool, error) - - // Lock the cache. - Lock() - - // TryLock tries to lock the cache. - TryLock() bool - - // Unlock the cache. - Unlock() - - // CondWait waits for the condition to be true. - CondWait() - - // CondSignalCondBroadcast signals the condition. - CondBroadcast() } diff --git a/flyteidl/clients/go/admin/cache/token_cache_inmemory.go b/flyteidl/clients/go/admin/cache/token_cache_inmemory.go index 46e134ec55..9c6223fc06 100644 --- a/flyteidl/clients/go/admin/cache/token_cache_inmemory.go +++ b/flyteidl/clients/go/admin/cache/token_cache_inmemory.go @@ -2,62 +2,23 @@ package cache import ( "fmt" - "sync" - "sync/atomic" "golang.org/x/oauth2" ) type TokenCacheInMemoryProvider struct { - token atomic.Value - mu *sync.Mutex - cond *sync.Cond + token *oauth2.Token } func (t *TokenCacheInMemoryProvider) SaveToken(token *oauth2.Token) error { - t.token.Store(token) + t.token = token return nil } -func (t *TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) { - tkn := t.token.Load() - if tkn == nil { +func (t TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) { + if t.token == nil { return nil, fmt.Errorf("cannot find token in cache") } - return tkn.(*oauth2.Token), nil -} - -func (t *TokenCacheInMemoryProvider) PurgeIfEquals(existing *oauth2.Token) (bool, error) { - return t.token.CompareAndSwap(existing, nil), nil -} - -func (t *TokenCacheInMemoryProvider) Lock() { - t.mu.Lock() -} - -func (t *TokenCacheInMemoryProvider) TryLock() bool { - return t.mu.TryLock() -} - -func (t *TokenCacheInMemoryProvider) Unlock() { - t.mu.Unlock() -} - -// CondWait waits for the condition to be true. -func (t *TokenCacheInMemoryProvider) CondWait() { - t.cond.Wait() -} - -// CondBroadcast signals the condition. -func (t *TokenCacheInMemoryProvider) CondBroadcast() { - t.cond.Broadcast() -} - -func NewTokenCacheInMemoryProvider() *TokenCacheInMemoryProvider { - return &TokenCacheInMemoryProvider{ - mu: &sync.Mutex{}, - token: atomic.Value{}, - cond: sync.NewCond(&sync.Mutex{}), - } + return t.token, nil } diff --git a/flyteidl/clients/go/admin/client_builder.go b/flyteidl/clients/go/admin/client_builder.go index 0d1341bf7b..25b263ecf1 100644 --- a/flyteidl/clients/go/admin/client_builder.go +++ b/flyteidl/clients/go/admin/client_builder.go @@ -40,7 +40,7 @@ func (cb *ClientsetBuilder) WithDialOptions(opts ...grpc.DialOption) *ClientsetB // Build the clientset using the current state of the ClientsetBuilder func (cb *ClientsetBuilder) Build(ctx context.Context) (*Clientset, error) { if cb.tokenCache == nil { - cb.tokenCache = cache.NewTokenCacheInMemoryProvider() + cb.tokenCache = &cache.TokenCacheInMemoryProvider{} } if cb.config == nil { diff --git a/flyteidl/clients/go/admin/client_builder_test.go b/flyteidl/clients/go/admin/client_builder_test.go index 89bcc38550..c871bcb326 100644 --- a/flyteidl/clients/go/admin/client_builder_test.go +++ b/flyteidl/clients/go/admin/client_builder_test.go @@ -17,9 +17,9 @@ func TestClientsetBuilder_Build(t *testing.T) { cb := NewClientsetBuilder().WithConfig(&Config{ UseInsecureConnection: true, Endpoint: config.URL{URL: *u}, - }).WithTokenCache(cache.NewTokenCacheInMemoryProvider()) + }).WithTokenCache(&cache.TokenCacheInMemoryProvider{}) ctx := context.Background() _, err := cb.Build(ctx) assert.NoError(t, err) - assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(cache.NewTokenCacheInMemoryProvider())) + assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(&cache.TokenCacheInMemoryProvider{})) } diff --git a/flyteidl/clients/go/admin/client_test.go b/flyteidl/clients/go/admin/client_test.go index 042a826692..eb19b76f47 100644 --- a/flyteidl/clients/go/admin/client_test.go +++ b/flyteidl/clients/go/admin/client_test.go @@ -255,8 +255,6 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) { mockAuthClient := new(mocks.AuthMetadataServiceClient) mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil) - mockTokenCache.On("Lock").Return() - mockTokenCache.On("Unlock").Return() mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil) mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) tokenSourceProvider, err := NewTokenSourceProvider(ctx, adminServiceConfig, mockTokenCache, mockAuthClient) @@ -290,7 +288,7 @@ func Test_getPkceAuthTokenSource(t *testing.T) { assert.NoError(t, err) // populate the cache - tokenCache := cache.NewTokenCacheInMemoryProvider() + tokenCache := &cache.TokenCacheInMemoryProvider{} assert.NoError(t, tokenCache.SaveToken(&tokenData)) baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{ diff --git a/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go b/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go index 9f20fb3ef5..5c1dc5f2bd 100644 --- a/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/deviceflow/token_orchestrator_test.go @@ -23,7 +23,7 @@ import ( func TestFetchFromAuthFlow(t *testing.T) { ctx := context.Background() t.Run("fetch from auth flow", func(t *testing.T) { - tokenCache := cache.NewTokenCacheInMemoryProvider() + tokenCache := &cache.TokenCacheInMemoryProvider{} orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ @@ -97,7 +97,7 @@ func TestFetchFromAuthFlow(t *testing.T) { })) defer fakeServer.Close() - tokenCache := cache.NewTokenCacheInMemoryProvider() + tokenCache := &cache.TokenCacheInMemoryProvider{} orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ diff --git a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go index ca1973ea66..dc1c80f63a 100644 --- a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go +++ b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator_test.go @@ -16,7 +16,7 @@ import ( func TestFetchFromAuthFlow(t *testing.T) { ctx := context.Background() t.Run("fetch from auth flow", func(t *testing.T) { - tokenCache := cache.NewTokenCacheInMemoryProvider() + tokenCache := &cache.TokenCacheInMemoryProvider{} orchestrator, err := NewTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ Config: &oauth2.Config{ diff --git a/flyteidl/clients/go/admin/token_source_provider.go b/flyteidl/clients/go/admin/token_source_provider.go index 83df542082..d4f4a31a5a 100644 --- a/flyteidl/clients/go/admin/token_source_provider.go +++ b/flyteidl/clients/go/admin/token_source_provider.go @@ -188,7 +188,7 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s } secret = strings.TrimSpace(secret) if tokenCache == nil { - tokenCache = cache.NewTokenCacheInMemoryProvider() + tokenCache = &cache.TokenCacheInMemoryProvider{} } return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ @@ -227,14 +227,14 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { token, err := s.new.Token() if err != nil { - logger.Warnf(s.ctx, "failed to get token: %v", err) + logger.Warnf(s.ctx, "failed to get token: %w", err) return nil, fmt.Errorf("failed to get token: %w", err) } logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry) err = s.tokenCache.SaveToken(token) if err != nil { - logger.Warnf(s.ctx, "failed to cache token: %v", err) + logger.Warnf(s.ctx, "failed to cache token: %w", err) } return token, nil diff --git a/flyteidl/clients/go/admin/token_source_provider_test.go b/flyteidl/clients/go/admin/token_source_provider_test.go index 43d0fdd928..63fc1aa56e 100644 --- a/flyteidl/clients/go/admin/token_source_provider_test.go +++ b/flyteidl/clients/go/admin/token_source_provider_test.go @@ -127,9 +127,7 @@ func TestCustomTokenSource_Token(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { tokenCache := &tokenCacheMocks.TokenCache{} - tokenCache.OnGetToken().Return(test.token, nil).Maybe() - tokenCache.On("Lock").Return().Maybe() - tokenCache.On("Unlock").Return().Maybe() + tokenCache.OnGetToken().Return(test.token, nil).Once() provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") assert.NoError(t, err) source, err := provider.GetTokenSource(ctx) diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go index 4fd3fa476c..c4891b13ae 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go @@ -3,6 +3,7 @@ package tokenorchestrator import ( "context" "fmt" + "time" "golang.org/x/oauth2" @@ -52,21 +53,16 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex return nil, err } - if token.Valid() { - return token, nil - } - - t.TokenCache.Lock() - defer t.TokenCache.Unlock() - - token, err = t.TokenCache.GetToken() - if err != nil { - return nil, err + if !token.Valid() { + return nil, fmt.Errorf("token from cache is invalid") } - if token.Valid() { + // If token doesn't need to be refreshed, return it. + if time.Now().Before(token.Expiry.Add(-tokenRefreshGracePeriod.Duration)) { + logger.Infof(ctx, "found the token in the cache") return token, nil } + token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration) token, err = t.RefreshToken(ctx, token) if err != nil { @@ -77,8 +73,6 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex return nil, fmt.Errorf("refreshed token is invalid") } - token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration) - err = t.TokenCache.SaveToken(token) if err != nil { return nil, fmt.Errorf("failed to save token in the token cache. Error: %w", err) diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go index 0a1a9f4985..ed4afa0ff0 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go @@ -26,7 +26,7 @@ func TestRefreshTheToken(t *testing.T) { ClientID: "dummyClient", }, } - tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() + tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} orchestrator := BaseTokenOrchestrator{ ClientConfig: clientConf, TokenCache: tokenCacheProvider, @@ -58,7 +58,7 @@ func TestFetchFromCache(t *testing.T) { mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) t.Run("no token in cache", func(t *testing.T) { - tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() + tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) @@ -69,7 +69,7 @@ func TestFetchFromCache(t *testing.T) { }) t.Run("token in cache", func(t *testing.T) { - tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() + tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) fileData, _ := os.ReadFile("testdata/token.json") @@ -86,7 +86,7 @@ func TestFetchFromCache(t *testing.T) { }) t.Run("expired token in cache", func(t *testing.T) { - tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() + tokenCacheProvider := &cache.TokenCacheInMemoryProvider{} orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) fileData, _ := os.ReadFile("testdata/token.json")