diff --git a/.schema/config.schema.json b/.schema/config.schema.json index e7b8b91f48..fa4338a4fc 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -574,6 +574,26 @@ ] ] }, + "jwks_max_wait": { + "title": "Max await interval for the JWK fetch", + "type": "string", + "description": "The configuration which sets the max wait threshold when fetching new JWKs", + "default" : "1s", + "examples": [ + "100ms", + "1s" + ] + }, + "jwks_ttl": { + "title": "JWK cache TTL configuration", + "type": "string", + "description": "The time interval for which fetched JWKs are cached", + "default" : "30s", + "examples": [ + "30m", + "6h" + ] + }, "scope_strategy": { "$ref": "#/definitions/scopeStrategy" }, diff --git a/.schemas/authenticators.jwt.schema.json b/.schemas/authenticators.jwt.schema.json index 0cd4965a00..ff8fe0d108 100644 --- a/.schemas/authenticators.jwt.schema.json +++ b/.schemas/authenticators.jwt.schema.json @@ -47,6 +47,26 @@ "file://path/to/local/jwks.json" ] }, + "jwks_max_wait": { + "title": "Max await interval for the JWK fetch", + "type": "string", + "description": "The configuration which sets the max wait threshold when fetching new JWKs", + "default" : "1s", + "examples": [ + "100ms", + "1s" + ] + }, + "jwks_ttl": { + "title": "JWK cache TTL configuration", + "type": "string", + "description": "The time interval for which fetched JWKs are cached", + "default" : "30s", + "examples": [ + "30m", + "6h" + ] + }, "scope_strategy": { "$ref": "https://raw.githubusercontent.com/ory/oathkeeper/master/.schemas/scope_strategy.schema.json#" }, diff --git a/.schemas/config.schema.json b/.schemas/config.schema.json index 87c03d5157..0d521b4986 100644 --- a/.schemas/config.schema.json +++ b/.schemas/config.schema.json @@ -439,6 +439,26 @@ ] ] }, + "jwks_max_wait": { + "title": "Max await interval for the JWK fetch", + "type": "string", + "description": "The configuration which sets the max wait threshold when fetching new JWKs", + "default" : "1s", + "examples": [ + "100ms", + "1s" + ] + }, + "jwks_ttl": { + "title": "JWK cache TTL configuration", + "type": "string", + "description": "The time interval for which fetched JWKs are cached", + "default" : "30s", + "examples": [ + "30m", + "6h" + ] + }, "scope_strategy": { "$ref": "#/definitions/scopeStrategy" }, diff --git a/credentials/fetcher_default.go b/credentials/fetcher_default.go index b075a1cc66..8dfb9a9099 100644 --- a/credentials/fetcher_default.go +++ b/credentials/fetcher_default.go @@ -84,13 +84,13 @@ func NewFetcherDefault(l *logrusx.Logger, cancelAfter time.Duration, ttl time.Du } func (s *FetcherDefault) ResolveSets(ctx context.Context, locations []url.URL) ([]jose.JSONWebKeySet, error) { - if set := s.set(locations); set != nil { + if set := s.set(locations, false); set != nil { return set, nil } - s.fetchParallel(ctx, locations) + fetchError := s.fetchParallel(ctx, locations) - if set := s.set(locations); set != nil { + if set := s.set(locations, errors.Is(fetchError, context.DeadlineExceeded)); set != nil { return set, nil } @@ -100,7 +100,7 @@ func (s *FetcherDefault) ResolveSets(ctx context.Context, locations []url.URL) ( ) } -func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) { +func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) error { ctx, cancel := context.WithTimeout(ctx, s.cancelAfter) defer cancel() errs := make(chan error) @@ -123,20 +123,22 @@ func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) select { case <-ctx.Done(): - s.l.Errorf("Ignoring JSON Web Keys from at least one URI because the request timed out waiting for a response.") + s.l.WithError(ctx.Err()).Errorf("Ignoring JSON Web Keys from at least one URI because the request timed out waiting for a response.") + return ctx.Err() case <-done: // We're done! + return nil } } func (s *FetcherDefault) ResolveKey(ctx context.Context, locations []url.URL, kid string, use string) (*jose.JSONWebKey, error) { - if key := s.key(kid, locations, use); key != nil { + if key := s.key(kid, locations, use, false); key != nil { return key, nil } - s.fetchParallel(ctx, locations) + fetchError := s.fetchParallel(ctx, locations) - if key := s.key(kid, locations, use); key != nil { + if key := s.key(kid, locations, use, errors.Is(fetchError, context.DeadlineExceeded)); key != nil { return key, nil } @@ -148,14 +150,14 @@ func (s *FetcherDefault) ResolveKey(ctx context.Context, locations []url.URL, ki ) } -func (s *FetcherDefault) key(kid string, locations []url.URL, use string) *jose.JSONWebKey { +func (s *FetcherDefault) key(kid string, locations []url.URL, use string, staleKeyAcceptable bool) *jose.JSONWebKey { for _, l := range locations { s.RLock() keys, ok1 := s.keys[l.String()] fetchedAt, ok2 := s.fetchedAt[l.String()] s.RUnlock() - if !ok1 || !ok2 || fetchedAt.Add(s.ttl).Before(time.Now().UTC()) { + if !ok1 || !ok2 || s.isKeyExpired(staleKeyAcceptable, fetchedAt) { continue } @@ -169,7 +171,7 @@ func (s *FetcherDefault) key(kid string, locations []url.URL, use string) *jose. return nil } -func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet { +func (s *FetcherDefault) set(locations []url.URL, staleKeyAcceptable bool) []jose.JSONWebKeySet { var result []jose.JSONWebKeySet for _, l := range locations { s.RLock() @@ -177,7 +179,7 @@ func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet { fetchedAt, ok2 := s.fetchedAt[l.String()] s.RUnlock() - if !ok1 || !ok2 || fetchedAt.Add(s.ttl).Before(time.Now().UTC()) { + if !ok1 || !ok2 || s.isKeyExpired(staleKeyAcceptable, fetchedAt) { continue } @@ -187,6 +189,11 @@ func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet { return result } +func (s *FetcherDefault) isKeyExpired(expiredKeyAcceptable bool, fetchedAt time.Time) bool { + return expiredKeyAcceptable == false && + fetchedAt.Add(s.ttl).Before(time.Now().UTC()) +} + func (s *FetcherDefault) resolveAll(done chan struct{}, errs chan error, locations []url.URL) { var wg sync.WaitGroup diff --git a/credentials/fetcher_default_test.go b/credentials/fetcher_default_test.go index cfd0cad521..abfc940d9a 100644 --- a/credentials/fetcher_default_test.go +++ b/credentials/fetcher_default_test.go @@ -31,16 +31,19 @@ var sets = [...]json.RawMessage{ func TestFetcherDefault(t *testing.T) { const maxWait = time.Millisecond * 100 + const JWKsTTL = maxWait * 7 + const timeoutServerDelay = maxWait * 2 + t.Cleanup(func() { cloudstorage.SetCurrentTest(nil) }) l := logrusx.New("", "", logrusx.ForceLevel(logrus.DebugLevel)) w := herodot.NewJSONWriter(l.Logger) - s := NewFetcherDefault(l, maxWait, maxWait*7) + s := NewFetcherDefault(l, maxWait, JWKsTTL) timeOutServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - time.Sleep(maxWait * 2) + time.Sleep(timeoutServerDelay) w.Write(rw, r, sets[0]) })) defer timeOutServer.Close() @@ -153,11 +156,25 @@ func TestFetcherDefault(t *testing.T) { assert.True(t, check("8e884167-1300-4f58-8cc1-81af68f878a8")) }) + time.Sleep( + timeoutServerDelay + + JWKsTTL + + (time.Millisecond * 100)) // wait so the fetched key reaches ttl + // change "alg" for "c61308cc-faef-4b98-99c3-839f513ac296", + // so we are sure we get the "stale" data in `name=should find the previously fetched key if the refresh request times out` + sets[0] = json.RawMessage(`{"keys":[{"use":"sig","kty":"oct","kid":"c61308cc-faef-4b98-99c3-839f513ac296","k":"I2_YrZxll-Uq65GKjnJq4u7uNub8hG5cBvlHRz03w94","alg":"RS256"}]}`) + + t.Run("name=should find the previously fetched key if the refresh request times out", func(t *testing.T) { + key, err := s.ResolveKey(context.Background(), uris, "c61308cc-faef-4b98-99c3-839f513ac296", "sig") + require.NoError(t, err) + assert.Equal(t, "HS256", key.Algorithm) + }) + t.Run("name=should fetch from s3 object storage", func(t *testing.T) { ctx := context.Background() cloudstorage.SetCurrentTest(t) - s := NewFetcherDefault(l, maxWait, maxWait*7) + s := NewFetcherDefault(l, maxWait, JWKsTTL) key, err := s.ResolveKey(ctx, []url.URL{ *urlx.ParseOrPanic("s3://oathkeeper-test-bucket/path/prefix/jwks.json"), @@ -170,7 +187,7 @@ func TestFetcherDefault(t *testing.T) { ctx := context.Background() cloudstorage.SetCurrentTest(t) - s := NewFetcherDefault(l, maxWait, maxWait*7) + s := NewFetcherDefault(l, maxWait, JWKsTTL) key, err := s.ResolveKey(ctx, []url.URL{ *urlx.ParseOrPanic("gs://oathkeeper-test-bucket/path/prefix/jwks.json"), @@ -183,7 +200,7 @@ func TestFetcherDefault(t *testing.T) { ctx := context.Background() cloudstorage.SetCurrentTest(t) - s := NewFetcherDefault(l, maxWait, maxWait*7) + s := NewFetcherDefault(l, maxWait, JWKsTTL) jwkKey, err := s.ResolveKey(ctx, []url.URL{ *urlx.ParseOrPanic("azblob://path/prefix/jwks.json"), diff --git a/docs/docs/pipeline/authn.md b/docs/docs/pipeline/authn.md index b85f67463a..30f2b9a333 100644 --- a/docs/docs/pipeline/authn.md +++ b/docs/docs/pipeline/authn.md @@ -781,6 +781,12 @@ JSON Web Token and tries to verify the signature of it. JSON Web Keys from for validating the JSON Web Token. Usually something like `https://my-keys.com/.well-known/jwks.json`. The response of that endpoint must return a JSON Web Key Set (JWKS). +- `jwks_max_wait` (duration, optional) - The maximum time for which the JWK fetcher + should wait for the JWK request to complete. After the interval passes, + the JWK fetcher will return expired or no JWK at all. If the initial JWK request + finishes successfully, it will still refresh the cached JWKs. Defaults to "1s". +- `jwks_ttl` (duration, optional) - The duration for which fetched JWKs should be + cached internally. Defaults to "30s". - `scope_strategy` (string, optional) - Sets the strategy to be used to validate/match the scope. Supports "hierarchic", "exact", "wildcard", "none". Defaults to "none". diff --git a/docs/docs/reference/configuration.md b/docs/docs/reference/configuration.md index 9e812f42da..393c8da233 100644 --- a/docs/docs/reference/configuration.md +++ b/docs/docs/reference/configuration.md @@ -304,6 +304,26 @@ authenticators: - https://my-other-website.com/.well-known/jwks.json - file://path/to/local/jwks.json + ## jwks_max_wait ## + # + # Set this value using environment variables on + # - Linux/macOS: + # $ export AUTHENTICATORS_JWT_CONFIG_JWKS_MAX_WAIT= + # - Windows Command Line (CMD): + # > set AUTHENTICATORS_JWT_CONFIG_JWKS_MAX_WAIT= + # + jwks_max_wait: 1s + + ## jwks_ttl ## + # + # Set this value using environment variables on + # - Linux/macOS: + # $ export AUTHENTICATORS_JWT_CONFIG_JWKS_TTL= + # - Windows Command Line (CMD): + # > set AUTHENTICATORS_JWT_CONFIG_JWKS_TTL= + # + jwks_ttl: 30s + ## target_audience ## # # Set this value using environment variables on diff --git a/driver/configuration/provider.go b/driver/configuration/provider.go index 7da85efa0b..7812f3b6c5 100644 --- a/driver/configuration/provider.go +++ b/driver/configuration/provider.go @@ -6,7 +6,6 @@ import ( "time" "github.com/gobuffalo/packr/v2" - "github.com/ory/fosite" "github.com/ory/x/tracing" @@ -73,6 +72,8 @@ type ProviderErrorHandlers interface { type ProviderAuthenticators interface { AuthenticatorConfig(id string, overrides json.RawMessage, destination interface{}) error AuthenticatorIsEnabled(id string) bool + AuthenticatorJwtJwkMaxWait() time.Duration + AuthenticatorJwtJwkTtl() time.Duration } type ProviderAuthorizers interface { diff --git a/driver/configuration/provider_viper.go b/driver/configuration/provider_viper.go index 1f61d5fa0e..4a68415825 100644 --- a/driver/configuration/provider_viper.go +++ b/driver/configuration/provider_viper.go @@ -95,7 +95,9 @@ const ( ViperKeyAuthenticatorCookieSessionIsEnabled = "authenticators.cookie_session.enabled" // jwt - ViperKeyAuthenticatorJWTIsEnabled = "authenticators.jwt.enabled" + ViperKeyAuthenticatorJwtIsEnabled = "authenticators.jwt.enabled" + ViperKeyAuthenticatorJwtJwkMaxWait = "authenticators.jwt.config.jwks_max_wait" + ViperKeyAuthenticatorJwtJwkTtl = "authenticators.jwt.config.jwks_ttl" // oauth2_client_credentials ViperKeyAuthenticatorOAuth2ClientCredentialsIsEnabled = "authenticators.oauth2_client_credentials.enabled" @@ -394,6 +396,14 @@ func (v *ViperProvider) AuthenticatorConfig(id string, override json.RawMessage, return v.PipelineConfig("authenticators", id, override, dest) } +func (v *ViperProvider) AuthenticatorJwtJwkMaxWait() time.Duration { + return viperx.GetDuration(v.l, ViperKeyAuthenticatorJwtJwkMaxWait, time.Second) +} + +func (v *ViperProvider) AuthenticatorJwtJwkTtl() time.Duration { + return viperx.GetDuration(v.l, ViperKeyAuthenticatorJwtJwkTtl, time.Second*30) +} + func (v *ViperProvider) AuthorizerIsEnabled(id string) bool { return v.pipelineIsEnabled("authorizers", id) } diff --git a/driver/registry_memory.go b/driver/registry_memory.go index 0ccdada2e1..62b1fe7e8d 100644 --- a/driver/registry_memory.go +++ b/driver/registry_memory.go @@ -3,7 +3,6 @@ package driver import ( "context" "sync" - "time" "github.com/ory/oathkeeper/driver/health" "github.com/ory/oathkeeper/pipeline" @@ -208,7 +207,7 @@ func (r *RegistryMemory) DecisionHandler() *api.DecisionHandler { func (r *RegistryMemory) CredentialsFetcher() credentials.Fetcher { if r.credentialsFetcher == nil { - r.credentialsFetcher = credentials.NewFetcherDefault(r.Logger(), time.Second, time.Second*30) + r.credentialsFetcher = credentials.NewFetcherDefault(r.Logger(), r.c.AuthenticatorJwtJwkMaxWait(), r.c.AuthenticatorJwtJwkTtl()) } return r.credentialsFetcher