From e0243a201982d1b355ee43a372f0f099c4d7ba99 Mon Sep 17 00:00:00 2001 From: hperl <34397+hperl@users.noreply.github.com> Date: Tue, 31 Jan 2023 10:28:25 +0100 Subject: [PATCH] feat: DELETE must now at least specify the namespace --- internal/relationtuple/transact_server.go | 10 ++-- internal/x/validate/validate.go | 72 +++++++++++++++++++---- internal/x/validate/validate_test.go | 46 ++++++++++++++- ketoapi/public_api_definitions.go | 2 +- 4 files changed, 109 insertions(+), 21 deletions(-) diff --git a/internal/relationtuple/transact_server.go b/internal/relationtuple/transact_server.go index 58860e5a9..b84dd92e3 100644 --- a/internal/relationtuple/transact_server.go +++ b/internal/relationtuple/transact_server.go @@ -176,11 +176,11 @@ func (h *handler) createRelation(w http.ResponseWriter, r *http.Request, _ httpr func (h *handler) deleteRelations(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ctx := r.Context() - if err := validate.NoExtraQueryParams(r, ketoapi.RelationQueryKeys...); err != nil { - h.d.Writer().WriteError(w, r, err) - return - } - if err := validate.HasEmptyBody(r); err != nil { + if err := validate.All(r, + validate.NoExtraQueryParams(ketoapi.RelationQueryKeys...), + validate.QueryParamsContainsOneOf(ketoapi.NamespaceKey), + validate.HasEmptyBody(), + ); err != nil { h.d.Writer().WriteError(w, r, err) return } diff --git a/internal/x/validate/validate.go b/internal/x/validate/validate.go index a2e59c58c..e29993ac1 100644 --- a/internal/x/validate/validate.go +++ b/internal/x/validate/validate.go @@ -1,29 +1,75 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package validate import ( + "fmt" "io" "net/http" + "strings" "github.com/ory/herodot" ) -func NoExtraQueryParams(req *http.Request, except ...string) error { - allowed := make(map[string]struct{}, len(except)) - for _, e := range except { - allowed[e] = struct{}{} - } - for key := range req.URL.Query() { - if _, found := allowed[key]; !found { - return herodot.ErrBadRequest.WithReasonf("query parameter key %q unknown", key) +type Validator func(r *http.Request) (ok bool, reason string) + +// All runs all validators and returns an error if any of them fail. It returns +// a ErrBadRequest with all failed validation messages. +func All(r *http.Request, validator ...Validator) error { + var reasons []string + for _, v := range validator { + if ok, msg := v(r); !ok { + reasons = append(reasons, msg) } } + if len(reasons) > 0 { + return herodot.ErrBadRequest.WithReason(strings.Join(reasons, "; ")) + } return nil } -func HasEmptyBody(r *http.Request) error { - _, err := r.Body.Read([]byte{}) - if err != io.EOF { - return herodot.ErrBadRequest.WithReason("body is not empty") +// NoExtraQueryParams returns a validator that checks if the request has any +// query parameters that are not in the except list. +func NoExtraQueryParams(except ...string) Validator { + return func(req *http.Request) (ok bool, reason string) { + allowed := make(map[string]struct{}, len(except)) + for _, e := range except { + allowed[e] = struct{}{} + } + for key := range req.URL.Query() { + if _, found := allowed[key]; !found { + return false, fmt.Sprintf("query parameter key %q unknown", key) + } + } + return true, "" + } +} + +// QueryParamsContainsOneOf returns a validator that checks if the request has +// at least one of the specified query parameters. +func QueryParamsContainsOneOf(keys ...string) Validator { + return func(req *http.Request) (ok bool, reason string) { + oneOfKeys := make(map[string]struct{}, len(keys)) + for _, k := range keys { + oneOfKeys[k] = struct{}{} + } + for key := range req.URL.Query() { + if _, found := oneOfKeys[key]; found { + return true, "" + } + } + return false, fmt.Sprintf("quey parameters must specify at least one of the following: %s", strings.Join(keys, ", ")) + } +} + +// HasEmptyBody returns a validator that checks if the request body is empty. +func HasEmptyBody() Validator { + return func(r *http.Request) (ok bool, reason string) { + _, err := r.Body.Read([]byte{}) + if err != io.EOF { + return false, "body is not empty" + } + return true, "" } - return nil } diff --git a/internal/x/validate/validate_test.go b/internal/x/validate/validate_test.go index 2f7192fff..50ff5d8f7 100644 --- a/internal/x/validate/validate_test.go +++ b/internal/x/validate/validate_test.go @@ -1,3 +1,6 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + package validate_test import ( @@ -42,7 +45,46 @@ func TestValidateNoExtraParams(t *testing.T) { }, } { t.Run("case="+tt.name, func(t *testing.T) { - err := validate.NoExtraQueryParams(tt.req, "foo", "bar") + err := validate.All(tt.req, validate.NoExtraQueryParams("foo", "bar")) + tt.assertErr(t, err) + }) + } +} + +func TestQueryParamsContainsOneOf(t *testing.T) { + for _, tt := range []struct { + name string + req *http.Request + assertErr assert.ErrorAssertionFunc + }{ + { + name: "empty", + req: &http.Request{URL: toURL(t, "https://example.com")}, + assertErr: assert.Error, + }, + { + name: "other", + req: &http.Request{URL: toURL(t, "https://example.com?a=1&b=2&c=3")}, + assertErr: assert.Error, + }, + { + name: "one", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1")}, + assertErr: assert.NoError, + }, + { + name: "all params", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")}, + assertErr: assert.NoError, + }, + { + name: "extra params", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")}, + assertErr: assert.NoError, + }, + } { + t.Run("case="+tt.name, func(t *testing.T) { + err := validate.All(tt.req, validate.QueryParamsContainsOneOf("foo", "bar")) tt.assertErr(t, err) }) } @@ -66,7 +108,7 @@ func TestValidateHasEmptyBody(t *testing.T) { }, } { t.Run("case="+tt.name, func(t *testing.T) { - err := validate.HasEmptyBody(tt.req) + err := validate.All(tt.req, validate.HasEmptyBody()) tt.assertErr(t, err) }) } diff --git a/ketoapi/public_api_definitions.go b/ketoapi/public_api_definitions.go index 0afd1cedd..1e96eac56 100644 --- a/ketoapi/public_api_definitions.go +++ b/ketoapi/public_api_definitions.go @@ -138,7 +138,7 @@ var RelationQueryKeys = []string{ SubjectSetNamespaceKey, SubjectSetObjectKey, SubjectSetRelationKey, - "subject", // We have a mnore specific error message for this key. + "subject", // We have a more specific error message for this key. } // Paginated Relationship List