diff --git a/README.md b/README.md index 8add9da07..9d0ec6d4b 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ This library considered and implemented: - [Proof Key for Code Exchange by OAuth Public Clients](https://tools.ietf.org/html/rfc7636) - [OAuth 2.0 for Native Apps](https://tools.ietf.org/html/rfc8252) - [OpenID Connect Core 1.0](https://openid.net/specs/openid-connect-core-1_0.html) +- [OAuth 2.0 Pushed Authorization Request](https://datatracker.ietf.org/doc/html/rfc9126) OAuth2 and OpenID Connect are difficult protocols. If you want quick wins, we strongly encourage you to look at [Hydra](https://github.com/ory-am/hydra). diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 481d2b347..b0984982f 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -47,7 +47,7 @@ func wrapSigningKeyFailure(outer *RFC6749Error, inner error) *RFC6749Error { return outer } -func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(ctx context.Context, request *AuthorizeRequest) error { +func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(ctx context.Context, request *AuthorizeRequest, isPARRequest bool) error { var scope Arguments = RemoveEmpty(strings.Split(request.Form.Get("scope"), " ")) // Even if a scope parameter is present in the Request Object value, a scope parameter MUST always be passed using @@ -155,6 +155,12 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(ctx context. } claims := token.Claims + // Reject the request if the "request_uri" authorization request + // parameter is provided. + if requestURI, _ := claims["request_uri"].(string); isPARRequest && requestURI != "" { + return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Pushed Authorization Requests can not contain the 'request_uri' parameter.")) + } + for k, v := range claims { request.Form.Set(k, fmt.Sprintf("%s", v)) } @@ -272,7 +278,57 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest return nil } +func (f *Fosite) authorizeRequestFromPAR(ctx context.Context, r *http.Request, request *AuthorizeRequest) (bool, error) { + configProvider, ok := f.Config.(PushedAuthorizeRequestConfigProvider) + if !ok { + // If the config provider is not implemented, PAR cannot be used. + return false, nil + } + + requestURI := r.Form.Get("request_uri") + if requestURI == "" || !strings.HasPrefix(requestURI, configProvider.GetPushedAuthorizeRequestURIPrefix(ctx)) { + // nothing to do here + return false, nil + } + + clientID := r.Form.Get("client_id") + + storage, ok := f.Store.(PARStorage) + if !ok { + return false, errorsx.WithStack(ErrServerError.WithHint(ErrorPARNotSupported).WithDebug(DebugPARStorageInvalid)) + } + + // hydrate the requester + var parRequest AuthorizeRequester + var err error + if parRequest, err = storage.GetPARSession(ctx, requestURI); err != nil { + return false, errorsx.WithStack(ErrInvalidRequestURI.WithHint("Invalid PAR session").WithWrap(err).WithDebug(err.Error())) + } + + // hydrate the request object + request.Merge(parRequest) + request.RedirectURI = parRequest.GetRedirectURI() + request.ResponseTypes = parRequest.GetResponseTypes() + request.State = parRequest.GetState() + request.ResponseMode = parRequest.GetResponseMode() + + if err := storage.DeletePARSession(ctx, requestURI); err != nil { + return false, errorsx.WithStack(ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + // validate the clients match + if clientID != request.GetClient().GetID() { + return false, errorsx.WithStack(ErrInvalidRequest.WithHint("The 'client_id' must match the one sent in the pushed authorization request.")) + } + + return true, nil +} + func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) { + return f.newAuthorizeRequest(ctx, r, false) +} + +func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPARRequest bool) (AuthorizeRequester, error) { request := NewAuthorizeRequest() request.Request.Lang = i18n.GetLangFromRequest(f.Config.GetMessageCatalog(ctx), r) @@ -287,6 +343,18 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth // Save state to the request to be returned in error conditions (https://github.com/ory/hydra/issues/1642) request.State = request.Form.Get("state") + // Check if this is a continuation from a pushed authorization request + if !isPARRequest { + if isPAR, err := f.authorizeRequestFromPAR(ctx, r, request); err != nil { + return request, err + } else if isPAR { + // No need to continue + return request, nil + } else if configProvider, ok := f.Config.(PushedAuthorizeRequestConfigProvider); ok && configProvider.EnforcePushedAuthorize(ctx) { + return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Pushed Authorization Requests are enforced but no such request was sent.")) + } + } + client, err := f.Store.GetClient(ctx, request.GetRequestForm().Get("client_id")) if err != nil { return request, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithWrap(err).WithDebug(err.Error())) @@ -298,7 +366,7 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth // // All other parse methods should come afterwards so that we ensure that the data is taken // from the request_object if set. - if err := f.authorizeRequestParametersFromOpenIDConnectRequest(ctx, request); err != nil { + if err := f.authorizeRequestParametersFromOpenIDConnectRequest(ctx, request, isPARRequest); err != nil { return request, err } diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index ad90e330e..9733a2a8d 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -213,7 +213,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { }, } - err := f.authorizeRequestParametersFromOpenIDConnectRequest(context.Background(), req) + err := f.authorizeRequestParametersFromOpenIDConnectRequest(context.Background(), req, false) if tc.expectErr != nil { require.EqualError(t, err, tc.expectErr.Error(), "%+v", err) if tc.expectErrReason != "" { diff --git a/compose/compose.go b/compose/compose.go index dca99dcdf..683564537 100644 --- a/compose/compose.go +++ b/compose/compose.go @@ -54,7 +54,6 @@ type Factory func(config fosite.Configurator, storage interface{}, strategy inte // Compose makes use of interface{} types in order to be able to handle a all types of stores, strategies and handlers. func Compose(config *fosite.Config, storage interface{}, strategy interface{}, factories ...Factory) fosite.OAuth2Provider { f := fosite.NewOAuth2Provider(storage.(fosite.Storage), config) - for _, factory := range factories { res := factory(config, storage, strategy) if ah, ok := res.(fosite.AuthorizeEndpointHandler); ok { @@ -69,6 +68,9 @@ func Compose(config *fosite.Config, storage interface{}, strategy interface{}, f if rh, ok := res.(fosite.RevocationHandler); ok { config.RevocationHandlers.Append(rh) } + if ph, ok := res.(fosite.PushedAuthorizeEndpointHandler); ok { + config.PushedAuthorizeEndpointHandlers.Append(ph) + } } return f @@ -103,5 +105,6 @@ func ComposeAllEnabled(config *fosite.Config, storage interface{}, key interface OAuth2TokenRevocationFactory, OAuth2PKCEFactory, + PushedAuthorizeHandlerFactory, ) } diff --git a/compose/compose_par.go b/compose/compose_par.go new file mode 100644 index 000000000..7e36a3a2a --- /dev/null +++ b/compose/compose_par.go @@ -0,0 +1,14 @@ +package compose + +import ( + "github.com/ory/fosite" + "github.com/ory/fosite/handler/par" +) + +// PushedAuthorizeHandlerFactory creates the basic PAR handler +func PushedAuthorizeHandlerFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { + return &par.PushedAuthorizeHandler{ + Storage: storage, + Config: config, + } +} diff --git a/config.go b/config.go index 3c5b1196e..1edc5bc5f 100644 --- a/config.go +++ b/config.go @@ -266,6 +266,12 @@ type RevocationHandlersProvider interface { GetRevocationHandlers(ctx context.Context) RevocationHandlers } +// PushedAuthorizeEndpointHandlersProvider returns the provider for configuring the PAR handlers. +type PushedAuthorizeRequestHandlersProvider interface { + // GetPushedAuthorizeEndpointHandlers returns the handlers. + GetPushedAuthorizeEndpointHandlers(ctx context.Context) PushedAuthorizeEndpointHandlers +} + // UseLegacyErrorFormatProvider returns the provider for configuring whether to use the legacy error format. // // DEPRECATED: Do not use this flag anymore. @@ -275,3 +281,19 @@ type UseLegacyErrorFormatProvider interface { // DEPRECATED: Do not use this flag anymore. GetUseLegacyErrorFormat(ctx context.Context) bool } + +// PushedAuthorizeRequestConfigProvider is the configuration provider for pushed +// authorization request. +type PushedAuthorizeRequestConfigProvider interface { + // GetPushedAuthorizeRequestURIPrefix is the request URI prefix. This is + // usually 'urn:ietf:params:oauth:request_uri:'. + GetPushedAuthorizeRequestURIPrefix(ctx context.Context) string + + // GetPushedAuthorizeContextLifespan is the lifespan of the short-lived PAR context. + GetPushedAuthorizeContextLifespan(ctx context.Context) time.Duration + + // EnforcePushedAuthorize indicates if PAR is enforced. In this mode, a client + // cannot pass authorize parameters at the 'authorize' endpoint. The 'authorize' endpoint + // must contain the PAR request_uri. + EnforcePushedAuthorize(ctx context.Context) bool +} diff --git a/config_default.go b/config_default.go index 5dfdcd3ce..9ff7dc941 100644 --- a/config_default.go +++ b/config_default.go @@ -35,6 +35,11 @@ import ( "github.com/ory/fosite/i18n" ) +const ( + defaultPARPrefix = "urn:ietf:params:oauth:request_uri:" + defaultPARContextLifetime = 5 * time.Minute +) + var ( _ AuthorizeCodeLifespanProvider = (*Config)(nil) _ RefreshTokenLifespanProvider = (*Config)(nil) @@ -73,6 +78,8 @@ var ( _ TokenEndpointHandlersProvider = (*Config)(nil) _ TokenIntrospectionHandlersProvider = (*Config)(nil) _ RevocationHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) ) type Config struct { @@ -202,6 +209,9 @@ type Config struct { // RevocationHandlers is a list of handlers that are called before the revocation endpoint is served. RevocationHandlers RevocationHandlers + // PushedAuthorizeEndpointHandlers is a list of handlers that are called before the PAR endpoint is served. + PushedAuthorizeEndpointHandlers PushedAuthorizeEndpointHandlers + // GlobalSecret is the global secret used to sign and verify signatures. GlobalSecret []byte @@ -210,6 +220,16 @@ type Config struct { // HMACHasher is the hasher used to generate HMAC signatures. HMACHasher func() hash.Hash + + // PushedAuthorizeRequestURIPrefix is the URI prefix for the PAR request_uri. + // This is defaulted to 'urn:ietf:params:oauth:request_uri:'. + PushedAuthorizeRequestURIPrefix string + + // PushedAuthorizeContextLifespan is the lifespan of the PAR context + PushedAuthorizeContextLifespan time.Duration + + // IsPushedAuthorizeEnforced enforces pushed authorization request for /authorize + IsPushedAuthorizeEnforced bool } func (c *Config) GetGlobalSecret(ctx context.Context) []byte { @@ -455,3 +475,34 @@ func (c *Config) GetClientAuthenticationStrategy(_ context.Context) ClientAuthen func (c *Config) GetDisableRefreshTokenValidation(_ context.Context) bool { return c.DisableRefreshTokenValidation } + +// GetPushedAuthorizeEndpointHandlers returns the handlers. +func (c *Config) GetPushedAuthorizeEndpointHandlers(ctx context.Context) PushedAuthorizeEndpointHandlers { + return c.PushedAuthorizeEndpointHandlers +} + +// GetPushedAuthorizeRequestURIPrefix is the request URI prefix. This is +// usually 'urn:ietf:params:oauth:request_uri:'. +func (c *Config) GetPushedAuthorizeRequestURIPrefix(ctx context.Context) string { + if c.PushedAuthorizeRequestURIPrefix == "" { + return defaultPARPrefix + } + + return c.PushedAuthorizeRequestURIPrefix +} + +// GetPushedAuthorizeContextLifespan is the lifespan of the short-lived PAR context. +func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) time.Duration { + if c.PushedAuthorizeContextLifespan <= 0 { + return defaultPARContextLifetime + } + + return c.PushedAuthorizeContextLifespan +} + +// EnforcePushedAuthorize indicates if PAR is enforced. In this mode, a client +// cannot pass authorize parameters at the 'authorize' endpoint. The 'authorize' endpoint +// must contain the PAR request_uri. +func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool { + return c.IsPushedAuthorizeEnforced +} diff --git a/context.go b/context.go index 48558e9a3..df877cfed 100644 --- a/context.go +++ b/context.go @@ -35,4 +35,6 @@ const ( AccessResponseContextKey = ContextKey("accessResponse") AuthorizeRequestContextKey = ContextKey("authorizeRequest") AuthorizeResponseContextKey = ContextKey("authorizeResponse") + // PushedAuthorizeResponseContextKey is the response context + PushedAuthorizeResponseContextKey = ContextKey("pushedAuthorizeResponse") ) diff --git a/fosite.go b/fosite.go index afb4a3d63..285a6be3f 100644 --- a/fosite.go +++ b/fosite.go @@ -86,6 +86,20 @@ func (t *RevocationHandlers) Append(h RevocationHandler) { *t = append(*t, h) } +// PushedAuthorizeEndpointHandlers is a list of PushedAuthorizeEndpointHandler +type PushedAuthorizeEndpointHandlers []PushedAuthorizeEndpointHandler + +// Append adds an AuthorizeEndpointHandler to this list. Ignores duplicates based on reflect.TypeOf. +func (a *PushedAuthorizeEndpointHandlers) Append(h PushedAuthorizeEndpointHandler) { + for _, this := range *a { + if reflect.TypeOf(this) == reflect.TypeOf(h) { + return + } + } + + *a = append(*a, h) +} + var _ OAuth2Provider = (*Fosite)(nil) type Configurator interface { diff --git a/fosite_test.go b/fosite_test.go index 0e0e2504c..0c21ce37a 100644 --- a/fosite_test.go +++ b/fosite_test.go @@ -30,6 +30,7 @@ import ( . "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/handler/par" ) func TestAuthorizeEndpointHandlers(t *testing.T) { @@ -65,6 +66,15 @@ func TestAuthorizedRequestValidators(t *testing.T) { assert.Equal(t, hs[0], h) } +func TestPushedAuthorizedRequestHandlers(t *testing.T) { + h := &par.PushedAuthorizeHandler{} + hs := PushedAuthorizeEndpointHandlers{} + hs.Append(h) + hs.Append(h) + require.Len(t, hs, 1) + assert.Equal(t, hs[0], h) +} + func TestMinParameterEntropy(t *testing.T) { f := Fosite{Config: new(Config)} assert.Equal(t, MinParameterEntropy, f.GetMinParameterEntropy(context.Background())) diff --git a/handler.go b/handler.go index 9ee834cea..f1319a610 100644 --- a/handler.go +++ b/handler.go @@ -76,3 +76,11 @@ type RevocationHandler interface { // RevokeToken handles access and refresh token revocation. RevokeToken(ctx context.Context, token string, tokenType TokenType, client Client) error } + +// PushedAuthorizeEndpointHandler is the interface that handles PAR (https://datatracker.ietf.org/doc/html/rfc9126) +type PushedAuthorizeEndpointHandler interface { + // HandlePushedAuthorizeRequest handles a pushed authorize endpoint request. To extend the handler's capabilities, the http request + // is passed along, if further information retrieval is required. If the handler feels that he is not responsible for + // the pushed authorize request, he must return nil and NOT modify session nor responder neither requester. + HandlePushedAuthorizeEndpointRequest(ctx context.Context, requester AuthorizeRequester, responder PushedAuthorizeResponder) error +} diff --git a/handler/oauth2/flow_authorize_code_auth.go b/handler/oauth2/flow_authorize_code_auth.go index 7e849eef8..1abdefd6d 100644 --- a/handler/oauth2/flow_authorize_code_auth.go +++ b/handler/oauth2/flow_authorize_code_auth.go @@ -77,7 +77,7 @@ func (c *AuthorizeExplicitGrantHandler) HandleAuthorizeEndpointRequest(ctx conte // } if !c.secureChecker(ctx)(ctx, ar.GetRedirectURI()) { - return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Redirect URL is using an insecure protocol, http is only allowed for hosts with suffix `localhost`, for example: http://myapp.localhost/.")) + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Redirect URL is using an insecure protocol, http is only allowed for hosts with suffix 'localhost', for example: http://myapp.localhost/.")) } client := ar.GetClient() diff --git a/handler/par/flow_pushed_authorize.go b/handler/par/flow_pushed_authorize.go new file mode 100644 index 000000000..4cd67c409 --- /dev/null +++ b/handler/par/flow_pushed_authorize.go @@ -0,0 +1,89 @@ +package par + +import ( + "context" + "encoding/base64" + "fmt" + "net/url" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/token/hmac" + "github.com/ory/x/errorsx" +) + +const ( + defaultPARKeyLength = 32 +) + +var b64 = base64.URLEncoding.WithPadding(base64.NoPadding) + +// PushedAuthorizeHandler handles the PAR request +type PushedAuthorizeHandler struct { + Storage interface{} + Config fosite.Configurator +} + +// HandlePushedAuthorizeEndpointRequest handles a pushed authorize endpoint request. To extend the handler's capabilities, the http request +// is passed along, if further information retrieval is required. If the handler feels that he is not responsible for +// the pushed authorize request, he must return nil and NOT modify session nor responder neither requester. +func (c *PushedAuthorizeHandler) HandlePushedAuthorizeEndpointRequest(ctx context.Context, ar fosite.AuthorizeRequester, resp fosite.PushedAuthorizeResponder) error { + configProvider, ok := c.Config.(fosite.PushedAuthorizeRequestConfigProvider) + if !ok { + return errorsx.WithStack(fosite.ErrServerError.WithHint(fosite.ErrorPARNotSupported).WithDebug(fosite.DebugPARConfigMissing)) + } + + storage, ok := c.Storage.(fosite.PARStorage) + if !ok { + return errorsx.WithStack(fosite.ErrServerError.WithHint(fosite.ErrorPARNotSupported).WithDebug(fosite.DebugPARStorageInvalid)) + } + + if !ar.GetResponseTypes().HasOneOf("token", "code", "id_token") { + return nil + } + + if !c.secureChecker(ctx, ar.GetRedirectURI()) { + return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Redirect URL is using an insecure protocol, http is only allowed for hosts with suffix 'localhost', for example: http://myapp.localhost/.")) + } + + client := ar.GetClient() + for _, scope := range ar.GetRequestedScopes() { + if !c.Config.GetScopeStrategy(ctx)(client.GetScopes(), scope) { + return errorsx.WithStack(fosite.ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope)) + } + } + + if err := c.Config.GetAudienceStrategy(ctx)(client.GetAudience(), ar.GetRequestedAudience()); err != nil { + return err + } + + expiresIn := configProvider.GetPushedAuthorizeContextLifespan(ctx) + if ar.GetSession() != nil { + ar.GetSession().SetExpiresAt(fosite.PushedAuthorizeRequestContext, time.Now().UTC().Add(expiresIn)) + } + + // generate an ID + stateKey, err := hmac.RandomBytes(defaultPARKeyLength) + if err != nil { + return errorsx.WithStack(fosite.ErrInsufficientEntropy.WithHint("Unable to generate the random part of the request_uri.").WithWrap(err).WithDebug(err.Error())) + } + + requestURI := fmt.Sprintf("%s%s", configProvider.GetPushedAuthorizeRequestURIPrefix(ctx), b64.EncodeToString(stateKey)) + + // store + if err = storage.CreatePARSession(ctx, requestURI, ar); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithHint("Unable to store the PAR session").WithWrap(err).WithDebug(err.Error())) + } + + resp.SetRequestURI(requestURI) + resp.SetExpiresIn(int(expiresIn.Seconds())) + return nil +} + +func (c *PushedAuthorizeHandler) secureChecker(ctx context.Context, u *url.URL) bool { + isRedirectURISecure := c.Config.GetRedirectSecureChecker(ctx) + if isRedirectURISecure == nil { + isRedirectURISecure = fosite.IsRedirectURISecure + } + return isRedirectURISecure(ctx, u) +} diff --git a/handler/par/flow_pushed_authorize_test.go b/handler/par/flow_pushed_authorize_test.go new file mode 100644 index 000000000..c9853b4b5 --- /dev/null +++ b/handler/par/flow_pushed_authorize_test.go @@ -0,0 +1,133 @@ +package par_test + +import ( + "context" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/fosite/storage" + + "github.com/ory/fosite" + . "github.com/ory/fosite/handler/par" +) + +func parseURL(uu string) *url.URL { + u, _ := url.Parse(uu) + return u +} + +func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) { + requestURIPrefix := "urn:ietf:params:oauth:request_uri_diff:" + store := storage.NewMemoryStore() + handler := PushedAuthorizeHandler{ + Storage: store, + Config: &fosite.Config{ + PushedAuthorizeContextLifespan: 30 * time.Minute, + PushedAuthorizeRequestURIPrefix: requestURIPrefix, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + }, + } + for _, c := range []struct { + handler PushedAuthorizeHandler + areq *fosite.AuthorizeRequest + description string + expectErr error + expect func(t *testing.T, areq *fosite.AuthorizeRequest, aresp *fosite.PushedAuthorizeResponse) + }{ + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{""}, + Request: *fosite.NewRequest(), + }, + description: "should pass because not responsible for handling an empty response type", + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"foo"}, + Request: *fosite.NewRequest(), + }, + description: "should pass because not responsible for handling an invalid response type", + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ResponseTypes: fosite.Arguments{"code"}, + RedirectURIs: []string{"http://asdf.com/cb"}, + }, + }, + RedirectURI: parseURL("http://asdf.com/cb"), + }, + description: "should fail because redirect uri is not https", + expectErr: fosite.ErrInvalidRequest, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ResponseTypes: fosite.Arguments{"code"}, + RedirectURIs: []string{"https://asdf.com/cb"}, + Audience: []string{"https://www.ory.sh/api"}, + }, + RequestedAudience: []string{"https://www.ory.sh/not-api"}, + }, + RedirectURI: parseURL("https://asdf.com/cb"), + }, + description: "should fail because audience doesn't match", + expectErr: fosite.ErrInvalidRequest, + }, + { + handler: handler, + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ResponseTypes: fosite.Arguments{"code"}, + RedirectURIs: []string{"https://asdf.de/cb"}, + Audience: []string{"https://www.ory.sh/api"}, + }, + RequestedAudience: []string{"https://www.ory.sh/api"}, + GrantedScope: fosite.Arguments{"a", "b"}, + Session: &fosite.DefaultSession{ + ExpiresAt: map[fosite.TokenType]time.Time{fosite.AccessToken: time.Now().UTC().Add(time.Hour)}, + }, + RequestedAt: time.Now().UTC(), + }, + State: "superstate", + RedirectURI: parseURL("https://asdf.de/cb"), + }, + description: "should pass", + expect: func(t *testing.T, areq *fosite.AuthorizeRequest, aresp *fosite.PushedAuthorizeResponse) { + requestURI := aresp.RequestURI + assert.NotEmpty(t, requestURI) + assert.True(t, strings.HasPrefix(requestURI, requestURIPrefix), "requestURI does not match: %s", requestURI) + }, + }, + } { + t.Run("case="+c.description, func(t *testing.T) { + aresp := &fosite.PushedAuthorizeResponse{} + err := c.handler.HandlePushedAuthorizeEndpointRequest(context.Background(), c.areq, aresp) + if c.expectErr != nil { + require.EqualError(t, err, c.expectErr.Error()) + } else { + require.NoError(t, err) + } + + if c.expect != nil { + c.expect(t, c.areq, aresp) + } + }) + } +} diff --git a/integration/helper_endpoints_test.go b/integration/helper_endpoints_test.go index d8afb359d..fcd970cff 100644 --- a/integration/helper_endpoints_test.go +++ b/integration/helper_endpoints_test.go @@ -173,3 +173,27 @@ func tokenEndpointHandler(t *testing.T, provider fosite.OAuth2Provider) func(rw provider.WriteAccessResponse(req.Context(), rw, accessRequest, response) } } + +func pushedAuthorizeRequestHandler(t *testing.T, oauth2 fosite.OAuth2Provider, session fosite.Session) func(rw http.ResponseWriter, req *http.Request) { + return func(rw http.ResponseWriter, req *http.Request) { + ctx := fosite.NewContext() + + ar, err := oauth2.NewPushedAuthorizeRequest(ctx, req) + if err != nil { + t.Logf("PAR request failed because: %+v", err) + t.Logf("Request: %+v", ar) + oauth2.WritePushedAuthorizeError(ctx, rw, ar, err) + return + } + + response, err := oauth2.NewPushedAuthorizeResponse(ctx, ar, session) + if err != nil { + t.Logf("PAR response failed because: %+v", err) + t.Logf("Request: %+v", ar) + oauth2.WritePushedAuthorizeError(ctx, rw, ar, err) + return + } + + oauth2.WritePushedAuthorizeResponse(ctx, rw, ar, response) + } +} diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index 186e22efb..e1a325cd5 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -117,6 +117,7 @@ var fositeStore = &storage.MemoryStore{ IDSessions: map[string]fosite.Requester{}, AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, + PARSessions: map[string]fosite.AuthorizeRequester{}, } type defaultSession struct { @@ -207,6 +208,7 @@ func mockServer(t *testing.T, f fosite.OAuth2Provider, session fosite.Session) * router.HandleFunc("/info", tokenInfoHandler(t, f, session)) router.HandleFunc("/introspect", tokenIntrospectionHandler(t, f, session)) router.HandleFunc("/revoke", tokenRevocationHandler(t, f, session)) + router.HandleFunc("/par", pushedAuthorizeRequestHandler(t, f, session)) ts := httptest.NewServer(router) return ts diff --git a/integration/pushed_authorize_code_grant_test.go b/integration/pushed_authorize_code_grant_test.go new file mode 100644 index 000000000..a9f5a5de2 --- /dev/null +++ b/integration/pushed_authorize_code_grant_test.go @@ -0,0 +1,209 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + goauth "golang.org/x/oauth2" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/handler/oauth2" +) + +func TestPushedAuthorizeCodeFlow(t *testing.T) { + for _, strategy := range []oauth2.AccessTokenStrategy{ + hmacStrategy, + } { + runPushedAuthorizeCodeGrantTest(t, strategy) + } +} + +func runPushedAuthorizeCodeGrantTest(t *testing.T, strategy interface{}) { + f := compose.Compose(new(fosite.Config), fositeStore, strategy, compose.OAuth2AuthorizeExplicitFactory, compose.OAuth2TokenIntrospectionFactory, compose.PushedAuthorizeHandlerFactory) + ts := mockServer(t, f, &fosite.DefaultSession{Subject: "foo-sub"}) + defer ts.Close() + + oauthClient := newOAuth2Client(ts) + fositeStore.Clients["my-client"].(*fosite.DefaultClient).RedirectURIs[0] = ts.URL + "/callback" + + var state string + for k, c := range []struct { + description string + setup func() + check func(t *testing.T, r *http.Response) + params map[string]string + authStatusCode int + parStatusCode int + }{ + { + description: "should fail because of audience", + params: map[string]string{"audience": "https://www.ory.sh/not-api"}, + setup: func() { + oauthClient = newOAuth2Client(ts) + state = "12345678901234567890" + }, + parStatusCode: http.StatusBadRequest, + authStatusCode: http.StatusNotAcceptable, + }, + { + description: "should fail because of scope", + params: nil, + setup: func() { + oauthClient = newOAuth2Client(ts) + oauthClient.Scopes = []string{"not-exist"} + state = "12345678901234567890" + }, + parStatusCode: http.StatusBadRequest, + authStatusCode: http.StatusNotAcceptable, + }, + { + description: "should pass with proper audience", + params: map[string]string{"audience": "https://www.ory.sh/api"}, + setup: func() { + oauthClient = newOAuth2Client(ts) + state = "12345678901234567890" + }, + check: func(t *testing.T, r *http.Response) { + var b fosite.AccessRequest + b.Client = new(fosite.DefaultClient) + b.Session = new(defaultSession) + require.NoError(t, json.NewDecoder(r.Body).Decode(&b)) + assert.EqualValues(t, fosite.Arguments{"https://www.ory.sh/api"}, b.RequestedAudience) + assert.EqualValues(t, fosite.Arguments{"https://www.ory.sh/api"}, b.GrantedAudience) + assert.EqualValues(t, "foo-sub", b.Session.(*defaultSession).Subject) + }, + parStatusCode: http.StatusCreated, + authStatusCode: http.StatusOK, + }, + { + description: "should pass", + setup: func() { + oauthClient = newOAuth2Client(ts) + state = "12345678901234567890" + }, + parStatusCode: http.StatusCreated, + authStatusCode: http.StatusOK, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) { + c.setup() + + // build request from the OAuth client + data := url.Values{} + data.Set("client_id", oauthClient.ClientID) + data.Set("client_secret", oauthClient.ClientSecret) + data.Set("response_type", "code") + data.Set("state", state) + data.Set("scope", strings.Join(oauthClient.Scopes, " ")) + data.Set("redirect_uri", oauthClient.RedirectURL) + for k, v := range c.params { + data.Set(k, v) + } + + req, err := http.NewRequest("POST", ts.URL+"/par", strings.NewReader(data.Encode())) + require.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + resp, err := http.DefaultClient.Do(req) + + require.NoError(t, err) + + body, err := checkStatusAndGetBody(t, resp, c.parStatusCode) + require.NoError(t, err, "Unable to get body after PAR. Err=%v", err) + + if resp.StatusCode != http.StatusCreated { + return + } + + m := map[string]interface{}{} + err = json.Unmarshal(body, &m) + + assert.NoError(t, err, "Error occurred when unamrshaling the body: %v", err) + + // validate request_uri + requestURI, _ := m["request_uri"].(string) + assert.NotEmpty(t, requestURI, "request_uri is empty") + assert.Condition(t, func() bool { + return strings.HasPrefix(requestURI, "urn:ietf:params:oauth:request_uri:") + }, "PAR Prefix is incorrect: %s", requestURI) + + // validate expires_in + assert.EqualValues(t, 300, int(m["expires_in"].(float64)), "Invalid expires_in value=%v", m["expires_in"]) + + // call authorize + data = url.Values{} + data.Set("client_id", oauthClient.ClientID) + data.Set("request_uri", m["request_uri"].(string)) + req, err = http.NewRequest("POST", ts.URL+"/auth", strings.NewReader(data.Encode())) + require.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err = http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, c.authStatusCode, resp.StatusCode) + if resp.StatusCode != http.StatusOK { + return + } + + require.NotEmpty(t, resp.Request.URL.Query().Get("code"), "Auth code is empty") + + token, err := oauthClient.Exchange(goauth.NoContext, resp.Request.URL.Query().Get("code")) + require.NoError(t, err) + require.NotEmpty(t, token.AccessToken) + + httpClient := oauthClient.Client(goauth.NoContext, token) + resp, err = httpClient.Get(ts.URL + "/info") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + if c.check != nil { + c.check(t, resp) + } + }) + } +} + +func checkStatusAndGetBody(t *testing.T, resp *http.Response, expectedStatusCode int) ([]byte, error) { + defer resp.Body.Close() + + require.Equal(t, expectedStatusCode, resp.StatusCode) + b, err := ioutil.ReadAll(resp.Body) + if err == nil { + fmt.Printf("PAR response: body=%s\n", string(b)) + } + if expectedStatusCode != resp.StatusCode { + return nil, fmt.Errorf("Invalid status code %d", resp.StatusCode) + } + + return b, err +} diff --git a/internal/pushed_authorize_handler.go b/internal/pushed_authorize_handler.go new file mode 100644 index 000000000..77b77a143 --- /dev/null +++ b/internal/pushed_authorize_handler.go @@ -0,0 +1,47 @@ +package internal + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + + fosite "github.com/ory/fosite" +) + +// MockPushedAuthorizeEndpointHandler is a mock of PushedAuthorizeEndpointHandler interface +type MockPushedAuthorizeEndpointHandler struct { + ctrl *gomock.Controller + recorder *MockPushedAuthorizeEndpointHandlerMockRecorder +} + +// MockPushedAuthorizeEndpointHandlerMockRecorder is the mock recorder for PushedMockAuthorizeEndpointHandler +type MockPushedAuthorizeEndpointHandlerMockRecorder struct { + mock *MockPushedAuthorizeEndpointHandler +} + +// NewMockPushedAuthorizeEndpointHandler creates a new mock instance +func NewMockPushedAuthorizeEndpointHandler(ctrl *gomock.Controller) *MockPushedAuthorizeEndpointHandler { + mock := &MockPushedAuthorizeEndpointHandler{ctrl: ctrl} + mock.recorder = &MockPushedAuthorizeEndpointHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPushedAuthorizeEndpointHandler) EXPECT() *MockPushedAuthorizeEndpointHandlerMockRecorder { + return m.recorder +} + +// HandlePushedAuthorizeEndpointRequest mocks base method +func (m *MockPushedAuthorizeEndpointHandler) HandlePushedAuthorizeEndpointRequest(arg0 context.Context, arg1 fosite.AuthorizeRequester, arg2 fosite.PushedAuthorizeResponder) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandlePushedAuthorizeEndpointRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandlePushedAuthorizeEndpointRequest indicates an expected call of HandlePushedAuthorizeEndpointRequest +func (mr *MockPushedAuthorizeEndpointHandlerMockRecorder) HandlePushedAuthorizeEndpointRequest(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePushedAuthorizeEndpointRequest", reflect.TypeOf((*MockPushedAuthorizeEndpointHandler)(nil).HandlePushedAuthorizeEndpointRequest), arg0, arg1, arg2) +} diff --git a/oauth2.go b/oauth2.go index 46519402c..f7ca95995 100644 --- a/oauth2.go +++ b/oauth2.go @@ -39,12 +39,14 @@ const ( RefreshToken TokenType = "refresh_token" AuthorizeCode TokenType = "authorize_code" IDToken TokenType = "id_token" + // PushedAuthorizeRequestContext represents the PAR context object + PushedAuthorizeRequestContext TokenType = "par_context" BearerAccessToken string = "bearer" ) // OAuth2Provider is an interface that enables you to write OAuth2 handlers with only a few lines of code. -// Check fosite.Fosite for an implementation of this interface. +// Check Fosite for an implementation of this interface. type OAuth2Provider interface { // NewAuthorizeRequest returns an AuthorizeRequest. // @@ -162,6 +164,18 @@ type OAuth2Provider interface { // WriteIntrospectionResponse responds with token metadata discovered by token introspection as defined in // https://tools.ietf.org/search/rfc7662#section-2.2 WriteIntrospectionResponse(ctx context.Context, rw http.ResponseWriter, r IntrospectionResponder) + + // NewPushedAuthorizeRequest validates the request and produces an AuthorizeRequester object that can be stored + NewPushedAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) + + // NewPushedAuthorizeResponse executes the handlers and builds the response + NewPushedAuthorizeResponse(ctx context.Context, ar AuthorizeRequester, session Session) (PushedAuthorizeResponder, error) + + // WritePushedAuthorizeResponse writes the PAR response + WritePushedAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, resp PushedAuthorizeResponder) + + // WritePushedAuthorizeError writes the PAR error + WritePushedAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, err error) } // IntrospectionResponder is the response object that will be returned when token introspection was successful, @@ -328,6 +342,33 @@ type AuthorizeResponder interface { AddParameter(key, value string) } +// PushedAuthorizeResponder is the response object for PAR +type PushedAuthorizeResponder interface { + // GetRequestURI returns the request_uri + GetRequestURI() string + // SetRequestURI sets the request_uri + SetRequestURI(requestURI string) + // GetExpiresIn gets the expires_in + GetExpiresIn() int + // SetExpiresIn sets the expires_in + SetExpiresIn(seconds int) + + // GetHeader returns the response's header + GetHeader() (header http.Header) + + // AddHeader adds an header key value pair to the response + AddHeader(key, value string) + + // SetExtra sets a key value pair for the response. + SetExtra(key string, value interface{}) + + // GetExtra returns a key's value. + GetExtra(key string) interface{} + + // ToMap converts the response to a map. + ToMap() map[string]interface{} +} + // G11NContext is the globalization context type G11NContext interface { // GetLang returns the current language in the context diff --git a/pushed_authorize_request_handler.go b/pushed_authorize_request_handler.go new file mode 100644 index 000000000..8e2f47a6a --- /dev/null +++ b/pushed_authorize_request_handler.go @@ -0,0 +1,70 @@ +package fosite + +import ( + "context" + "errors" + "net/http" + + "github.com/ory/fosite/i18n" + "github.com/ory/x/errorsx" +) + +const ( + ErrorPARNotSupported = "The OAuth 2.0 provider does not support Pushed Authorization Requests" + DebugPARStorageInvalid = "'PARStorage' not implemented" + DebugPARConfigMissing = "'PushedAuthorizeRequestConfigProvider' not implemented" + DebugPARRequestsHandlerMissing = "'PushedAuthorizeRequestHandlersProvider' not implemented" +) + +// NewPushedAuthorizeRequest validates the request and produces an AuthorizeRequester object that can be stored +func (f *Fosite) NewPushedAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) { + request := NewAuthorizeRequest() + request.Request.Lang = i18n.GetLangFromRequest(f.Config.GetMessageCatalog(ctx), r) + + if r.Method != "POST" { + return request, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method)) + } + + if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { + return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error())) + } + request.Form = r.Form + request.State = request.Form.Get("state") + + // Authenticate the client in the same way as at the token endpoint + // (Section 2.3 of [RFC6749]). + client, err := f.AuthenticateClient(ctx, r, r.Form) + if err != nil { + var rfcerr *RFC6749Error + if errors.As(err, &rfcerr) && rfcerr.ErrorField != ErrInvalidClient.ErrorField { + return request, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client could not be authenticated.").WithWrap(err).WithDebug(err.Error())) + } + + return request, err + } + request.Client = client + + // Reject the request if the "request_uri" authorization request + // parameter is provided. + if r.Form.Get("request_uri") != "" { + return request, errorsx.WithStack(ErrInvalidRequest.WithHint("The request must not contain 'request_uri'.")) + } + + // For private_key_jwt or basic auth client authentication, "client_id" may not inside the form + // However this is required by NewAuthorizeRequest implementation + if len(r.Form.Get("client_id")) == 0 { + r.Form.Set("client_id", client.GetID()) + } + + // Validate as if this is a new authorize request + fr, err := f.newAuthorizeRequest(ctx, r, true) + if err != nil { + return fr, err + } + + if fr.GetRequestedScopes().Has("openid") && r.Form.Get("redirect_uri") == "" { + return fr, errorsx.WithStack(ErrInvalidRequest.WithHint("Query parameter 'redirect_uri' is required when performing an OpenID Connect flow.")) + } + + return fr, nil +} diff --git a/pushed_authorize_request_handler_test.go b/pushed_authorize_request_handler_test.go new file mode 100644 index 000000000..27a8e3501 --- /dev/null +++ b/pushed_authorize_request_handler_test.go @@ -0,0 +1,664 @@ +package fosite_test + +import ( + "fmt" + "net/http" + "net/url" + "runtime/debug" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + . "github.com/ory/fosite" + "github.com/ory/fosite/internal" +) + +// Should pass +// +// * https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#Terminology +// The OAuth 2.0 specification allows for registration of space-separated response_type parameter values. +// If a Response Type contains one of more space characters (%20), it is compared as a space-delimited list of +// values in which the order of values does not matter. +func TestNewPushedAuthorizeRequest(t *testing.T) { + ctrl := gomock.NewController(t) + store := internal.NewMockStorage(ctrl) + hasher := internal.NewMockHasher(ctrl) + defer ctrl.Finish() + + config := &Config{ + ScopeStrategy: ExactScopeStrategy, + AudienceMatchingStrategy: DefaultAudienceMatchingStrategy, + ClientSecretsHasher: hasher, + } + + fosite := &Fosite{ + Store: store, + Config: config, + } + + ctx := NewContext() + + redir, _ := url.Parse("https://foo.bar/cb") + specialCharRedir, _ := url.Parse("web+application://callback") + for _, c := range []struct { + desc string + conf *Fosite + r *http.Request + query url.Values + expectedError error + mock func() + expect *AuthorizeRequest + }{ + /* empty request */ + { + desc: "empty request fails", + conf: fosite, + r: &http.Request{ + Method: "POST", + }, + expectedError: ErrInvalidClient, + mock: func() {}, + }, + /* invalid redirect uri */ + { + desc: "invalid redirect uri fails", + conf: fosite, + query: url.Values{"redirect_uri": []string{"invalid"}}, + expectedError: ErrInvalidClient, + mock: func() {}, + }, + /* invalid client */ + { + desc: "invalid client fails", + conf: fosite, + query: url.Values{"redirect_uri": []string{"https://foo.bar/cb"}}, + expectedError: ErrInvalidClient, + mock: func() {}, + }, + /* redirect client mismatch */ + { + desc: "client and request redirects mismatch", + conf: fosite, + query: url.Values{ + "client_id": []string{"1234"}, + "client_secret": []string{"1234"}, + }, + expectedError: ErrInvalidRequest, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + }, + /* redirect client mismatch */ + { + desc: "client and request redirects mismatch", + conf: fosite, + query: url.Values{ + "redirect_uri": []string{""}, + "client_id": []string{"1234"}, + "client_secret": []string{"1234"}, + }, + expectedError: ErrInvalidRequest, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + }, + /* redirect client mismatch */ + { + desc: "client and request redirects mismatch", + conf: fosite, + query: url.Values{ + "redirect_uri": []string{"https://foo.bar/cb"}, + "client_id": []string{"1234"}, + "client_secret": []string{"1234"}, + }, + expectedError: ErrInvalidRequest, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"invalid"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + }, + /* no state */ + { + desc: "no state", + conf: fosite, + query: url.Values{ + "redirect_uri": []string{"https://foo.bar/cb"}, + "client_id": []string{"1234"}, + "client_secret": []string{"1234"}, + "response_type": []string{"code"}, + }, + expectedError: ErrInvalidState, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + }, + /* short state */ + { + desc: "short state", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code"}, + "state": {"short"}, + }, + expectedError: ErrInvalidState, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + }, + /* fails because scope not given */ + { + desc: "should fail because client does not have scope baz", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar baz"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expectedError: ErrInvalidScope, + }, + /* fails because scope not given */ + { + desc: "should fail because client does not have scope baz", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api"}, + Secret: []byte("1234"), + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expectedError: ErrInvalidRequest, + }, + /* success case */ + { + desc: "should pass", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ + ResponseTypes: []string{"code token"}, + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultClient{ + ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* repeated audience parameter */ + { + desc: "repeated audience parameter", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ + ResponseTypes: []string{"code token"}, + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultClient{ + ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* repeated audience parameter with tricky values */ + { + desc: "repeated audience parameter with tricky values", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"test value", ""}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ + ResponseTypes: []string{"code token"}, + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"test value"}, + Secret: []byte("1234"), + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultClient{ + ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"test value"}, + Secret: []byte("1234"), + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"test value"}, + }, + }, + }, + /* redirect_uri with special character in protocol*/ + { + desc: "redirect_uri with special character", + conf: fosite, + query: url.Values{ + "redirect_uri": {"web+application://callback"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ + ResponseTypes: []string{"code token"}, + RedirectURIs: []string{"web+application://callback"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: specialCharRedir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultClient{ + ResponseTypes: []string{"code token"}, RedirectURIs: []string{"web+application://callback"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* audience with double spaces between values */ + { + desc: "audience with double spaces between values", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{ + ResponseTypes: []string{"code token"}, + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultClient{ + ResponseTypes: []string{"code token"}, RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* fails because unknown response_mode*/ + { + desc: "should fail because unknown response_mode", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"unknown"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expectedError: ErrUnsupportedResponseMode, + }, + /* fails because response_mode is requested but the OAuth 2.0 client doesn't support response mode */ + { + desc: "should fail because response_mode is requested but the OAuth 2.0 client doesn't support response mode", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expectedError: ErrUnsupportedResponseMode, + }, + /* fails because requested response mode is not allowed */ + { + desc: "should fail because requested response mode is not allowed", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeQuery}, + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expectedError: ErrUnsupportedResponseMode, + }, + /* success with response mode */ + { + desc: "success with response mode", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeFormPost}, + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeFormPost}, + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* determine correct response mode if default */ + { + desc: "success with response mode", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeQuery}, + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code"}, + State: "strong-state", + Request: Request{ + Client: &DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeQuery}, + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* determine correct response mode if default */ + { + desc: "success with response mode", + conf: fosite, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeFragment}, + }, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + Secret: []byte("1234"), + }, + ResponseModes: []ResponseModeType{ResponseModeFragment}, + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, + /* fails because request_uri is included */ + { + desc: "should fail because request_uri is provided in the request", + conf: fosite, + query: url.Values{ + "request_uri": {"https://foo.bar/ru"}, + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("1234"))).Return(nil) + }, + expectedError: ErrInvalidRequest.WithHint("The request must not contain 'request_uri'."), + }, + /* fails because of invalid client credentials */ + { + desc: "should fail because of invalid client creds", + conf: fosite, + query: url.Values{ + "request_uri": {"https://foo.bar/ru"}, + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "client_secret": []string{"4321"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, Secret: []byte("1234")}, nil).MaxTimes(2) + hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("1234")), gomock.Eq([]byte("4321"))).Return(fmt.Errorf("invalid hash")) + }, + expectedError: ErrInvalidClient, + }, + } { + t.Run(fmt.Sprintf("case=%s", c.desc), func(t *testing.T) { + ctx := NewContext() + + c.mock() + if c.r == nil { + c.r = &http.Request{ + Header: http.Header{}, + Method: "POST", + } + if c.query != nil { + c.r.URL = &url.URL{RawQuery: c.query.Encode()} + } + } + + ar, err := c.conf.NewPushedAuthorizeRequest(ctx, c.r) + if c.expectedError != nil { + assert.EqualError(t, err, c.expectedError.Error(), "Stack: %s", string(debug.Stack())) + // https://github.com/ory/hydra/issues/1642 + AssertObjectKeysEqual(t, &AuthorizeRequest{State: c.query.Get("state")}, ar, "State") + } else { + require.NoError(t, err) + AssertObjectKeysEqual(t, c.expect, ar, "ResponseTypes", "RequestedAudience", "RequestedScope", "Client", "RedirectURI", "State") + assert.NotNil(t, ar.GetRequestedAt()) + } + }) + } +} diff --git a/pushed_authorize_response.go b/pushed_authorize_response.go new file mode 100644 index 000000000..62bf2b151 --- /dev/null +++ b/pushed_authorize_response.go @@ -0,0 +1,58 @@ +package fosite + +import "net/http" + +// PushedAuthorizeResponse is the response object for PAR +type PushedAuthorizeResponse struct { + RequestURI string `json:"request_uri"` + ExpiresIn int `json:"expires_in"` + Header http.Header + Extra map[string]interface{} +} + +// GetRequestURI gets +func (a *PushedAuthorizeResponse) GetRequestURI() string { + return a.RequestURI +} + +// SetRequestURI sets +func (a *PushedAuthorizeResponse) SetRequestURI(requestURI string) { + a.RequestURI = requestURI +} + +// GetExpiresIn gets +func (a *PushedAuthorizeResponse) GetExpiresIn() int { + return a.ExpiresIn +} + +// SetExpiresIn sets +func (a *PushedAuthorizeResponse) SetExpiresIn(seconds int) { + a.ExpiresIn = seconds +} + +// GetHeader gets +func (a *PushedAuthorizeResponse) GetHeader() http.Header { + return a.Header +} + +// AddHeader adds +func (a *PushedAuthorizeResponse) AddHeader(key, value string) { + a.Header.Add(key, value) +} + +// SetExtra sets +func (a *PushedAuthorizeResponse) SetExtra(key string, value interface{}) { + a.Extra[key] = value +} + +// GetExtra gets +func (a *PushedAuthorizeResponse) GetExtra(key string) interface{} { + return a.Extra[key] +} + +// ToMap converts to a map +func (a *PushedAuthorizeResponse) ToMap() map[string]interface{} { + a.Extra["request_uri"] = a.RequestURI + a.Extra["expires_in"] = a.ExpiresIn + return a.Extra +} diff --git a/pushed_authorize_response_writer.go b/pushed_authorize_response_writer.go new file mode 100644 index 000000000..c4ae0b74c --- /dev/null +++ b/pushed_authorize_response_writer.go @@ -0,0 +1,86 @@ +package fosite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/ory/x/errorsx" +) + +// NewPushedAuthorizeResponse executes the handlers and builds the response +func (f *Fosite) NewPushedAuthorizeResponse(ctx context.Context, ar AuthorizeRequester, session Session) (PushedAuthorizeResponder, error) { + // Get handlers. If no handlers are defined, this is considered a misconfigured Fosite instance. + handlersProvider, ok := f.Config.(PushedAuthorizeRequestHandlersProvider) + if !ok { + return nil, errorsx.WithStack(ErrServerError.WithHint(ErrorPARNotSupported).WithDebug(DebugPARRequestsHandlerMissing)) + } + + var resp = &PushedAuthorizeResponse{ + Header: http.Header{}, + Extra: map[string]interface{}{}, + } + + ctx = context.WithValue(ctx, AuthorizeRequestContextKey, ar) + ctx = context.WithValue(ctx, PushedAuthorizeResponseContextKey, resp) + + ar.SetSession(session) + for _, h := range handlersProvider.GetPushedAuthorizeEndpointHandlers(ctx) { + if err := h.HandlePushedAuthorizeEndpointRequest(ctx, ar, resp); err != nil { + return nil, err + } + } + + return resp, nil +} + +// WritePushedAuthorizeResponse writes the PAR response +func (f *Fosite) WritePushedAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, resp PushedAuthorizeResponder) { + // Set custom headers, e.g. "X-MySuperCoolCustomHeader" or "X-DONT-CACHE-ME"... + wh := rw.Header() + rh := resp.GetHeader() + for k := range rh { + wh.Set(k, rh.Get(k)) + } + + wh.Set("Cache-Control", "no-store") + wh.Set("Pragma", "no-cache") + wh.Set("Content-Type", "application/json;charset=UTF-8") + + js, err := json.Marshal(resp.ToMap()) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + rw.Header().Set("Content-Type", "application/json;charset=UTF-8") + + rw.WriteHeader(http.StatusCreated) + _, _ = rw.Write(js) +} + +// WritePushedAuthorizeError writes the PAR error +func (f *Fosite) WritePushedAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, err error) { + rw.Header().Set("Cache-Control", "no-store") + rw.Header().Set("Pragma", "no-cache") + rw.Header().Set("Content-Type", "application/json;charset=UTF-8") + + sendDebugMessagesToClient := f.Config.GetSendDebugMessagesToClients(ctx) + rfcerr := ErrorToRFC6749Error(err).WithLegacyFormat(f.Config.GetUseLegacyErrorFormat(ctx)). + WithExposeDebug(sendDebugMessagesToClient).WithLocalizer(f.Config.GetMessageCatalog(ctx), getLangFromRequester(ar)) + + js, err := json.Marshal(rfcerr) + if err != nil { + if sendDebugMessagesToClient { + errorMessage := EscapeJSONString(err.Error()) + http.Error(rw, fmt.Sprintf(`{"error":"server_error","error_description":"%s"}`, errorMessage), http.StatusInternalServerError) + } else { + http.Error(rw, `{"error":"server_error"}`, http.StatusInternalServerError) + } + return + } + + rw.WriteHeader(rfcerr.CodeField) + _, _ = rw.Write(js) +} diff --git a/pushed_authorize_response_writer_test.go b/pushed_authorize_response_writer_test.go new file mode 100644 index 000000000..2e62006ed --- /dev/null +++ b/pushed_authorize_response_writer_test.go @@ -0,0 +1,59 @@ +package fosite_test + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + . "github.com/ory/fosite" + . "github.com/ory/fosite/internal" +) + +func TestNewPushedAuthorizeResponse(t *testing.T) { + ctrl := gomock.NewController(t) + handlers := []*MockPushedAuthorizeEndpointHandler{NewMockPushedAuthorizeEndpointHandler(ctrl)} + ar := NewMockAuthorizeRequester(ctrl) + defer ctrl.Finish() + + ctx := context.Background() + oauth2 := &Fosite{ + Config: &Config{ + PushedAuthorizeEndpointHandlers: PushedAuthorizeEndpointHandlers{handlers[0]}, + }, + } + ar.EXPECT().SetSession(gomock.Eq(new(DefaultSession))).AnyTimes() + fooErr := errors.New("foo") + for k, c := range []struct { + isErr bool + mock func() + expectErr error + }{ + { + mock: func() { + handlers[0].EXPECT().HandlePushedAuthorizeEndpointRequest(gomock.Any(), gomock.Eq(ar), gomock.Any()).Return(fooErr) + }, + isErr: true, + expectErr: fooErr, + }, + { + mock: func() { + handlers[0].EXPECT().HandlePushedAuthorizeEndpointRequest(gomock.Any(), gomock.Eq(ar), gomock.Any()).Return(nil) + }, + isErr: false, + }, + } { + c.mock() + responder, err := oauth2.NewPushedAuthorizeResponse(ctx, ar, new(DefaultSession)) + assert.Equal(t, c.isErr, err != nil, "%d: %s", k, err) + if err != nil { + assert.Equal(t, c.expectErr, err, "%d: %s", k, err) + assert.Nil(t, responder, "%d", k) + } else { + assert.NotNil(t, responder, "%d", k) + } + t.Logf("Passed test case %d", k) + } +} diff --git a/storage.go b/storage.go index ddec91dec..efd98e784 100644 --- a/storage.go +++ b/storage.go @@ -21,7 +21,19 @@ package fosite +import "context" + // Storage defines fosite's minimal storage interface. type Storage interface { ClientManager } + +// PARStorage holds information needed to store and retrieve PAR context. +type PARStorage interface { + // CreatePARSession stores the pushed authorization request context. The requestURI is used to derive the key. + CreatePARSession(ctx context.Context, requestURI string, request AuthorizeRequester) error + // GetPARSession gets the push authorization request context. The caller is expected to merge the AuthorizeRequest. + GetPARSession(ctx context.Context, requestURI string) (AuthorizeRequester, error) + // DeletePARSession deletes the context. + DeletePARSession(ctx context.Context, requestURI string) (err error) +} diff --git a/storage/memory.go b/storage/memory.go index 3705df313..98f7f2083 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -65,6 +65,7 @@ type MemoryStore struct { RefreshTokenRequestIDs map[string]string // Public keys to check signature in auth grant jwt assertion. IssuerPublicKeys map[string]IssuerPublicKeys + PARSessions map[string]fosite.AuthorizeRequester clientsMutex sync.RWMutex authorizeCodesMutex sync.RWMutex @@ -77,6 +78,7 @@ type MemoryStore struct { accessTokenRequestIDsMutex sync.RWMutex refreshTokenRequestIDsMutex sync.RWMutex issuerPublicKeysMutex sync.RWMutex + parSessionsMutex sync.RWMutex } func NewMemoryStore() *MemoryStore { @@ -92,6 +94,7 @@ func NewMemoryStore() *MemoryStore { RefreshTokenRequestIDs: make(map[string]string), BlacklistedJTIs: make(map[string]time.Time), IssuerPublicKeys: make(map[string]IssuerPublicKeys), + PARSessions: make(map[string]fosite.AuthorizeRequester), } } @@ -454,3 +457,35 @@ func (s *MemoryStore) IsJWTUsed(ctx context.Context, jti string) (bool, error) { func (s *MemoryStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error { return s.SetClientAssertionJWT(ctx, jti, exp) } + +// CreatePARSession stores the pushed authorization request context. The requestURI is used to derive the key. +func (s *MemoryStore) CreatePARSession(ctx context.Context, requestURI string, request fosite.AuthorizeRequester) error { + s.parSessionsMutex.Lock() + defer s.parSessionsMutex.Unlock() + + s.PARSessions[requestURI] = request + return nil +} + +// GetPARSession gets the push authorization request context. If the request is nil, a new request object +// is created. Otherwise, the same object is updated. +func (s *MemoryStore) GetPARSession(ctx context.Context, requestURI string) (fosite.AuthorizeRequester, error) { + s.parSessionsMutex.RLock() + defer s.parSessionsMutex.RUnlock() + + r, ok := s.PARSessions[requestURI] + if !ok { + return nil, fosite.ErrNotFound + } + + return r, nil +} + +// DeletePARSession deletes the context. +func (s *MemoryStore) DeletePARSession(ctx context.Context, requestURI string) (err error) { + s.parSessionsMutex.Lock() + defer s.parSessionsMutex.Unlock() + + delete(s.PARSessions, requestURI) + return nil +}