Skip to content

Commit

Permalink
fix: ensure nil subject is not allowed (#449)
Browse files Browse the repository at this point in the history
The nodejs gRPC client was a great fuzzer and pointed me to some nil pointer dereference panics.
This adds some input validation to prevent panics.
  • Loading branch information
zepatrik authored Feb 16, 2021
1 parent 3b5c313 commit 7a0fcfc
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 65 deletions.
8 changes: 7 additions & 1 deletion cmd/expand/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ func NewExpandCmd() *cobra.Command {
return cmdx.FailSilently(cmd)
}

cmdx.PrintJSONAble(cmd, expand.TreeFromProto(resp.Tree))
tree, err := expand.TreeFromProto(resp.Tree)
if err != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Error building the tree: %s\n", err.Error())
return cmdx.FailSilently(cmd)
}

cmdx.PrintJSONAble(cmd, tree)
switch flagx.MustGetString(cmd, cmdx.FlagFormat) {
case string(cmdx.FormatDefault), "":
_, _ = fmt.Fprintln(cmd.OutOrStdout())
Expand Down
5 changes: 0 additions & 5 deletions internal/check/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package check
import (
"context"
"errors"
"fmt"

"github.com/ory/herodot"

Expand Down Expand Up @@ -37,10 +36,6 @@ func (e *Engine) subjectIsAllowed(ctx context.Context, requested *relationtuple.

var allowed bool
for _, sr := range rels {
// TODO move this to input validation
if requested.Subject == nil {
return false, fmt.Errorf("subject is unexpectedly nil for %+v", requested)
}
// we only have to check Subject here as we know that sr was reached from requested.ObjectID, requested.Relation through 0...n indirections
if requested.Subject.Equals(sr.Subject) {
// found the requested relation
Expand Down
5 changes: 4 additions & 1 deletion internal/check/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ func (h *Handler) getCheck(w http.ResponseWriter, r *http.Request, _ httprouter.
}

func (h *Handler) Check(ctx context.Context, req *acl.CheckRequest) (*acl.CheckResponse, error) {
tuple := (&relationtuple.InternalRelationTuple{}).FromDataProvider(req)
tuple, err := (&relationtuple.InternalRelationTuple{}).FromDataProvider(req)
if err != nil {
return nil, err
}

allowed, err := h.d.PermissionEngine().SubjectIsAllowed(ctx, tuple)
// TODO add content change handling
Expand Down
8 changes: 5 additions & 3 deletions internal/expand/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ func (h *handler) getExpand(w http.ResponseWriter, r *http.Request, _ httprouter
}

func (h *handler) Expand(ctx context.Context, req *acl.ExpandRequest) (*acl.ExpandResponse, error) {
tree, err := h.d.ExpandEngine().BuildTree(ctx,
relationtuple.SubjectFromProto(req.Subject),
int(req.MaxDepth))
sub, err := relationtuple.SubjectFromProto(req.Subject)
if err != nil {
return nil, err
}
tree, err := h.d.ExpandEngine().BuildTree(ctx, sub, int(req.MaxDepth))
if err != nil {
return nil, err
}
Expand Down
32 changes: 18 additions & 14 deletions internal/expand/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,28 @@ func (t *Tree) ToProto() *acl.SubjectTree {
}
}

func TreeFromProto(t *acl.SubjectTree) *Tree {
if t.NodeType == acl.NodeType_NODE_TYPE_LEAF {
return &Tree{
Type: Leaf,
Subject: relationtuple.SubjectFromProto(t.Subject),
}
func TreeFromProto(t *acl.SubjectTree) (*Tree, error) {
sub, err := relationtuple.SubjectFromProto(t.Subject)
if err != nil {
return nil, err
}

children := make([]*Tree, len(t.Children))
for i, c := range t.Children {
children[i] = TreeFromProto(c)
self := &Tree{
Type: NodeTypeFromProto(t.NodeType),
Subject: sub,
}

return &Tree{
Type: NodeTypeFromProto(t.NodeType),
Subject: relationtuple.SubjectFromProto(t.Subject),
Children: children,
if t.NodeType != acl.NodeType_NODE_TYPE_LEAF {
self.Children = make([]*Tree, len(t.Children))
for i, c := range t.Children {
var err error
self.Children[i], err = TreeFromProto(c)
if err != nil {
return nil, err
}
}
}

return self, nil
}

func (t *Tree) String() string {
Expand Down
4 changes: 4 additions & 0 deletions internal/persistence/sql/relationtuples.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ func (r *relationTuple) toInternal() (*relationtuple.InternalRelationTuple, erro
}

func (p *Persister) insertRelationTuple(ctx context.Context, rel *relationtuple.InternalRelationTuple) error {
if rel.Subject == nil {
return errors.New("subject is not allowed to be nil")
}

n, err := p.namespaces.GetNamespace(ctx, rel.Namespace)
if err != nil {
return err
Expand Down
86 changes: 54 additions & 32 deletions internal/relationtuple/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import (
"github.com/ory/keto/internal/x"

"github.com/tidwall/gjson"

"github.com/ory/x/cmdx"
)

type (
Expand Down Expand Up @@ -76,6 +74,7 @@ var (
_, _ Subject = &SubjectID{}, &SubjectSet{}

ErrMalformedInput = errors.New("malformed string input")
ErrNilSubject = errors.New("subject is nil")
)

func SubjectFromString(s string) (Subject, error) {
Expand All @@ -85,20 +84,22 @@ func SubjectFromString(s string) (Subject, error) {
return (&SubjectID{}).FromString(s)
}

func SubjectFromProto(gs *acl.Subject) Subject {
func SubjectFromProto(gs *acl.Subject) (Subject, error) {
switch s := gs.GetRef().(type) {
case nil:
return nil, errors.WithStack(ErrNilSubject)
case *acl.Subject_Id:
return &SubjectID{
ID: s.Id,
}
}, nil
case *acl.Subject_Set:
return &SubjectSet{
Namespace: s.Set.Namespace,
Object: s.Set.Object,
Relation: s.Set.Relation,
}
}, nil
}
return nil
return nil, errors.WithStack(ErrNilSubject)
}

func (s *SubjectID) String() string {
Expand Down Expand Up @@ -235,13 +236,18 @@ func (r *InternalRelationTuple) MarshalJSON() ([]byte, error) {
return sjson.SetBytes(enc, "subject", r.Subject.String())
}

func (r *InternalRelationTuple) FromDataProvider(d TupleData) *InternalRelationTuple {
r.Subject = SubjectFromProto(d.GetSubject())
func (r *InternalRelationTuple) FromDataProvider(d TupleData) (*InternalRelationTuple, error) {
var err error
r.Subject, err = SubjectFromProto(d.GetSubject())
if err != nil {
return nil, err
}

r.Object = d.GetObject()
r.Namespace = d.GetNamespace()
r.Relation = d.GetRelation()

return r
return r, nil
}

func (r *InternalRelationTuple) ToProto() *acl.RelationTuple {
Expand Down Expand Up @@ -290,13 +296,14 @@ func (r *InternalRelationTuple) ToLoggerFields() logrus.Fields {
}
}

func (q *RelationQuery) FromProto(query *acl.ListRelationTuplesRequest_Query) *RelationQuery {
return &RelationQuery{
Namespace: query.Namespace,
Object: query.Object,
Relation: query.Relation,
Subject: SubjectFromProto(query.Subject),
func (q *RelationQuery) FromProto(query *acl.ListRelationTuplesRequest_Query) (*RelationQuery, error) {
r, err := (&InternalRelationTuple{}).FromDataProvider(query)
if err != nil {
return nil, err
}

*q = RelationQuery(*r)
return q, nil
}

func (q *RelationQuery) FromURLQuery(query url.Values) (*RelationQuery, error) {
Expand Down Expand Up @@ -386,51 +393,66 @@ func (r *RelationCollection) Header() []string {
}

func (r *RelationCollection) Table() [][]string {
if r.internalRelations == nil {
for _, rel := range r.protoRelations {
r.internalRelations = append(r.internalRelations, (&InternalRelationTuple{}).FromDataProvider(rel))
}
ir, err := r.ToInternal()
if err != nil {
return [][]string{{fmt.Sprintf("%+v", err)}}
}

data := make([][]string, len(r.internalRelations))
for i, rel := range r.internalRelations {
data[i] = []string{rel.Namespace, rel.Object, rel.Relation, cmdx.None}
if rel.Subject != nil {
data[i][3] = rel.Subject.String()
}
data := make([][]string, len(ir))
for i, rel := range ir {
data[i] = []string{rel.Namespace, rel.Object, rel.Relation, rel.Subject.String()}
}

return data
}

func (r *RelationCollection) ToInternal() []*InternalRelationTuple {
func (r *RelationCollection) ToInternal() ([]*InternalRelationTuple, error) {
if r.internalRelations == nil {
r.internalRelations = make([]*InternalRelationTuple, len(r.protoRelations))
for i, rel := range r.protoRelations {
r.internalRelations[i] = (&InternalRelationTuple{}).FromDataProvider(rel)
ir, err := (&InternalRelationTuple{}).FromDataProvider(rel)
if err != nil {
return nil, err
}
r.internalRelations[i] = ir
}
}
return r.internalRelations
return r.internalRelations, nil
}

func (r *RelationCollection) Interface() interface{} {
return r.ToInternal()
ir, err := r.ToInternal()
if err != nil {
return err
}
return ir
}

func (r *RelationCollection) MarshalJSON() ([]byte, error) {
return json.Marshal(r.ToInternal())
ir, err := r.ToInternal()
if err != nil {
return nil, err
}
return json.Marshal(ir)
}

func (r *RelationCollection) UnmarshalJSON(raw []byte) error {
return json.Unmarshal(raw, &r.internalRelations)
}

func (r *RelationCollection) Len() int {
return len(r.ToInternal())
if ir := len(r.internalRelations); ir > 0 {
return ir
}
return len(r.protoRelations)
}

func (r *RelationCollection) IDs() []string {
rts := r.ToInternal()
rts, err := r.ToInternal()
if err != nil {
// fmt.Sprintf to include the stacktrace
return []string{fmt.Sprintf("%+v", err)}
}
ids := make([]string, len(rts))
for i, rt := range rts {
ids[i] = rt.String()
Expand Down
Loading

0 comments on commit 7a0fcfc

Please sign in to comment.