From 7a0fcfc4fe83776fa09cf78ee11f407610554d04 Mon Sep 17 00:00:00 2001 From: Patrik Date: Tue, 16 Feb 2021 14:23:54 +0100 Subject: [PATCH] fix: ensure nil subject is not allowed (#449) The nodejs gRPC client was a great fuzzer and pointed me to some nil pointer dereference panics. This adds some input validation to prevent panics. --- cmd/expand/root.go | 8 +- internal/check/engine.go | 5 - internal/check/handler.go | 5 +- internal/expand/handler.go | 8 +- internal/expand/tree.go | 32 ++-- internal/persistence/sql/relationtuples.go | 4 + internal/relationtuple/definitions.go | 86 +++++++---- internal/relationtuple/definitions_test.go | 146 ++++++++++++++++++ .../relationtuple/manager_requirements.go | 3 +- internal/relationtuple/read_server.go | 14 +- internal/relationtuple/write_server.go | 18 ++- 11 files changed, 264 insertions(+), 65 deletions(-) diff --git a/cmd/expand/root.go b/cmd/expand/root.go index f074d7d70..af3887718 100644 --- a/cmd/expand/root.go +++ b/cmd/expand/root.go @@ -49,7 +49,13 @@ func NewExpandCmd() *cobra.Command { return cmdx.FailSilently(cmd) } - cmdx.PrintJSONAble(cmd, expand.TreeFromProto(resp.Tree)) + tree, err := expand.TreeFromProto(resp.Tree) + if err != nil { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Error building the tree: %s\n", err.Error()) + return cmdx.FailSilently(cmd) + } + + cmdx.PrintJSONAble(cmd, tree) switch flagx.MustGetString(cmd, cmdx.FlagFormat) { case string(cmdx.FormatDefault), "": _, _ = fmt.Fprintln(cmd.OutOrStdout()) diff --git a/internal/check/engine.go b/internal/check/engine.go index ea6b0d81e..fffecec02 100644 --- a/internal/check/engine.go +++ b/internal/check/engine.go @@ -3,7 +3,6 @@ package check import ( "context" "errors" - "fmt" "github.com/ory/herodot" @@ -37,10 +36,6 @@ func (e *Engine) subjectIsAllowed(ctx context.Context, requested *relationtuple. var allowed bool for _, sr := range rels { - // TODO move this to input validation - if requested.Subject == nil { - return false, fmt.Errorf("subject is unexpectedly nil for %+v", requested) - } // we only have to check Subject here as we know that sr was reached from requested.ObjectID, requested.Relation through 0...n indirections if requested.Subject.Equals(sr.Subject) { // found the requested relation diff --git a/internal/check/handler.go b/internal/check/handler.go index 02ddc23b6..80acad57f 100644 --- a/internal/check/handler.go +++ b/internal/check/handler.go @@ -67,7 +67,10 @@ func (h *Handler) getCheck(w http.ResponseWriter, r *http.Request, _ httprouter. } func (h *Handler) Check(ctx context.Context, req *acl.CheckRequest) (*acl.CheckResponse, error) { - tuple := (&relationtuple.InternalRelationTuple{}).FromDataProvider(req) + tuple, err := (&relationtuple.InternalRelationTuple{}).FromDataProvider(req) + if err != nil { + return nil, err + } allowed, err := h.d.PermissionEngine().SubjectIsAllowed(ctx, tuple) // TODO add content change handling diff --git a/internal/expand/handler.go b/internal/expand/handler.go index 5d33da227..a83cc7450 100644 --- a/internal/expand/handler.go +++ b/internal/expand/handler.go @@ -64,9 +64,11 @@ func (h *handler) getExpand(w http.ResponseWriter, r *http.Request, _ httprouter } func (h *handler) Expand(ctx context.Context, req *acl.ExpandRequest) (*acl.ExpandResponse, error) { - tree, err := h.d.ExpandEngine().BuildTree(ctx, - relationtuple.SubjectFromProto(req.Subject), - int(req.MaxDepth)) + sub, err := relationtuple.SubjectFromProto(req.Subject) + if err != nil { + return nil, err + } + tree, err := h.d.ExpandEngine().BuildTree(ctx, sub, int(req.MaxDepth)) if err != nil { return nil, err } diff --git a/internal/expand/tree.go b/internal/expand/tree.go index cfdf67551..94ebf0a4d 100644 --- a/internal/expand/tree.go +++ b/internal/expand/tree.go @@ -138,24 +138,28 @@ func (t *Tree) ToProto() *acl.SubjectTree { } } -func TreeFromProto(t *acl.SubjectTree) *Tree { - if t.NodeType == acl.NodeType_NODE_TYPE_LEAF { - return &Tree{ - Type: Leaf, - Subject: relationtuple.SubjectFromProto(t.Subject), - } +func TreeFromProto(t *acl.SubjectTree) (*Tree, error) { + sub, err := relationtuple.SubjectFromProto(t.Subject) + if err != nil { + return nil, err } - - children := make([]*Tree, len(t.Children)) - for i, c := range t.Children { - children[i] = TreeFromProto(c) + self := &Tree{ + Type: NodeTypeFromProto(t.NodeType), + Subject: sub, } - return &Tree{ - Type: NodeTypeFromProto(t.NodeType), - Subject: relationtuple.SubjectFromProto(t.Subject), - Children: children, + if t.NodeType != acl.NodeType_NODE_TYPE_LEAF { + self.Children = make([]*Tree, len(t.Children)) + for i, c := range t.Children { + var err error + self.Children[i], err = TreeFromProto(c) + if err != nil { + return nil, err + } + } } + + return self, nil } func (t *Tree) String() string { diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 38b5c06c0..80f7b8fc9 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -72,6 +72,10 @@ func (r *relationTuple) toInternal() (*relationtuple.InternalRelationTuple, erro } func (p *Persister) insertRelationTuple(ctx context.Context, rel *relationtuple.InternalRelationTuple) error { + if rel.Subject == nil { + return errors.New("subject is not allowed to be nil") + } + n, err := p.namespaces.GetNamespace(ctx, rel.Namespace) if err != nil { return err diff --git a/internal/relationtuple/definitions.go b/internal/relationtuple/definitions.go index 63f603b3d..898b5ac1a 100644 --- a/internal/relationtuple/definitions.go +++ b/internal/relationtuple/definitions.go @@ -19,8 +19,6 @@ import ( "github.com/ory/keto/internal/x" "github.com/tidwall/gjson" - - "github.com/ory/x/cmdx" ) type ( @@ -76,6 +74,7 @@ var ( _, _ Subject = &SubjectID{}, &SubjectSet{} ErrMalformedInput = errors.New("malformed string input") + ErrNilSubject = errors.New("subject is nil") ) func SubjectFromString(s string) (Subject, error) { @@ -85,20 +84,22 @@ func SubjectFromString(s string) (Subject, error) { return (&SubjectID{}).FromString(s) } -func SubjectFromProto(gs *acl.Subject) Subject { +func SubjectFromProto(gs *acl.Subject) (Subject, error) { switch s := gs.GetRef().(type) { + case nil: + return nil, errors.WithStack(ErrNilSubject) case *acl.Subject_Id: return &SubjectID{ ID: s.Id, - } + }, nil case *acl.Subject_Set: return &SubjectSet{ Namespace: s.Set.Namespace, Object: s.Set.Object, Relation: s.Set.Relation, - } + }, nil } - return nil + return nil, errors.WithStack(ErrNilSubject) } func (s *SubjectID) String() string { @@ -235,13 +236,18 @@ func (r *InternalRelationTuple) MarshalJSON() ([]byte, error) { return sjson.SetBytes(enc, "subject", r.Subject.String()) } -func (r *InternalRelationTuple) FromDataProvider(d TupleData) *InternalRelationTuple { - r.Subject = SubjectFromProto(d.GetSubject()) +func (r *InternalRelationTuple) FromDataProvider(d TupleData) (*InternalRelationTuple, error) { + var err error + r.Subject, err = SubjectFromProto(d.GetSubject()) + if err != nil { + return nil, err + } + r.Object = d.GetObject() r.Namespace = d.GetNamespace() r.Relation = d.GetRelation() - return r + return r, nil } func (r *InternalRelationTuple) ToProto() *acl.RelationTuple { @@ -290,13 +296,14 @@ func (r *InternalRelationTuple) ToLoggerFields() logrus.Fields { } } -func (q *RelationQuery) FromProto(query *acl.ListRelationTuplesRequest_Query) *RelationQuery { - return &RelationQuery{ - Namespace: query.Namespace, - Object: query.Object, - Relation: query.Relation, - Subject: SubjectFromProto(query.Subject), +func (q *RelationQuery) FromProto(query *acl.ListRelationTuplesRequest_Query) (*RelationQuery, error) { + r, err := (&InternalRelationTuple{}).FromDataProvider(query) + if err != nil { + return nil, err } + + *q = RelationQuery(*r) + return q, nil } func (q *RelationQuery) FromURLQuery(query url.Values) (*RelationQuery, error) { @@ -386,39 +393,47 @@ func (r *RelationCollection) Header() []string { } func (r *RelationCollection) Table() [][]string { - if r.internalRelations == nil { - for _, rel := range r.protoRelations { - r.internalRelations = append(r.internalRelations, (&InternalRelationTuple{}).FromDataProvider(rel)) - } + ir, err := r.ToInternal() + if err != nil { + return [][]string{{fmt.Sprintf("%+v", err)}} } - data := make([][]string, len(r.internalRelations)) - for i, rel := range r.internalRelations { - data[i] = []string{rel.Namespace, rel.Object, rel.Relation, cmdx.None} - if rel.Subject != nil { - data[i][3] = rel.Subject.String() - } + data := make([][]string, len(ir)) + for i, rel := range ir { + data[i] = []string{rel.Namespace, rel.Object, rel.Relation, rel.Subject.String()} } return data } -func (r *RelationCollection) ToInternal() []*InternalRelationTuple { +func (r *RelationCollection) ToInternal() ([]*InternalRelationTuple, error) { if r.internalRelations == nil { r.internalRelations = make([]*InternalRelationTuple, len(r.protoRelations)) for i, rel := range r.protoRelations { - r.internalRelations[i] = (&InternalRelationTuple{}).FromDataProvider(rel) + ir, err := (&InternalRelationTuple{}).FromDataProvider(rel) + if err != nil { + return nil, err + } + r.internalRelations[i] = ir } } - return r.internalRelations + return r.internalRelations, nil } func (r *RelationCollection) Interface() interface{} { - return r.ToInternal() + ir, err := r.ToInternal() + if err != nil { + return err + } + return ir } func (r *RelationCollection) MarshalJSON() ([]byte, error) { - return json.Marshal(r.ToInternal()) + ir, err := r.ToInternal() + if err != nil { + return nil, err + } + return json.Marshal(ir) } func (r *RelationCollection) UnmarshalJSON(raw []byte) error { @@ -426,11 +441,18 @@ func (r *RelationCollection) UnmarshalJSON(raw []byte) error { } func (r *RelationCollection) Len() int { - return len(r.ToInternal()) + if ir := len(r.internalRelations); ir > 0 { + return ir + } + return len(r.protoRelations) } func (r *RelationCollection) IDs() []string { - rts := r.ToInternal() + rts, err := r.ToInternal() + if err != nil { + // fmt.Sprintf to include the stacktrace + return []string{fmt.Sprintf("%+v", err)} + } ids := make([]string, len(rts)) for i, rt := range rts { ids[i] = rt.String() diff --git a/internal/relationtuple/definitions_test.go b/internal/relationtuple/definitions_test.go index 09165cfdb..c8b6d9a47 100644 --- a/internal/relationtuple/definitions_test.go +++ b/internal/relationtuple/definitions_test.go @@ -8,6 +8,8 @@ import ( "strconv" "testing" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -68,6 +70,47 @@ func TestSubject(t *testing.T) { } }) + t.Run("case=proto decoding", func(t *testing.T) { + for i, tc := range []struct { + proto *acl.Subject + expected Subject + err error + }{ + { + proto: &acl.Subject{ + Ref: &acl.Subject_Id{Id: "foo"}, + }, + expected: &SubjectID{ID: "foo"}, + }, + { + proto: nil, + err: ErrNilSubject, + }, + { + proto: &acl.Subject{ + Ref: &acl.Subject_Set{ + Set: &acl.SubjectSet{ + Namespace: "n", + Object: "o", + Relation: "r", + }, + }, + }, + expected: &SubjectSet{ + Namespace: "n", + Object: "o", + Relation: "r", + }, + }, + } { + t.Run(fmt.Sprintf("case=%d", i), func(t *testing.T) { + actual, err := SubjectFromProto(tc.proto) + require.True(t, errors.Is(err, tc.err)) + assert.Equal(t, tc.expected, actual) + }) + } + }) + t.Run("method=equals", func(t *testing.T) { for i, tc := range []struct { a, b Subject @@ -269,6 +312,76 @@ func TestInternalRelationTuple(t *testing.T) { }) } }) + + t.Run("case=proto decoding", func(t *testing.T) { + for i, tc := range []struct { + proto TupleData + expected *InternalRelationTuple + err error + }{ + { + proto: &acl.RelationTuple{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: nil, + }, + err: ErrNilSubject, + }, + { + proto: &acl.RelationTuple{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: &acl.Subject{ + Ref: &acl.Subject_Set{ + Set: &acl.SubjectSet{ + Namespace: "n", + Object: "o", + Relation: "r", + }, + }, + }, + }, + expected: &InternalRelationTuple{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: &SubjectSet{ + Namespace: "n", + Object: "o", + Relation: "r", + }, + }, + }, + { + proto: &acl.RelationTuple{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: &acl.Subject{ + Ref: &acl.Subject_Id{ + Id: "user", + }, + }, + }, + expected: &InternalRelationTuple{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: &SubjectID{ + ID: "user", + }, + }, + }, + } { + t.Run(fmt.Sprintf("case=%d", i), func(t *testing.T) { + actual, err := (&InternalRelationTuple{}).FromDataProvider(tc.proto) + require.True(t, errors.Is(err, tc.err)) + assert.Equal(t, tc.expected, actual) + }) + } + }) } func TestRelationQuery(t *testing.T) { @@ -400,4 +513,37 @@ func TestRelationCollection(t *testing.T) { }) } }) + + t.Run("func=toInternal", func(t *testing.T) { + for i, tc := range []struct { + collection *RelationCollection + expected []*InternalRelationTuple + err error + }{ + { + collection: NewProtoRelationCollection([]*acl.RelationTuple{{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: (&SubjectID{ID: "s"}).ToProto(), + }}), + expected: []*InternalRelationTuple{{ + Namespace: "n", + Object: "o", + Relation: "r", + Subject: &SubjectID{ID: "s"}, + }}, + }, + { + collection: NewProtoRelationCollection([]*acl.RelationTuple{{ /*subject is nil*/ }}), + err: ErrNilSubject, + }, + } { + t.Run(fmt.Sprintf("case=%d", i), func(t *testing.T) { + actual, err := tc.collection.ToInternal() + require.True(t, errors.Is(err, tc.err)) + assert.Equal(t, tc.expected, actual) + }) + } + }) } diff --git a/internal/relationtuple/manager_requirements.go b/internal/relationtuple/manager_requirements.go index 6b09d08d8..de117ea2f 100644 --- a/internal/relationtuple/manager_requirements.go +++ b/internal/relationtuple/manager_requirements.go @@ -54,9 +54,10 @@ func ManagerTest(t *testing.T, m Manager, addNamespace func(context.Context, *te t.Run("case=unknown namespace", func(t *testing.T) { err := m.WriteRelationTuples(context.Background(), &InternalRelationTuple{ Namespace: "unknown namespace", + Subject: &SubjectID{}, }) assert.NotNil(t, err) - assert.True(t, errors.Is(err, herodot.ErrNotFound)) + assert.True(t, errors.Is(err, herodot.ErrNotFound), "actual error: %+v", err) }) }) diff --git a/internal/relationtuple/read_server.go b/internal/relationtuple/read_server.go index 3011a826b..8b5fe473b 100644 --- a/internal/relationtuple/read_server.go +++ b/internal/relationtuple/read_server.go @@ -5,6 +5,8 @@ import ( "net/http" "strconv" + "github.com/pkg/errors" + acl "github.com/ory/keto/proto/ory/keto/acl/v1alpha1" "github.com/julienschmidt/httprouter" @@ -15,12 +17,22 @@ import ( var _ acl.ReadServiceServer = (*handler)(nil) func (h *handler) ListRelationTuples(ctx context.Context, req *acl.ListRelationTuplesRequest) (*acl.ListRelationTuplesResponse, error) { + if req.Query == nil { + return nil, errors.New("invalid request") + } + + sub, err := SubjectFromProto(req.Query.Subject) + if err != nil { + // this means we are not querying by subject + sub = nil + } + rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(ctx, &RelationQuery{ Namespace: req.Query.Namespace, Object: req.Query.Object, Relation: req.Query.Relation, - Subject: SubjectFromProto(req.Query.Subject), + Subject: sub, }, x.WithSize(int(req.PageSize)), x.WithToken(req.PageToken), diff --git a/internal/relationtuple/write_server.go b/internal/relationtuple/write_server.go index ee4ce2ce4..451deb7b3 100644 --- a/internal/relationtuple/write_server.go +++ b/internal/relationtuple/write_server.go @@ -14,22 +14,26 @@ import ( var _ acl.WriteServiceServer = (*handler)(nil) -func tuplesWithAction(deltas []*acl.RelationTupleDelta, action acl.RelationTupleDelta_Action) (filtered []*InternalRelationTuple) { +func tuplesWithAction(deltas []*acl.RelationTupleDelta, action acl.RelationTupleDelta_Action) (filtered []*InternalRelationTuple, err error) { for _, d := range deltas { if d.Action == action { - filtered = append( - filtered, - (&InternalRelationTuple{}).FromDataProvider(d.RelationTuple), - ) + it, err := (&InternalRelationTuple{}).FromDataProvider(d.RelationTuple) + if err != nil { + return nil, err + } + filtered = append(filtered, it) } } return } func (h *handler) TransactRelationTuples(ctx context.Context, req *acl.TransactRelationTuplesRequest) (*acl.TransactRelationTuplesResponse, error) { - insertTuples := tuplesWithAction(req.RelationTupleDeltas, acl.RelationTupleDelta_INSERT) + insertTuples, err := tuplesWithAction(req.RelationTupleDeltas, acl.RelationTupleDelta_INSERT) + if err != nil { + return nil, err + } - err := h.d.RelationTupleManager().WriteRelationTuples(ctx, insertTuples...) + err = h.d.RelationTupleManager().WriteRelationTuples(ctx, insertTuples...) if err != nil { return nil, err }