Skip to content

Commit

Permalink
feat: introduce token caching for client credentials authentication (#…
Browse files Browse the repository at this point in the history
…922)

Right now every request via Oathkeeper that uses client credentials
authentication requests a new access token. This can introduce a lot
of latency in the critical path of an application in case of a slow
token endpoint.

This change introduces a cache similar to the one that is used in the
introspection authentication.

Closes #870

Co-authored-by: Marlin Cremers <[email protected]>
  • Loading branch information
aeneasr and Marlinc authored Feb 14, 2022
1 parent a3b5b28 commit 9a56154
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 54 deletions.
24 changes: 24 additions & 0 deletions .schema/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,30 @@
},
"retry": {
"$ref": "#/definitions/retry"
},
"cache": {
"additionalProperties": false,
"type": "object",
"properties": {
"enabled": {
"$ref": "#/definitions/handlerSwitch"
},
"ttl": {
"type": "string",
"pattern": "^[0-9]+(ns|us|ms|s|m|h)$",
"title": "Cache Time to Live",
"description": "Can override the default behaviour of using the token exp time, and specify a set time to live for the token in the cache. If the token exp time is lower than the set value the token exp time will be used instead.",
"examples": [
"5s"
]
},
"max_tokens": {
"type": "integer",
"default": 1000,
"title": "Maximum Cached Tokens",
"description": "Max number of tokens to cache."
}
}
}
},
"required": [
Expand Down
2 changes: 1 addition & 1 deletion driver/configuration/provider_viper_public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func TestViperProvider(t *testing.T) {
})

