diff --git a/internal/relationtuple/transact_server.go b/internal/relationtuple/transact_server.go index 1e08332a2..b84dd92e3 100644 --- a/internal/relationtuple/transact_server.go +++ b/internal/relationtuple/transact_server.go @@ -8,6 +8,7 @@ import ( "encoding/json" "net/http" + "github.com/ory/keto/internal/x/validate" "github.com/ory/keto/ketoapi" rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" @@ -175,6 +176,15 @@ func (h *handler) createRelation(w http.ResponseWriter, r *http.Request, _ httpr func (h *handler) deleteRelations(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { ctx := r.Context() + if err := validate.All(r, + validate.NoExtraQueryParams(ketoapi.RelationQueryKeys...), + validate.QueryParamsContainsOneOf(ketoapi.NamespaceKey), + validate.HasEmptyBody(), + ); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } + q := r.URL.Query() query, err := (&ketoapi.RelationQuery{}).FromURLQuery(q) if err != nil { diff --git a/internal/relationtuple/transact_server_test.go b/internal/relationtuple/transact_server_test.go index a8ed8cf5a..86f6839c2 100644 --- a/internal/relationtuple/transact_server_test.go +++ b/internal/relationtuple/transact_server_test.go @@ -11,23 +11,20 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "github.com/julienschmidt/httprouter" "github.com/ory/x/pointerx" - - "github.com/ory/keto/ketoapi" - - "github.com/ory/keto/internal/driver/config" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/driver" + "github.com/ory/keto/internal/driver/config" "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" ) func TestWriteHandlers(t *testing.T) { @@ -218,6 +215,85 @@ func TestWriteHandlers(t *testing.T) { require.NoError(t, err) assert.Equal(t, []*relationtuple.RelationTuple{}, actualRTs) }) + + t.Run("suite=bad requests", func(t *testing.T) { + nspace := addNamespace(t) + + rts := []*ketoapi.RelationTuple{ + { + Namespace: nspace.Name, + Object: "deleted obj", + Relation: "deleted rel", + SubjectID: pointerx.Ptr("deleted subj 1"), + }, + { + Namespace: nspace.Name, + Object: "deleted obj", + Relation: "deleted rel", + SubjectID: pointerx.Ptr("deleted subj 2"), + }, + } + + relationtuple.MapAndWriteTuples(t, reg, rts...) + + assertBadRequest := func(t *testing.T, req *http.Request) { + resp, err := ts.Client().Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + } + + assertTuplesExist := func(t *testing.T) { + mappedQuery, err := reg.Mapper().FromQuery(ctx, &ketoapi.RelationQuery{ + Namespace: &nspace.Name, + }) + require.NoError(t, err) + + actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(ctx, mappedQuery, x.WithSize(10)) + require.NoError(t, err) + mappedRTs, err := reg.Mapper().ToTuple(ctx, actualRTs...) + require.NoError(t, err) + assert.ElementsMatch(t, rts, mappedRTs) + } + + t.Run("case=bad request if body sent", func(t *testing.T) { + q := url.Values{ + "namespace": {nspace.Name}, + "object": {"deleted obj"}, + "relation": {"deleted rel"}, + } + req, err := http.NewRequest( + http.MethodDelete, + ts.URL+relationtuple.WriteRouteBase+"?"+q.Encode(), + strings.NewReader("some body")) + require.NoError(t, err) + + assertBadRequest(t, req) + assertTuplesExist(t) + }) + + t.Run("case=bad request query param misspelled", func(t *testing.T) { + req, err := http.NewRequest( + http.MethodDelete, + ts.URL+relationtuple.WriteRouteBase+"?invalid=param", + nil) + require.NoError(t, err) + + assertBadRequest(t, req) + assertTuplesExist(t) + }) + + t.Run("case=bad request if query params misssing", func(t *testing.T) { + req, err := http.NewRequest( + http.MethodDelete, + ts.URL+relationtuple.WriteRouteBase, + nil) + require.NoError(t, err) + + assertBadRequest(t, req) + assertTuplesExist(t) + }) + }) + }) t.Run("method=patch", func(t *testing.T) { diff --git a/internal/x/validate/validate.go b/internal/x/validate/validate.go new file mode 100644 index 000000000..e29993ac1 --- /dev/null +++ b/internal/x/validate/validate.go @@ -0,0 +1,75 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package validate + +import ( + "fmt" + "io" + "net/http" + "strings" + + "github.com/ory/herodot" +) + +type Validator func(r *http.Request) (ok bool, reason string) + +// All runs all validators and returns an error if any of them fail. It returns +// a ErrBadRequest with all failed validation messages. +func All(r *http.Request, validator ...Validator) error { + var reasons []string + for _, v := range validator { + if ok, msg := v(r); !ok { + reasons = append(reasons, msg) + } + } + if len(reasons) > 0 { + return herodot.ErrBadRequest.WithReason(strings.Join(reasons, "; ")) + } + return nil +} + +// NoExtraQueryParams returns a validator that checks if the request has any +// query parameters that are not in the except list. +func NoExtraQueryParams(except ...string) Validator { + return func(req *http.Request) (ok bool, reason string) { + allowed := make(map[string]struct{}, len(except)) + for _, e := range except { + allowed[e] = struct{}{} + } + for key := range req.URL.Query() { + if _, found := allowed[key]; !found { + return false, fmt.Sprintf("query parameter key %q unknown", key) + } + } + return true, "" + } +} + +// QueryParamsContainsOneOf returns a validator that checks if the request has +// at least one of the specified query parameters. +func QueryParamsContainsOneOf(keys ...string) Validator { + return func(req *http.Request) (ok bool, reason string) { + oneOfKeys := make(map[string]struct{}, len(keys)) + for _, k := range keys { + oneOfKeys[k] = struct{}{} + } + for key := range req.URL.Query() { + if _, found := oneOfKeys[key]; found { + return true, "" + } + } + return false, fmt.Sprintf("quey parameters must specify at least one of the following: %s", strings.Join(keys, ", ")) + } +} + +// HasEmptyBody returns a validator that checks if the request body is empty. +func HasEmptyBody() Validator { + return func(r *http.Request) (ok bool, reason string) { + _, err := r.Body.Read([]byte{}) + if err != io.EOF { + return false, "body is not empty" + } + return true, "" + } +} diff --git a/internal/x/validate/validate_test.go b/internal/x/validate/validate_test.go new file mode 100644 index 000000000..50ff5d8f7 --- /dev/null +++ b/internal/x/validate/validate_test.go @@ -0,0 +1,116 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package validate_test + +import ( + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/x/validate" +) + +func toURL(t *testing.T, s string) *url.URL { + u, err := url.Parse(s) + require.NoError(t, err) + return u +} + +func TestValidateNoExtraParams(t *testing.T) { + for _, tt := range []struct { + name string + req *http.Request + assertErr assert.ErrorAssertionFunc + }{ + { + name: "empty", + req: &http.Request{URL: toURL(t, "https://example.com")}, + assertErr: assert.NoError, + }, + { + name: "all params", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")}, + assertErr: assert.NoError, + }, + { + name: "extra params", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")}, + assertErr: assert.Error, + }, + } { + t.Run("case="+tt.name, func(t *testing.T) { + err := validate.All(tt.req, validate.NoExtraQueryParams("foo", "bar")) + tt.assertErr(t, err) + }) + } +} + +func TestQueryParamsContainsOneOf(t *testing.T) { + for _, tt := range []struct { + name string + req *http.Request + assertErr assert.ErrorAssertionFunc + }{ + { + name: "empty", + req: &http.Request{URL: toURL(t, "https://example.com")}, + assertErr: assert.Error, + }, + { + name: "other", + req: &http.Request{URL: toURL(t, "https://example.com?a=1&b=2&c=3")}, + assertErr: assert.Error, + }, + { + name: "one", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1")}, + assertErr: assert.NoError, + }, + { + name: "all params", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")}, + assertErr: assert.NoError, + }, + { + name: "extra params", + req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")}, + assertErr: assert.NoError, + }, + } { + t.Run("case="+tt.name, func(t *testing.T) { + err := validate.All(tt.req, validate.QueryParamsContainsOneOf("foo", "bar")) + tt.assertErr(t, err) + }) + } +} + +func TestValidateHasEmptyBody(t *testing.T) { + for _, tt := range []struct { + name string + req *http.Request + assertErr assert.ErrorAssertionFunc + }{ + { + name: "empty body", + req: &http.Request{Body: io.NopCloser(strings.NewReader(""))}, + assertErr: assert.NoError, + }, + { + name: "non-empty body", + req: &http.Request{Body: io.NopCloser(strings.NewReader("content"))}, + assertErr: assert.Error, + }, + } { + t.Run("case="+tt.name, func(t *testing.T) { + err := validate.All(tt.req, validate.HasEmptyBody()) + tt.assertErr(t, err) + }) + } + +} diff --git a/ketoapi/enc_url_query.go b/ketoapi/enc_url_query.go index 0001c391a..98e397a24 100644 --- a/ketoapi/enc_url_query.go +++ b/ketoapi/enc_url_query.go @@ -40,14 +40,14 @@ func (q *RelationQuery) FromURLQuery(query url.Values) (*RelationQuery, error) { return nil, ErrIncompleteSubject } - if query.Has("namespace") { - q.Namespace = pointerx.Ptr(query.Get("namespace")) + if query.Has(NamespaceKey) { + q.Namespace = pointerx.Ptr(query.Get(NamespaceKey)) } - if query.Has("object") { - q.Object = pointerx.Ptr(query.Get("object")) + if query.Has(ObjectKey) { + q.Object = pointerx.Ptr(query.Get(ObjectKey)) } - if query.Has("relation") { - q.Relation = pointerx.Ptr(query.Get("relation")) + if query.Has(RelationKey) { + q.Relation = pointerx.Ptr(query.Get(RelationKey)) } return q, nil @@ -57,13 +57,13 @@ func (q *RelationQuery) ToURLQuery() url.Values { v := make(url.Values, 7) if q.Namespace != nil { - v.Add("namespace", *q.Namespace) + v.Add(NamespaceKey, *q.Namespace) } if q.Relation != nil { - v.Add("relation", *q.Relation) + v.Add(RelationKey, *q.Relation) } if q.Object != nil { - v.Add("object", *q.Object) + v.Add(ObjectKey, *q.Object) } if q.SubjectID != nil { v.Add(SubjectIDKey, *q.SubjectID) @@ -112,17 +112,17 @@ func (s *SubjectSet) FromURLQuery(values url.Values) *SubjectSet { s = &SubjectSet{} } - s.Namespace = values.Get("namespace") - s.Relation = values.Get("relation") - s.Object = values.Get("object") + s.Namespace = values.Get(NamespaceKey) + s.Relation = values.Get(RelationKey) + s.Object = values.Get(ObjectKey) return s } func (s *SubjectSet) ToURLQuery() url.Values { return url.Values{ - "namespace": []string{s.Namespace}, - "object": []string{s.Object}, - "relation": []string{s.Relation}, + NamespaceKey: []string{s.Namespace}, + ObjectKey: []string{s.Object}, + RelationKey: []string{s.Relation}, } } diff --git a/ketoapi/public_api_definitions.go b/ketoapi/public_api_definitions.go index 85636f8d2..1e96eac56 100644 --- a/ketoapi/public_api_definitions.go +++ b/ketoapi/public_api_definitions.go @@ -121,12 +121,26 @@ const ( ) const ( + NamespaceKey = "namespace" + ObjectKey = "object" + RelationKey = "relation" SubjectIDKey = "subject_id" SubjectSetNamespaceKey = "subject_set.namespace" SubjectSetObjectKey = "subject_set.object" SubjectSetRelationKey = "subject_set.relation" ) +var RelationQueryKeys = []string{ + NamespaceKey, + ObjectKey, + RelationKey, + SubjectIDKey, + SubjectSetNamespaceKey, + SubjectSetObjectKey, + SubjectSetRelationKey, + "subject", // We have a more specific error message for this key. +} + // Paginated Relationship List // // swagger:model relationships