Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved JWT Authorizer JWKs fetching #726

Merged
merged 6 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .schema/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,26 @@
]
]
},
"jwks_max_wait": {
"title": "Max await interval for the JWK fetch",
"type": "string",
"description": "The configuration which sets the max wait threshold when fetching new JWKs",
"default" : "1s",
"examples": [
"100ms",
"1s"
]
},
"jwks_ttl": {
"title": "JWK cache TTL configuration",
"type": "string",
"description": "The time interval for which fetched JWKs are cached",
"default" : "30s",
"examples": [
"30m",
"6h"
]
},
"scope_strategy": {
"$ref": "#/definitions/scopeStrategy"
},
Expand Down
20 changes: 20 additions & 0 deletions .schemas/authenticators.jwt.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@
"file://path/to/local/jwks.json"
]
},
"jwks_max_wait": {
"title": "Max await interval for the JWK fetch",
"type": "string",
"description": "The configuration which sets the max wait threshold when fetching new JWKs",
"default" : "1s",
"examples": [
"100ms",
"1s"
]
},
"jwks_ttl": {
"title": "JWK cache TTL configuration",
"type": "string",
"description": "The time interval for which fetched JWKs are cached",
"default" : "30s",
"examples": [
"30m",
"6h"
]
},
"scope_strategy": {
"$ref": "https://raw.githubusercontent.com/ory/oathkeeper/master/.schemas/scope_strategy.schema.json#"
},
Expand Down
20 changes: 20 additions & 0 deletions .schemas/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,26 @@
]
]
},
"jwks_max_wait": {
"title": "Max await interval for the JWK fetch",
"type": "string",
"description": "The configuration which sets the max wait threshold when fetching new JWKs",
"default" : "1s",
"examples": [
"100ms",
"1s"
]
},
"jwks_ttl": {
"title": "JWK cache TTL configuration",
"type": "string",
"description": "The time interval for which fetched JWKs are cached",
"default" : "30s",
"examples": [
"30m",
"6h"
]
},
"scope_strategy": {
"$ref": "#/definitions/scopeStrategy"
},
Expand Down
31 changes: 19 additions & 12 deletions credentials/fetcher_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ func NewFetcherDefault(l *logrusx.Logger, cancelAfter time.Duration, ttl time.Du
}

func (s *FetcherDefault) ResolveSets(ctx context.Context, locations []url.URL) ([]jose.JSONWebKeySet, error) {
if set := s.set(locations); set != nil {
if set := s.set(locations, false); set != nil {
return set, nil
}

s.fetchParallel(ctx, locations)
fetchError := s.fetchParallel(ctx, locations)

if set := s.set(locations); set != nil {
if set := s.set(locations, errors.Is(fetchError, context.DeadlineExceeded)); set != nil {
return set, nil
}

Expand All @@ -100,7 +100,7 @@ func (s *FetcherDefault) ResolveSets(ctx context.Context, locations []url.URL) (
)
}

func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) {
func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) error {
ctx, cancel := context.WithTimeout(ctx, s.cancelAfter)
defer cancel()
errs := make(chan error)
Expand All @@ -123,20 +123,22 @@ func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL)

select {
case <-ctx.Done():
s.l.Errorf("Ignoring JSON Web Keys from at least one URI because the request timed out waiting for a response.")
s.l.WithError(ctx.Err()).Errorf("Ignoring JSON Web Keys from at least one URI because the request timed out waiting for a response.")
return ctx.Err()
case <-done:
// We're done!
return nil
}
}

