diff --git a/.schema/config.schema.json b/.schema/config.schema.json index f3385ebdba..8fb2dff910 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -861,6 +861,12 @@ "examples": [ "5s" ] + }, + "max_cost": { + "type": "integer", + "default": 1000, + "title": "Max Cost", + "description": "Max number of tokens to cache." } } } diff --git a/driver/configuration/provider_viper_public_test.go b/driver/configuration/provider_viper_public_test.go index 406a50bf91..fec78a19e8 100644 --- a/driver/configuration/provider_viper_public_test.go +++ b/driver/configuration/provider_viper_public_test.go @@ -52,7 +52,7 @@ func TestPipelineConfig(t *testing.T) { p := setup(t) require.NoError(t, p.PipelineConfig("authenticators", "oauth2_introspection", nil, &res)) - assert.JSONEq(t, `{"cache":{"enabled":false},"introspection_url":"https://override/path","pre_authorization":{"client_id":"some_id","client_secret":"some_secret","enabled":true,"audience":"some_audience","scope":["foo","bar"],"token_url":"https://my-website.com/oauth2/token"},"retry":{"max_delay":"100ms", "give_up_after":"1s"},"scope_strategy":"exact"}`, string(res), "%s", res) + assert.JSONEq(t, `{"cache":{"enabled":false, "max_cost":1000},"introspection_url":"https://override/path","pre_authorization":{"client_id":"some_id","client_secret":"some_secret","enabled":true,"audience":"some_audience","scope":["foo","bar"],"token_url":"https://my-website.com/oauth2/token"},"retry":{"max_delay":"100ms", "give_up_after":"1s"},"scope_strategy":"exact"}`, string(res), "%s", res) // Cleanup require.NoError(t, os.Setenv("AUTHENTICATORS_OAUTH2_INTROSPECTION_CONFIG_INTROSPECTION_URL", "")) @@ -296,7 +296,7 @@ func TestViperProvider(t *testing.T) { }) t.Run("authenticator=oauth2_introspection", func(t *testing.T) { - a := authn.NewAuthenticatorOAuth2Introspection(p) + a := authn.NewAuthenticatorOAuth2Introspection(p, logger) assert.True(t, p.AuthenticatorIsEnabled(a.GetID())) require.NoError(t, a.Validate(nil)) @@ -431,7 +431,7 @@ func TestAuthenticatorOAuth2TokenIntrospectionPreAuthorization(t *testing.T) { {enabled: true, id: "a", secret: "b", turl: "https://some-url", err: false}, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - a := authn.NewAuthenticatorOAuth2Introspection(v) + a := authn.NewAuthenticatorOAuth2Introspection(v, logrusx.New("", "")) config, err := a.Config(json.RawMessage(fmt.Sprintf(`{ "pre_authorization": { diff --git a/driver/registry_memory.go b/driver/registry_memory.go index 88495be87f..0ccdada2e1 100644 --- a/driver/registry_memory.go +++ b/driver/registry_memory.go @@ -371,7 +371,7 @@ func (r *RegistryMemory) prepareAuthn() { authn.NewAuthenticatorJWT(r.c, r), authn.NewAuthenticatorNoOp(r.c), authn.NewAuthenticatorOAuth2ClientCredentials(r.c), - authn.NewAuthenticatorOAuth2Introspection(r.c), + authn.NewAuthenticatorOAuth2Introspection(r.c, r.Logger()), authn.NewAuthenticatorUnauthorized(r.c), } diff --git a/pipeline/authn/authenticator_oauth2_introspection.go b/pipeline/authn/authenticator_oauth2_introspection.go index 9e1d46da05..f4c8fe2c31 100644 --- a/pipeline/authn/authenticator_oauth2_introspection.go +++ b/pipeline/authn/authenticator_oauth2_introspection.go @@ -19,6 +19,7 @@ import ( "github.com/ory/go-convenience/stringslice" "github.com/ory/x/httpx" + "github.com/ory/x/logrusx" "github.com/ory/oathkeeper/driver/configuration" "github.com/ory/oathkeeper/helper" @@ -55,6 +56,7 @@ type AuthenticatorOAuth2IntrospectionRetryConfiguration struct { type cacheConfig struct { Enabled bool `json:"enabled"` TTL string `json:"ttl"` + MaxCost int `json:"max_cost"` } type AuthenticatorOAuth2Introspection struct { @@ -64,19 +66,12 @@ type AuthenticatorOAuth2Introspection struct { tokenCache *ristretto.Cache cacheTTL *time.Duration + logger *logrusx.Logger } -func NewAuthenticatorOAuth2Introspection(c configuration.Provider) *AuthenticatorOAuth2Introspection { +func NewAuthenticatorOAuth2Introspection(c configuration.Provider, logger *logrusx.Logger) *AuthenticatorOAuth2Introspection { var rt http.RoundTripper - cache, _ := ristretto.NewCache(&ristretto.Config{ - // This will hold about 1000 unique mutation responses. - NumCounters: 10000, - // Allocate a max of 32MB - MaxCost: 1 << 25, - // This is a best-practice value. - BufferItems: 64, - }) - return &AuthenticatorOAuth2Introspection{c: c, client: httpx.NewResilientClientLatencyToleranceSmall(rt), tokenCache: cache} + return &AuthenticatorOAuth2Introspection{c: c, client: httpx.NewResilientClientLatencyToleranceSmall(rt), logger: logger} } func (a *AuthenticatorOAuth2Introspection) GetID() string { @@ -123,9 +118,9 @@ func (a *AuthenticatorOAuth2Introspection) tokenToCache(config *AuthenticatorOAu } if a.cacheTTL != nil { - a.tokenCache.SetWithTTL(token, i, 0, *a.cacheTTL) + a.tokenCache.SetWithTTL(token, i, 1, *a.cacheTTL) } else { - a.tokenCache.Set(token, i, 0) + a.tokenCache.Set(token, i, 1) } } @@ -312,5 +307,19 @@ func (a *AuthenticatorOAuth2Introspection) Config(config json.RawMessage) (*Auth a.cacheTTL = &cacheTTL } + if a.tokenCache == nil { + a.logger.Debugf("Creating cache with max cost: %d", c.Cache.MaxCost) + cache, _ := ristretto.NewCache(&ristretto.Config{ + // This will hold about 1000 unique mutation responses. + NumCounters: 10000, + // Allocate a max + MaxCost: int64(c.Cache.MaxCost), + // This is a best-practice value. + BufferItems: 64, + }) + + a.tokenCache = cache + } + return &c, nil }