t.Run("authenticator=oauth2_client_credentials", func(t *testing.T) {
a := authn.NewAuthenticatorOAuth2ClientCredentials(p)
a := authn.NewAuthenticatorOAuth2ClientCredentials(p, logger)
assert.True(t, p.AuthenticatorIsEnabled(a.GetID()))
require.NoError(t, a.Validate(nil))

Expand Down
2 changes: 1 addition & 1 deletion driver/registry_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ func (r *RegistryMemory) prepareAuthn() {
authn.NewAuthenticatorBearerToken(r.c),
authn.NewAuthenticatorJWT(r.c, r),
authn.NewAuthenticatorNoOp(r.c),
authn.NewAuthenticatorOAuth2ClientCredentials(r.c),
authn.NewAuthenticatorOAuth2ClientCredentials(r.c, r.Logger()),
authn.NewAuthenticatorOAuth2Introspection(r.c, r.Logger()),
authn.NewAuthenticatorUnauthorized(r.c),
}
Expand Down
2 changes: 1 addition & 1 deletion internal/httpclient/models/health_not_ready_status.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/httpclient/models/json_web_key_set.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/httpclient/models/rule.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/httpclient/models/rule_handler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/httpclient/models/upstream.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

161 changes: 136 additions & 25 deletions pipeline/authn/authenticator_oauth2_client_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package authn
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/dgraph-io/ristretto"
"golang.org/x/oauth2"

"github.com/ory/x/logrusx"

"github.com/ory/x/httpx"

"github.com/ory/oathkeeper/driver/configuration"
Expand All @@ -22,23 +27,34 @@ import (
)

type AuthenticatorOAuth2Configuration struct {
Scopes []string `json:"required_scope"`
TokenURL string `json:"token_url"`
Retry *AuthenticatorOAuth2ClientCredentialsRetryConfiguration
Scopes []string `json:"required_scope"`
TokenURL string `json:"token_url"`
Retry *AuthenticatorOAuth2ClientCredentialsRetryConfiguration `json:"retry,omitempty"`
Cache clientCredentialsCacheConfig `json:"cache"`
}

type clientCredentialsCacheConfig struct {
Enabled bool `json:"enabled"`
TTL string `json:"ttl"`
MaxTokens int `json:"max_tokens"`
}

type AuthenticatorOAuth2ClientCredentials struct {
c configuration.Provider
client *http.Client

tokenCache *ristretto.Cache
cacheTTL *time.Duration
logger *logrusx.Logger
}

type AuthenticatorOAuth2ClientCredentialsRetryConfiguration struct {
Timeout string `json:"max_delay"`
MaxWait string `json:"give_up_after"`
}

func NewAuthenticatorOAuth2ClientCredentials(c configuration.Provider) *AuthenticatorOAuth2ClientCredentials {
return &AuthenticatorOAuth2ClientCredentials{c: c}
func NewAuthenticatorOAuth2ClientCredentials(c configuration.Provider, logger *logrusx.Logger) *AuthenticatorOAuth2ClientCredentials {
return &AuthenticatorOAuth2ClientCredentials{c: c, logger: logger}
}

func (a *AuthenticatorOAuth2ClientCredentials) GetID() string {
Expand Down Expand Up @@ -86,9 +102,96 @@ func (a *AuthenticatorOAuth2ClientCredentials) Config(config json.RawMessage) (*
timeout := time.Millisecond * duration
a.client = httpx.NewResilientClientLatencyToleranceConfigurable(nil, timeout, maxWait)

if c.Cache.TTL != "" {
cacheTTL, err := time.ParseDuration(c.Cache.TTL)
if err != nil {
return nil, err
}
a.cacheTTL = &cacheTTL
}

if a.tokenCache == nil {
maxTokens := int64(c.Cache.MaxTokens)
if maxTokens == 0 {
maxTokens = 1000
}
a.logger.Debugf("Creating cache with max tokens: %d", maxTokens)
cache, err := ristretto.NewCache(&ristretto.Config{
// This will hold about 1000 unique mutation responses.
NumCounters: 10 * maxTokens,
// Allocate a maximum amount of tokens to cache
MaxCost: maxTokens,
// This is a best-practice value.
BufferItems: 64,
// Use a static cost of 1, so we can limit the amount of tokens that can be stored
Cost: func(value interface{}) int64 {
return 1
},
})
if err != nil {
return nil, err
}

a.tokenCache = cache
}

return &c, nil
}

func clientCredentialsConfigToKey(cc clientcredentials.Config) string {
return fmt.Sprintf("%s|%s|%s:%s", cc.TokenURL, strings.Join(cc.Scopes, " "), cc.ClientID, cc.ClientSecret)
}

func (a *AuthenticatorOAuth2ClientCredentials) tokenFromCache(config *AuthenticatorOAuth2Configuration, clientCredentials clientcredentials.Config) *oauth2.Token {
if !config.Cache.Enabled {
return nil
}

item, found := a.tokenCache.Get(clientCredentialsConfigToKey(clientCredentials))
if !found {
return nil
}

i, ok := item.([]byte)
if !ok {
return nil
}

var v oauth2.Token
if err := json.Unmarshal(i, &v); err != nil {
return nil
}
return &v
}

func (a *AuthenticatorOAuth2ClientCredentials) tokenToCache(config *AuthenticatorOAuth2Configuration, clientCredentials clientcredentials.Config, token oauth2.Token) {
if !config.Cache.Enabled {
return
}

key := clientCredentialsConfigToKey(clientCredentials)

if v, err := json.Marshal(token); err != nil {
return
} else if a.cacheTTL != nil {
// Allow up-to at most the cache TTL, otherwise use token expiry
ttl := token.Expiry.Sub(time.Now())
if ttl > *a.cacheTTL {
ttl = *a.cacheTTL
}

a.tokenCache.SetWithTTL(key, v, 1, ttl)
} else {
// If token has no expiry apply the same to the cache
ttl := time.Duration(0)
if !token.Expiry.IsZero() {
ttl = token.Expiry.Sub(time.Now())
}

a.tokenCache.SetWithTTL(key, v, 1, ttl)
}
}

func (a *AuthenticatorOAuth2ClientCredentials) Authenticate(r *http.Request, session *AuthenticationSession, config json.RawMessage, _ pipeline.Rule) error {
cf, err := a.Config(config)
if err != nil {
Expand All @@ -110,37 +213,45 @@ func (a *AuthenticatorOAuth2ClientCredentials) Authenticate(r *http.Request, ses
return errors.Wrapf(helper.ErrUnauthorized, err.Error())
}

c := &clientcredentials.Config{
c := clientcredentials.Config{
ClientID: user,
ClientSecret: password,
Scopes: cf.Scopes,
TokenURL: cf.TokenURL,
AuthStyle: oauth2.AuthStyleInHeader,
}

token, err := c.Token(context.WithValue(
r.Context(),
oauth2.HTTPClient,
c.Client,
))
token := a.tokenFromCache(cf, c)

if err != nil {
if rErr, ok := err.(*oauth2.RetrieveError); ok {
switch httpStatusCode := rErr.Response.StatusCode; httpStatusCode {
case http.StatusServiceUnavailable:
if token == nil {
t, err := c.Token(context.WithValue(
r.Context(),
oauth2.HTTPClient,
c.Client,
))

if err != nil {
if rErr, ok := err.(*oauth2.RetrieveError); ok {
switch httpStatusCode := rErr.Response.StatusCode; httpStatusCode {
case http.StatusServiceUnavailable:
return errors.Wrapf(helper.ErrUpstreamServiceNotAvailable, err.Error())
case http.StatusInternalServerError:
return errors.Wrapf(helper.ErrUpstreamServiceInternalServerError, err.Error())
case http.StatusGatewayTimeout:
return errors.Wrapf(helper.ErrUpstreamServiceTimeout, err.Error())
case http.StatusNotFound:
return errors.Wrapf(helper.ErrUpstreamServiceNotFound, err.Error())
default:
return errors.Wrapf(helper.ErrUnauthorized, err.Error())
}
} else {
return errors.Wrapf(helper.ErrUpstreamServiceNotAvailable, err.Error())
case http.StatusInternalServerError:
return errors.Wrapf(helper.ErrUpstreamServiceInternalServerError, err.Error())
case http.StatusGatewayTimeout:
return errors.Wrapf(helper.ErrUpstreamServiceTimeout, err.Error())
case http.StatusNotFound:
return errors.Wrapf(helper.ErrUpstreamServiceNotFound, err.Error())
default:
return errors.Wrapf(helper.ErrUnauthorized, err.Error())
}
} else {
return errors.Wrapf(helper.ErrUpstreamServiceNotAvailable, err.Error())
}

token = t

a.tokenToCache(cf, c, *token)
}

if token.AccessToken == "" {
Expand Down
Loading

0 comments on commit 9a56154

Please sign in to comment.