Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved JWT Authorizer JWKs fetching #726

Merged
merged 6 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .schema/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,26 @@
]
]
},
"jwk_max_wait": {
eroznik marked this conversation as resolved.
Show resolved Hide resolved
"title": "Max await interval for the JWK fetch",
"type": "string",
"description": "The configuration which sets the max wait threshold when fetching new JWKs",
"default" : "1s",
"examples": [
"100ms",
"1s"
]
},
"jwk_ttl": {
"title": "JWK cache TTL configuration",
"type": "string",
"description": "The time interval for which fetched JWKs are cached",
"default" : "30s",
"examples": [
"30m",
"6h"
]
},
"scope_strategy": {
"$ref": "#/definitions/scopeStrategy"
},
Expand Down
20 changes: 20 additions & 0 deletions .schemas/authenticators.jwt.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@
"file://path/to/local/jwks.json"
]
},
"jwk_max_wait": {
"title": "Max await interval for the JWK fetch",
"type": "string",
"description": "The configuration which sets the max wait threshold when fetching new JWKs",
"default" : "1s",
"examples": [
"100ms",
"1s"
]
},
"jwk_ttl": {
"title": "JWK cache TTL configuration",
"type": "string",
"description": "The time interval for which fetched JWKs are cached",
"default" : "30s",
"examples": [
"30m",
"6h"
]
},
"scope_strategy": {
"$ref": "https://raw.githubusercontent.com/ory/oathkeeper/master/.schemas/scope_strategy.schema.json#"
},
Expand Down
20 changes: 20 additions & 0 deletions .schemas/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,26 @@
]
]
},
"jwk_max_wait": {
"title": "Max await interval for the JWK fetch",
"type": "string",
"description": "The configuration which sets the max wait threshold when fetching new JWKs",
"default" : "1s",
"examples": [
"100ms",
"1s"
]
},
"jwk_ttl": {
"title": "JWK cache TTL configuration",
"type": "string",
"description": "The time interval for which fetched JWKs are cached",
"default" : "30s",
"examples": [
"30m",
"6h"
]
},
"scope_strategy": {
"$ref": "#/definitions/scopeStrategy"
},
Expand Down
33 changes: 22 additions & 11 deletions credentials/fetcher_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ type reasoner interface {
Reason() string
}

type fetchResult struct {
finishedWithTimeout bool
}

var _ Fetcher = new(FetcherDefault)

