diff --git a/flyteidl/clients/go/admin/config.go b/flyteidl/clients/go/admin/config.go index c57a5c0078c..623622d49b4 100644 --- a/flyteidl/clients/go/admin/config.go +++ b/flyteidl/clients/go/admin/config.go @@ -45,6 +45,7 @@ type Config struct { PerRetryTimeout config.Duration `json:"perRetryTimeout" pflag:",gRPC per retry timeout"` MaxRetries int `json:"maxRetries" pflag:",Max number of gRPC retries"` AuthType AuthType `json:"authType" pflag:"-,Type of OAuth2 flow used for communicating with admin."` + TokenRefreshWindow config.Duration `json:"tokenRefreshWindow" pflag:",Max duration between token refresh attempt and token expiry."` // Deprecated: settings will be discovered dynamically DeprecatedUseAuth bool `json:"useAuth" pflag:",Deprecated: Auth will be enabled/disabled based on admin's dynamically discovered information."` ClientID string `json:"clientId" pflag:",Client ID"` @@ -81,6 +82,7 @@ var ( TokenRefreshGracePeriod: config.Duration{Duration: 5 * time.Minute}, BrowserSessionTimeout: config.Duration{Duration: 15 * time.Second}, }, + TokenRefreshWindow: config.Duration{Duration: 0}, } configSection = config.MustRegisterSectionWithUpdates(configSectionKey, &defaultConfig, func(ctx context.Context, newValue config.Config) { diff --git a/flyteidl/clients/go/admin/config_flags.go b/flyteidl/clients/go/admin/config_flags.go index bb13d674e64..6d277e32636 100755 --- a/flyteidl/clients/go/admin/config_flags.go +++ b/flyteidl/clients/go/admin/config_flags.go @@ -57,15 +57,16 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "maxBackoffDelay"), defaultConfig.MaxBackoffDelay.String(), "Max delay for grpc backoff") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "perRetryTimeout"), defaultConfig.PerRetryTimeout.String(), "gRPC per retry timeout") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "maxRetries"), defaultConfig.MaxRetries, "Max number of gRPC retries") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenRefreshWindow"), defaultConfig.TokenRefreshWindow.String(), "Max duration between token refresh attempt and token expiry.") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "useAuth"), defaultConfig.DeprecatedUseAuth, "Deprecated: Auth will be enabled/disabled based on admin's dynamically discovered information.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientId"), defaultConfig.ClientID, "Client ID") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientSecretLocation"), defaultConfig.ClientSecretLocation, "File containing the client secret") - cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "scopes"), defaultConfig.Scopes, "List of scopes to request") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "scopes"), []string{}, "List of scopes to request") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationServerUrl"), defaultConfig.DeprecatedAuthorizationServerURL, "This is the URL to your IdP's authorization server. It'll default to Endpoint") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenUrl"), defaultConfig.TokenURL, "OPTIONAL: Your IdP's token endpoint. It'll be discovered from flyte admin's OAuth Metadata endpoint if not provided.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationHeader"), defaultConfig.DeprecatedAuthorizationHeader, "Custom metadata header to pass JWT") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pkceConfig.timeout"), defaultConfig.PkceConfig.BrowserSessionTimeout.String(), "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pkceConfig.refreshTime"), defaultConfig.PkceConfig.TokenRefreshGracePeriod.String(), "") - cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "command"), defaultConfig.Command, "Command for external authentication token generation") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "command"), []string{}, "Command for external authentication token generation") return cmdFlags } diff --git a/flyteidl/clients/go/admin/config_flags_test.go b/flyteidl/clients/go/admin/config_flags_test.go index d6734406d70..fa4c192c220 100755 --- a/flyteidl/clients/go/admin/config_flags_test.go +++ b/flyteidl/clients/go/admin/config_flags_test.go @@ -197,6 +197,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_tokenRefreshWindow", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.TokenRefreshWindow.String() + + cmdFlags.Set("tokenRefreshWindow", testValue) + if vString, err := cmdFlags.GetString("tokenRefreshWindow"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.TokenRefreshWindow) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_useAuth", func(t *testing.T) { t.Run("Override", func(t *testing.T) { @@ -242,7 +256,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Test_scopes", func(t *testing.T) { t.Run("Override", func(t *testing.T) { - testValue := join_Config(defaultConfig.Scopes, ",") + testValue := join_Config("1,1", ",") cmdFlags.Set("scopes", testValue) if vStringSlice, err := cmdFlags.GetStringSlice("scopes"); err == nil { @@ -326,7 +340,7 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Test_command", func(t *testing.T) { t.Run("Override", func(t *testing.T) { - testValue := join_Config(defaultConfig.Command, ",") + testValue := join_Config("1,1", ",") cmdFlags.Set("command", testValue) if vStringSlice, err := cmdFlags.GetStringSlice("command"); err == nil { diff --git a/flyteidl/clients/go/admin/token_source_provider.go b/flyteidl/clients/go/admin/token_source_provider.go index 5d78893e124..aab1f0948b5 100644 --- a/flyteidl/clients/go/admin/token_source_provider.go +++ b/flyteidl/clients/go/admin/token_source_provider.go @@ -5,6 +5,10 @@ import ( "fmt" "io/ioutil" "strings" + "sync" + "time" + + "k8s.io/apimachinery/pkg/util/wait" "github.com/flyteorg/flyteidl/clients/go/admin/externalprocess" @@ -124,7 +128,8 @@ func GetPKCEAuthTokenSource(ctx context.Context, tokenOrchestrator pkce.TokenOrc } type ClientCredentialsTokenSourceProvider struct { - ccConfig clientcredentials.Config + ccConfig clientcredentials.Config + TokenRefreshWindow time.Duration } func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, @@ -141,15 +146,67 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, if len(scopes) == 0 { scopes = clientMetadata.Scopes } - return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ ClientID: cfg.ClientID, ClientSecret: secret, TokenURL: tokenURL, - Scopes: scopes}}, nil + Scopes: scopes}, + TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil } func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { + if p.TokenRefreshWindow > 0 { + source := p.ccConfig.TokenSource(ctx) + return &customTokenSource{ + new: source, + mu: sync.Mutex{}, + t: nil, + tokenRefreshWindow: p.TokenRefreshWindow, + }, nil + } return p.ccConfig.TokenSource(ctx), nil } + +type customTokenSource struct { + new oauth2.TokenSource + mu sync.Mutex // guards everything else + t *oauth2.Token + refreshTime time.Time + failedToRefresh bool + tokenRefreshWindow time.Duration +} + +func (s *customTokenSource) Token() (*oauth2.Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.t.Valid() { + if time.Now().After(s.refreshTime) && !s.failedToRefresh { + t, err := s.new.Token() + if err != nil { + s.failedToRefresh = true // don't try to refresh again before expiry + return s.t, nil + } + s.t = t + s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow)) + s.failedToRefresh = false + return s.t, nil + } + return s.t, nil + } + t, err := s.new.Token() + if err != nil { + return nil, err + } + s.t = t + s.failedToRefresh = false + s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow)) + return t, nil +} + +// Get random duration between 0 and maxDuration +func getRandomDuration(maxDuration time.Duration) time.Duration { + // d is 1.0 to 2.0 times maxDuration + d := wait.Jitter(maxDuration, 1) + return d - maxDuration +}