From fe507bb4ea719780e732d098291aa190d6b1c441 Mon Sep 17 00:00:00 2001 From: Patrik Date: Thu, 11 Feb 2021 10:18:37 +0100 Subject: [PATCH] fix: insert relation tuples without fmt.Sprintf (#443) --- internal/persistence/sql/persister.go | 6 +-- internal/persistence/sql/relationtuples.go | 60 +++++++++++++--------- 2 files changed, 40 insertions(+), 26 deletions(-) diff --git a/internal/persistence/sql/persister.go b/internal/persistence/sql/persister.go index 54d860687..12f6aa2aa 100644 --- a/internal/persistence/sql/persister.go +++ b/internal/persistence/sql/persister.go @@ -110,15 +110,15 @@ func (p *Persister) newConnection(options map[string]string) (c *pop.Connection, } func (p *Persister) MigrateUp(_ context.Context) error { - return p.mb.Up() + return errors.WithStack(p.mb.Up()) } func (p *Persister) MigrateDown(_ context.Context, steps int) error { - return p.mb.Down(steps) + return errors.WithStack(p.mb.Down(steps)) } func (p *Persister) MigrationStatus(_ context.Context, w io.Writer) error { - return p.mb.Status(w) + return errors.WithStack(p.mb.Status(w)) } func (p *Persister) connection(ctx context.Context) *pop.Connection { diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 2d8d69ecc..38b5c06c0 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -2,9 +2,10 @@ package sql import ( "context" - "fmt" "time" + "github.com/ory/x/sqlcon" + "github.com/gobuffalo/pop/v5" "github.com/pkg/errors" @@ -17,7 +18,8 @@ import ( type ( relationTuple struct { - ShardID string `db:"shard_id"` + // An ID field is required to make pop happy. The actual ID is a composite primary key. + ID string `db:"shard_id"` Object string `db:"object"` Relation string `db:"relation"` Subject string `db:"subject"` @@ -35,7 +37,7 @@ const ( namespaceContextKey contextKeys = "namespace" ) -func (relationTuples) TableName(ctx context.Context) string { +func namespaceTableFromContext(ctx context.Context) string { n, ok := ctx.Value(namespaceContextKey).(*namespace.Namespace) if n == nil || !ok { panic("namespace context key not set") @@ -43,23 +45,12 @@ func (relationTuples) TableName(ctx context.Context) string { return tableFromNamespace(n) } -func (p *Persister) insertRelationTuple(ctx context.Context, rel *relationtuple.InternalRelationTuple) error { - commitTime := time.Now() - - n, err := p.namespaces.GetNamespace(ctx, rel.Namespace) - if err != nil { - return err - } - - // TODO sharding - shardID := "default" - - p.l.WithFields(rel.ToLoggerFields()).Trace("creating in database") +func (relationTuples) TableName(ctx context.Context) string { + return namespaceTableFromContext(ctx) +} - return p.connection(ctx).RawQuery(fmt.Sprintf( - "INSERT INTO %s (shard_id, object, relation, subject, commit_time) VALUES (?, ?, ?, ?, ?)", tableFromNamespace(n)), - shardID, rel.Object, rel.Relation, rel.Subject.String(), commitTime, - ).Exec() +func (relationTuple) TableName(ctx context.Context) string { + return namespaceTableFromContext(ctx) } func (r *relationTuple) toInternal() (*relationtuple.InternalRelationTuple, error) { @@ -68,12 +59,38 @@ func (r *relationTuple) toInternal() (*relationtuple.InternalRelationTuple, erro } sub, err := relationtuple.SubjectFromString(r.Subject) + if err != nil { + return nil, err + } + return &relationtuple.InternalRelationTuple{ Relation: r.Relation, Object: r.Object, Namespace: r.Namespace.Name, Subject: sub, - }, err + }, nil +} + +func (p *Persister) insertRelationTuple(ctx context.Context, rel *relationtuple.InternalRelationTuple) error { + n, err := p.namespaces.GetNamespace(ctx, rel.Namespace) + if err != nil { + return err + } + + // TODO sharding + shardID := "default" + + p.l.WithFields(rel.ToLoggerFields()).Trace("creating in database") + + return sqlcon.HandleError( + p.connection(context.WithValue(ctx, namespaceContextKey, n)).Create(&relationTuple{ + ID: shardID, + Object: rel.Object, + Relation: rel.Relation, + Subject: rel.Subject.String(), + CommitTime: time.Now(), + }), + ) } func (p *Persister) GetRelationTuples(ctx context.Context, query *relationtuple.RelationQuery, options ...x.PaginationOptionSetter) ([]*relationtuple.InternalRelationTuple, string, error) { @@ -83,15 +100,12 @@ func (p *Persister) GetRelationTuples(ctx context.Context, query *relationtuple. } var wheres []whereStmts - if query.Relation != "" { wheres = append(wheres, whereStmts{stmt: "relation = ?", arg: query.Relation}) } - if query.Object != "" { wheres = append(wheres, whereStmts{stmt: "object = ?", arg: query.Object}) } - if query.Subject != nil { wheres = append(wheres, whereStmts{stmt: "subject = ?", arg: query.Subject.String()}) }