type FetcherDefault struct {
Expand Down Expand Up @@ -84,13 +88,13 @@ func NewFetcherDefault(l *logrusx.Logger, cancelAfter time.Duration, ttl time.Du
}

func (s *FetcherDefault) ResolveSets(ctx context.Context, locations []url.URL) ([]jose.JSONWebKeySet, error) {
if set := s.set(locations); set != nil {
if set := s.set(locations, false); set != nil {
return set, nil
}

s.fetchParallel(ctx, locations)
fetchResult := s.fetchParallel(ctx, locations)

if set := s.set(locations); set != nil {
if set := s.set(locations, fetchResult.finishedWithTimeout); set != nil {
return set, nil
}

Expand All @@ -100,7 +104,7 @@ func (s *FetcherDefault) ResolveSets(ctx context.Context, locations []url.URL) (
)
}

func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) {
func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL) fetchResult {
ctx, cancel := context.WithTimeout(ctx, s.cancelAfter)
defer cancel()
errs := make(chan error)
Expand All @@ -124,19 +128,21 @@ func (s *FetcherDefault) fetchParallel(ctx context.Context, locations []url.URL)
select {
case <-ctx.Done():
s.l.Errorf("Ignoring JSON Web Keys from at least one URI because the request timed out waiting for a response.")
return fetchResult{true}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe return an error here instead? And then we check for errors in the caller code? That would be a bit more idiomatic go code :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the honest feedback, I really appreciate comments like this - I am not that proficient with Go yet :)
Changed to return error, is now ok - or should I even improve the implementation with a custom error?

case <-done:
// We're done!
return fetchResult{false}
}
}

func (s *FetcherDefault) ResolveKey(ctx context.Context, locations []url.URL, kid string, use string) (*jose.JSONWebKey, error) {
if key := s.key(kid, locations, use); key != nil {
if key := s.key(kid, locations, use, false); key != nil {
return key, nil
}

s.fetchParallel(ctx, locations)
fetchResult := s.fetchParallel(ctx, locations)

if key := s.key(kid, locations, use); key != nil {
if key := s.key(kid, locations, use, fetchResult.finishedWithTimeout); key != nil {
return key, nil
}

Expand All @@ -148,14 +154,14 @@ func (s *FetcherDefault) ResolveKey(ctx context.Context, locations []url.URL, ki
)
}

func (s *FetcherDefault) key(kid string, locations []url.URL, use string) *jose.JSONWebKey {
func (s *FetcherDefault) key(kid string, locations []url.URL, use string, staleKeyAcceptable bool) *jose.JSONWebKey {
for _, l := range locations {
s.RLock()
keys, ok1 := s.keys[l.String()]
fetchedAt, ok2 := s.fetchedAt[l.String()]
s.RUnlock()

if !ok1 || !ok2 || fetchedAt.Add(s.ttl).Before(time.Now().UTC()) {
if !ok1 || !ok2 || s.isKeyExpired(staleKeyAcceptable, fetchedAt) {
continue
}

Expand All @@ -169,15 +175,15 @@ func (s *FetcherDefault) key(kid string, locations []url.URL, use string) *jose.
return nil
}

func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet {
func (s *FetcherDefault) set(locations []url.URL, staleKeyAcceptable bool) []jose.JSONWebKeySet {
var result []jose.JSONWebKeySet
for _, l := range locations {
s.RLock()
keys, ok1 := s.keys[l.String()]
fetchedAt, ok2 := s.fetchedAt[l.String()]
s.RUnlock()

if !ok1 || !ok2 || fetchedAt.Add(s.ttl).Before(time.Now().UTC()) {
if !ok1 || !ok2 || s.isKeyExpired(staleKeyAcceptable, fetchedAt) {
continue
}

Expand All @@ -187,6 +193,11 @@ func (s *FetcherDefault) set(locations []url.URL) []jose.JSONWebKeySet {
return result
}

func (s *FetcherDefault) isKeyExpired(expiredKeyAcceptable bool, fetchedAt time.Time) bool {
return expiredKeyAcceptable == false &&
fetchedAt.Add(s.ttl).Before(time.Now().UTC())
}

func (s *FetcherDefault) resolveAll(done chan struct{}, errs chan error, locations []url.URL) {
var wg sync.WaitGroup

Expand Down
8 changes: 8 additions & 0 deletions credentials/fetcher_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ func TestFetcherDefault(t *testing.T) {
assert.True(t, check("8e884167-1300-4f58-8cc1-81af68f878a8"))
})

time.Sleep(maxWait * 7) // wait so the fetched key reaches ttl

t.Run("name=should find the previously fetched key if the refresh request times out", func(t *testing.T) {
key, err := s.ResolveKey(context.Background(), uris, "c61308cc-faef-4b98-99c3-839f513ac296", "sig")
eroznik marked this conversation as resolved.
Show resolved Hide resolved
require.NoError(t, err)
assert.Equal(t, "c61308cc-faef-4b98-99c3-839f513ac296", key.KeyID)
})

t.Run("name=should fetch from s3 object storage", func(t *testing.T) {
ctx := context.Background()
cloudstorage.SetCurrentTest(t)
Expand Down
6 changes: 6 additions & 0 deletions docs/docs/pipeline/authn.md
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,12 @@ JSON Web Token and tries to verify the signature of it.
JSON Web Keys from for validating the JSON Web Token. Usually something like
`https://my-keys.com/.well-known/jwks.json`. The response of that endpoint
must return a JSON Web Key Set (JWKS).
- `jwk_max_wait` (duration, optional) - The maximum time for which the JWK fetcher
should wait for the JWK request to complete. After the interval passes,
the JWK fetcher will return expired or no JWK at all. If the initial JWK request
finishes successfully, it will still refresh the cached JWKs. Defaults to "1s".
- `jwk_ttl` (duration, optional) - The duration for which fetched JWKs should be
cached internally. Defaults to "30s".
- `scope_strategy` (string, optional) - Sets the strategy to be used to
validate/match the scope. Supports "hierarchic", "exact", "wildcard", "none".
Defaults to "none".
Expand Down
20 changes: 20 additions & 0 deletions docs/docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,26 @@ authenticators:
- https://my-other-website.com/.well-known/jwks.json
- file://path/to/local/jwks.json

## jwk_max_wait ##
#
# Set this value using environment variables on
# - Linux/macOS:
# $ export AUTHENTICATORS_JWT_CONFIG_JWK_MAX_WAIT=<value>
# - Windows Command Line (CMD):
# > set AUTHENTICATORS_JWT_CONFIG_JWK_MAX_WAIT=<value>
#
jwk_max_wait: 1s

## jwk_ttl ##
#
# Set this value using environment variables on
# - Linux/macOS:
# $ export AUTHENTICATORS_JWT_CONFIG_JWK_TTL=<value>
# - Windows Command Line (CMD):
# > set AUTHENTICATORS_JWT_CONFIG_JWK_TTL=<value>
#
jwk_ttl: 30s

## target_audience ##
#
# Set this value using environment variables on
Expand Down
3 changes: 2 additions & 1 deletion driver/configuration/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/gobuffalo/packr/v2"

"github.com/ory/fosite"
"github.com/ory/x/tracing"

Expand Down Expand Up @@ -73,6 +72,8 @@ type ProviderErrorHandlers interface {
type ProviderAuthenticators interface {
AuthenticatorConfig(id string, overrides json.RawMessage, destination interface{}) error
AuthenticatorIsEnabled(id string) bool
AuthenticatorJwtJwkMaxWait() time.Duration
AuthenticatorJwtJwkTtl() time.Duration
}

type ProviderAuthorizers interface {
Expand Down
12 changes: 11 additions & 1 deletion driver/configuration/provider_viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ const (
ViperKeyAuthenticatorCookieSessionIsEnabled = "authenticators.cookie_session.enabled"

// jwt
ViperKeyAuthenticatorJWTIsEnabled = "authenticators.jwt.enabled"
ViperKeyAuthenticatorJwtIsEnabled = "authenticators.jwt.enabled"
ViperKeyAuthenticatorJwtJwkMaxWait = "authenticators.jwt.config.jwk_max_wait"
ViperKeyAuthenticatorJwtJwkTtl = "authenticators.jwt.config.jwk_ttl"

// oauth2_client_credentials
ViperKeyAuthenticatorOAuth2ClientCredentialsIsEnabled = "authenticators.oauth2_client_credentials.enabled"
Expand Down Expand Up @@ -394,6 +396,14 @@ func (v *ViperProvider) AuthenticatorConfig(id string, override json.RawMessage,
return v.PipelineConfig("authenticators", id, override, dest)
}

func (v *ViperProvider) AuthenticatorJwtJwkMaxWait() time.Duration {
return viperx.GetDuration(v.l, ViperKeyAuthenticatorJwtJwkMaxWait, time.Second)
}

func (v *ViperProvider) AuthenticatorJwtJwkTtl() time.Duration {
return viperx.GetDuration(v.l, ViperKeyAuthenticatorJwtJwkTtl, time.Second*30)
}

func (v *ViperProvider) AuthorizerIsEnabled(id string) bool {
return v.pipelineIsEnabled("authorizers", id)
}
Expand Down
3 changes: 1 addition & 2 deletions driver/registry_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package driver
import (
"context"
"sync"
"time"

"github.com/ory/oathkeeper/driver/health"
"github.com/ory/oathkeeper/pipeline"
Expand Down Expand Up @@ -208,7 +207,7 @@ func (r *RegistryMemory) DecisionHandler() *api.DecisionHandler {

func (r *RegistryMemory) CredentialsFetcher() credentials.Fetcher {
if r.credentialsFetcher == nil {
r.credentialsFetcher = credentials.NewFetcherDefault(r.Logger(), time.Second, time.Second*30)
r.credentialsFetcher = credentials.NewFetcherDefault(r.Logger(), r.c.AuthenticatorJwtJwkMaxWait(), r.c.AuthenticatorJwtJwkTtl())
}

return r.credentialsFetcher
Expand Down