func (s *FetcherDefault) ResolveKey(ctx context.Context, locations []url.URL, kid string, use string) (*jose.JSONWebKey, error) {
if key := s.key(kid, locations, use); key != nil {
if key := s.key(kid, locations, use, false); key != nil {
return key, nil
}

s.fetchParallel(ctx, locations)
fetchError := s.fetchParallel(ctx, locations)

if key := s.key(kid, locations, use); key != nil {
if key := s.key(kid, locations, use, errors.Is(fetchError, context.DeadlineExceeded)); key != nil {
return key, nil
}

Expand All @@ -148,14 +150,14 @@ func (s *FetcherDefault) ResolveKey(ctx context.Context, locations []url.URL, ki
)
}

func (s *FetcherDefault) key(kid string, locations []url.URL, use string) *jose.JSONWebKey {
func (s *FetcherDefault) key(kid string, locations []url.URL, use string, staleKeyAcceptable bool) *jose.JSONWebKey {
for _, l := range locations {
s.RLock()
keys, ok1 := s.keys[l.String()]
fetchedAt, ok2 := s.fetchedAt[l.String()]
s.RUnlock()

if !ok1 || !ok2 || fetchedAt.Add(s.ttl).Before(time.Now().UTC()) {
if !ok1 || !ok2 || s.isKeyExpired(staleKeyAcceptable, fetchedAt) {
continue
}

Expand All @@ -169,15 +171,15 @@ func (s *FetcherDefault) key(kid string, locations []url.URL, use string) *jose.
return nil
}

func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet {
func (s *FetcherDefault) set(locations []url.URL, staleKeyAcceptable bool) []jose.JSONWebKeySet {
var result []jose.JSONWebKeySet
for _, l := range locations {
s.RLock()
keys, ok1 := s.keys[l.String()]
fetchedAt, ok2 := s.fetchedAt[l.String()]
s.RUnlock()

if !ok1 || !ok2 || fetchedAt.Add(s.ttl).Before(time.Now().UTC()) {
if !ok1 || !ok2 || s.isKeyExpired(staleKeyAcceptable, fetchedAt) {
continue
}

Expand All @@ -187,6 +189,11 @@ func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet {
return result
}

func (s *FetcherDefault) isKeyExpired(expiredKeyAcceptable bool, fetchedAt time.Time) bool {
return expiredKeyAcceptable == false &&
fetchedAt.Add(s.ttl).Before(time.Now().UTC())
}

func (s *FetcherDefault) resolveAll(done chan struct{}, errs chan error, locations []url.URL) {
var wg sync.WaitGroup

Expand Down
27 changes: 22 additions & 5 deletions credentials/fetcher_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,19 @@ var sets = [...]json.RawMessage{

func TestFetcherDefault(t *testing.T) {
const maxWait = time.Millisecond * 100
const JWKsTTL = maxWait * 7
const timeoutServerDelay = maxWait * 2

t.Cleanup(func() {
cloudstorage.SetCurrentTest(nil)
})

l := logrusx.New("", "", logrusx.ForceLevel(logrus.DebugLevel))
w := herodot.NewJSONWriter(l.Logger)
s := NewFetcherDefault(l, maxWait, maxWait*7)
s := NewFetcherDefault(l, maxWait, JWKsTTL)

timeOutServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
time.Sleep(maxWait * 2)
time.Sleep(timeoutServerDelay)
w.Write(rw, r, sets[0])
}))
defer timeOutServer.Close()
Expand Down Expand Up @@ -153,11 +156,25 @@ func TestFetcherDefault(t *testing.T) {
assert.True(t, check("8e884167-1300-4f58-8cc1-81af68f878a8"))
})

time.Sleep(
timeoutServerDelay +
JWKsTTL +
(time.Millisecond * 100)) // wait so the fetched key reaches ttl
// change "alg" for "c61308cc-faef-4b98-99c3-839f513ac296",
// so we are sure we get the "stale" data in `name=should find the previously fetched key if the refresh request times out`
sets[0] = json.RawMessage(`{"keys":[{"use":"sig","kty":"oct","kid":"c61308cc-faef-4b98-99c3-839f513ac296","k":"I2_YrZxll-Uq65GKjnJq4u7uNub8hG5cBvlHRz03w94","alg":"RS256"}]}`)

t.Run("name=should find the previously fetched key if the refresh request times out", func(t *testing.T) {
key, err := s.ResolveKey(context.Background(), uris, "c61308cc-faef-4b98-99c3-839f513ac296", "sig")
eroznik marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(t, err)
assert.Equal(t, "HS256", key.Algorithm)
})

t.Run("name=should fetch from s3 object storage", func(t *testing.T) {
ctx := context.Background()
cloudstorage.SetCurrentTest(t)

s := NewFetcherDefault(l, maxWait, maxWait*7)
s := NewFetcherDefault(l, maxWait, JWKsTTL)

key, err := s.ResolveKey(ctx, []url.URL{
*urlx.ParseOrPanic("s3://oathkeeper-test-bucket/path/prefix/jwks.json"),
Expand All @@ -170,7 +187,7 @@ func TestFetcherDefault(t *testing.T) {
ctx := context.Background()
cloudstorage.SetCurrentTest(t)

s := NewFetcherDefault(l, maxWait, maxWait*7)
s := NewFetcherDefault(l, maxWait, JWKsTTL)

key, err := s.ResolveKey(ctx, []url.URL{
*urlx.ParseOrPanic("gs://oathkeeper-test-bucket/path/prefix/jwks.json"),
Expand All @@ -183,7 +200,7 @@ func TestFetcherDefault(t *testing.T) {
ctx := context.Background()
cloudstorage.SetCurrentTest(t)

s := NewFetcherDefault(l, maxWait, maxWait*7)
s := NewFetcherDefault(l, maxWait, JWKsTTL)

jwkKey, err := s.ResolveKey(ctx, []url.URL{
*urlx.ParseOrPanic("azblob://path/prefix/jwks.json"),
Expand Down
6 changes: 6 additions & 0 deletions docs/docs/pipeline/authn.md
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,12 @@ JSON Web Token and tries to verify the signature of it.
JSON Web Keys from for validating the JSON Web Token. Usually something like
`https://my-keys.com/.well-known/jwks.json`. The response of that endpoint
must return a JSON Web Key Set (JWKS).
- `jwks_max_wait` (duration, optional) - The maximum time for which the JWK fetcher
should wait for the JWK request to complete. After the interval passes,
the JWK fetcher will return expired or no JWK at all. If the initial JWK request
finishes successfully, it will still refresh the cached JWKs. Defaults to "1s".
- `jwks_ttl` (duration, optional) - The duration for which fetched JWKs should be
cached internally. Defaults to "30s".
- `scope_strategy` (string, optional) - Sets the strategy to be used to
validate/match the scope. Supports "hierarchic", "exact", "wildcard", "none".
Defaults to "none".
Expand Down
20 changes: 20 additions & 0 deletions docs/docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,26 @@ authenticators:
- https://my-other-website.com/.well-known/jwks.json
- file://path/to/local/jwks.json

## jwks_max_wait ##
#
# Set this value using environment variables on
# - Linux/macOS:
# $ export AUTHENTICATORS_JWT_CONFIG_JWKS_MAX_WAIT=<value>
# - Windows Command Line (CMD):
# > set AUTHENTICATORS_JWT_CONFIG_JWKS_MAX_WAIT=<value>
#
jwks_max_wait: 1s

## jwks_ttl ##
#
# Set this value using environment variables on
# - Linux/macOS:
# $ export AUTHENTICATORS_JWT_CONFIG_JWKS_TTL=<value>
# - Windows Command Line (CMD):
# > set AUTHENTICATORS_JWT_CONFIG_JWKS_TTL=<value>
#
jwks_ttl: 30s

## target_audience ##
#
# Set this value using environment variables on
Expand Down
3 changes: 2 additions & 1 deletion driver/configuration/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/gobuffalo/packr/v2"

"github.com/ory/fosite"
"github.com/ory/x/tracing"

Expand Down Expand Up @@ -73,6 +72,8 @@ type ProviderErrorHandlers interface {
type ProviderAuthenticators interface {
AuthenticatorConfig(id string, overrides json.RawMessage, destination interface{}) error
AuthenticatorIsEnabled(id string) bool
AuthenticatorJwtJwkMaxWait() time.Duration
AuthenticatorJwtJwkTtl() time.Duration
}

type ProviderAuthorizers interface {
Expand Down
12 changes: 11 additions & 1 deletion driver/configuration/provider_viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ const (
ViperKeyAuthenticatorCookieSessionIsEnabled = "authenticators.cookie_session.enabled"

// jwt
ViperKeyAuthenticatorJWTIsEnabled = "authenticators.jwt.enabled"
ViperKeyAuthenticatorJwtIsEnabled = "authenticators.jwt.enabled"
ViperKeyAuthenticatorJwtJwkMaxWait = "authenticators.jwt.config.jwks_max_wait"
ViperKeyAuthenticatorJwtJwkTtl = "authenticators.jwt.config.jwks_ttl"

// oauth2_client_credentials
ViperKeyAuthenticatorOAuth2ClientCredentialsIsEnabled = "authenticators.oauth2_client_credentials.enabled"
Expand Down Expand Up @@ -394,6 +396,14 @@ func (v *ViperProvider) AuthenticatorConfig(id string, override json.RawMessage,
return v.PipelineConfig("authenticators", id, override, dest)
}

func (v *ViperProvider) AuthenticatorJwtJwkMaxWait() time.Duration {
return viperx.GetDuration(v.l, ViperKeyAuthenticatorJwtJwkMaxWait, time.Second)
}

func (v *ViperProvider) AuthenticatorJwtJwkTtl() time.Duration {
return viperx.GetDuration(v.l, ViperKeyAuthenticatorJwtJwkTtl, time.Second*30)
}

func (v *ViperProvider) AuthorizerIsEnabled(id string) bool {
return v.pipelineIsEnabled("authorizers", id)
}
Expand Down
3 changes: 1 addition & 2 deletions driver/registry_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package driver
import (
"context"
"sync"
"time"

"github.com/ory/oathkeeper/driver/health"
"github.com/ory/oathkeeper/pipeline"
Expand Down Expand Up @@ -208,7 +207,7 @@ func (r *RegistryMemory) DecisionHandler() *api.DecisionHandler {

func (r *RegistryMemory) CredentialsFetcher() credentials.Fetcher {
if r.credentialsFetcher == nil {
r.credentialsFetcher = credentials.NewFetcherDefault(r.Logger(), time.Second, time.Second*30)
r.credentialsFetcher = credentials.NewFetcherDefault(r.Logger(), r.c.AuthenticatorJwtJwkMaxWait(), r.c.AuthenticatorJwtJwkTtl())
}

return r.credentialsFetcher
Expand Down