Skip to content

Commit

Permalink
Add custom token source that allows preemptive token refresh (#262)
Browse files Browse the repository at this point in the history
* Add custom token source that allows preemptive token refresh
Signed-off-by: Sean Lin <[email protected]>

* Switch to apimachinery jitter
Signed-off-by: Sean Lin <[email protected]>

* Switch back to max because min doesnt make sense
Signed-off-by: Sean Lin <[email protected]>

* fix lint
Signed-off-by: Sean Lin <[email protected]>

* goimport
Signed-off-by: Sean Lin <[email protected]>

* minor fix
Signed-off-by: Sean Lin <[email protected]>

* Rename and trim config
Signed-off-by: Sean Lin <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
mayitbeegh authored and eapolinario committed Sep 13, 2023
1 parent 463eaec commit 43f2637
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
2 changes: 2 additions & 0 deletions flyteidl/clients/go/admin/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type Config struct {
PerRetryTimeout config.Duration `json:"perRetryTimeout" pflag:",gRPC per retry timeout"`
MaxRetries int `json:"maxRetries" pflag:",Max number of gRPC retries"`
AuthType AuthType `json:"authType" pflag:"-,Type of OAuth2 flow used for communicating with admin."`
TokenRefreshWindow config.Duration `json:"tokenRefreshWindow" pflag:",Max duration between token refresh attempt and token expiry."`
// Deprecated: settings will be discovered dynamically
DeprecatedUseAuth bool `json:"useAuth" pflag:",Deprecated: Auth will be enabled/disabled based on admin's dynamically discovered information."`
ClientID string `json:"clientId" pflag:",Client ID"`
Expand Down Expand Up @@ -81,6 +82,7 @@ var (
TokenRefreshGracePeriod: config.Duration{Duration: 5 * time.Minute},
BrowserSessionTimeout: config.Duration{Duration: 15 * time.Second},
},
TokenRefreshWindow: config.Duration{Duration: 0},
}

configSection = config.MustRegisterSectionWithUpdates(configSectionKey, &defaultConfig, func(ctx context.Context, newValue config.Config) {
Expand Down
5 changes: 3 additions & 2 deletions flyteidl/clients/go/admin/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 16 additions & 2 deletions flyteidl/clients/go/admin/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 60 additions & 3 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"fmt"
"io/ioutil"
"strings"
"sync"
"time"

"k8s.io/apimachinery/pkg/util/wait"

"github.com/flyteorg/flyteidl/clients/go/admin/externalprocess"

Expand Down Expand Up @@ -124,7 +128,8 @@ func GetPKCEAuthTokenSource(ctx context.Context, tokenOrchestrator pkce.TokenOrc
}

type ClientCredentialsTokenSourceProvider struct {
ccConfig clientcredentials.Config
ccConfig clientcredentials.Config
TokenRefreshWindow time.Duration
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config,
Expand All @@ -141,15 +146,67 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config,
if len(scopes) == 0 {
scopes = clientMetadata.Scopes
}

return ClientCredentialsTokenSourceProvider{
ccConfig: clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: secret,
TokenURL: tokenURL,
Scopes: scopes}}, nil
Scopes: scopes},
TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
if p.TokenRefreshWindow > 0 {
source := p.ccConfig.TokenSource(ctx)
return &customTokenSource{
new: source,
mu: sync.Mutex{},
t: nil,
tokenRefreshWindow: p.TokenRefreshWindow,
}, nil
}
return p.ccConfig.TokenSource(ctx), nil
}

type customTokenSource struct {
new oauth2.TokenSource
mu sync.Mutex // guards everything else
t *oauth2.Token
refreshTime time.Time
failedToRefresh bool
tokenRefreshWindow time.Duration
}

func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t.Valid() {
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
t, err := s.new.Token()
if err != nil {
s.failedToRefresh = true // don't try to refresh again before expiry
return s.t, nil
}
s.t = t
s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
s.failedToRefresh = false
return s.t, nil
}
return s.t, nil
}
t, err := s.new.Token()
if err != nil {
return nil, err
}
s.t = t
s.failedToRefresh = false
s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return t, nil
}

// Get random duration between 0 and maxDuration
func getRandomDuration(maxDuration time.Duration) time.Duration {
// d is 1.0 to 2.0 times maxDuration
d := wait.Jitter(maxDuration, 1)
return d - maxDuration
}

0 comments on commit 43f2637

Please sign in to comment.