diff --git a/internal/e2e/transaction_cases_test.go b/internal/e2e/transaction_cases_test.go index 4e4e7e5b0..3d5b22c64 100644 --- a/internal/e2e/transaction_cases_test.go +++ b/internal/e2e/transaction_cases_test.go @@ -58,6 +58,51 @@ func runTransactionCases(c transactClient, m *namespaceTestManager) func(*testin assert.Len(t, resp.RelationTuples, 0) }) + t.Run("case=large inserts and deletes", func(t *testing.T) { + ns := []*namespace.Namespace{ + {Name: t.Name() + "1"}, + {Name: t.Name() + "2"}, + } + m.add(t, ns...) + + var tuples []*ketoapi.RelationTuple + for range 12001 { + tuples = append(tuples, &ketoapi.RelationTuple{ + Namespace: ns[0].Name, + Object: "o", + Relation: "rel", + SubjectSet: &ketoapi.SubjectSet{ + Namespace: ns[1].Name, + Object: "o", + Relation: "rel", + }, + }, + &ketoapi.RelationTuple{ + Namespace: ns[0].Name, + Object: "o", + Relation: "rel", + SubjectID: pointerx.Ptr("sid"), + }, + ) + } + + c.transactTuples(t, tuples, nil) + + resp := c.queryTuple(t, &ketoapi.RelationQuery{ + Namespace: &ns[0].Name, + }) + for i := range tuples { + assert.Contains(t, resp.RelationTuples, tuples[i]) + } + + c.transactTuples(t, nil, tuples) + + resp = c.queryTuple(t, &ketoapi.RelationQuery{ + Namespace: &ns[0].Name, + }) + assert.Len(t, resp.RelationTuples, 0) + }) + t.Run("case=expand-api-display-access docs code sample", func(t *testing.T) { files := &namespace.Namespace{Name: t.Name() + "files"} directories := &namespace.Namespace{Name: t.Name() + "directories"} diff --git a/internal/persistence/sql/persister.go b/internal/persistence/sql/persister.go index 7fd5d94f6..b7fc83098 100644 --- a/internal/persistence/sql/persister.go +++ b/internal/persistence/sql/persister.go @@ -6,11 +6,9 @@ package sql import ( "context" "embed" - "reflect" "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/ory/x/otelx" "github.com/ory/x/popx" "github.com/pkg/errors" @@ -70,24 +68,6 @@ func (p *Persister) Connection(ctx context.Context) *pop.Connection { return popx.GetConnection(ctx, p.conn.WithContext(ctx)) } -func (p *Persister) createWithNetwork(ctx context.Context, v interface{}) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createWithNetwork") - defer otelx.End(span, &err) - - rv := reflect.ValueOf(v) - - if rv.Kind() != reflect.Ptr && rv.Elem().Kind() != reflect.Struct { - panic("expected to get *struct in create") - } - nID := rv.Elem().FieldByName("NetworkID") - if !nID.IsValid() || !nID.CanSet() { - panic("expected struct to have a 'NetworkID uuid.UUID' field") - } - nID.Set(reflect.ValueOf(p.NetworkID(ctx))) - - return p.Connection(ctx).Create(v) -} - func (p *Persister) queryWithNetwork(ctx context.Context) *pop.Query { return p.Connection(ctx).Where("nid = ?", p.NetworkID(ctx)) } diff --git a/internal/persistence/sql/query_test.go b/internal/persistence/sql/query_test.go new file mode 100644 index 000000000..c418def24 --- /dev/null +++ b/internal/persistence/sql/query_test.go @@ -0,0 +1,129 @@ +package sql + +import ( + "database/sql" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/ory/x/uuidx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/relationtuple" +) + +func TestBuildDelete(t *testing.T) { + t.Parallel() + nid := uuidx.NewV4() + + q, args, err := buildDelete(nid, nil) + assert.Error(t, err) + assert.Empty(t, q) + assert.Empty(t, args) + + obj1, obj2, sub1, obj3 := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + + q, args, err = buildDelete(nid, []*relationtuple.RelationTuple{ + { + Namespace: "ns1", + Object: obj1, + Relation: "rel1", + Subject: &relationtuple.SubjectID{ + ID: sub1, + }, + }, + { + Namespace: "ns2", + Object: obj2, + Relation: "rel2", + Subject: &relationtuple.SubjectSet{ + Namespace: "ns3", + Object: obj3, + Relation: "rel3", + }, + }, + }) + require.NoError(t, err) + + // parentheses are important here + assert.Equal(t, q, "DELETE FROM keto_relation_tuples WHERE ((namespace = ? AND object = ? AND relation = ? AND subject_id = ? AND subject_set_namespace IS NULL AND subject_set_object IS NULL AND subject_set_relation IS NULL) OR (namespace = ? AND object = ? AND relation = ? AND subject_id IS NULL AND subject_set_namespace = ? AND subject_set_object = ? AND subject_set_relation = ?)) AND nid = ?") + assert.Equal(t, []any{"ns1", obj1, "rel1", sub1, "ns2", obj2, "rel2", "ns3", obj3, "rel3", nid}, args) +} + +func TestBuildInsert(t *testing.T) { + t.Parallel() + nid := uuidx.NewV4() + + q, args, err := buildInsert(time.Now(), nid, nil) + assert.Error(t, err) + assert.Empty(t, q) + assert.Empty(t, args) + + obj1, obj2, sub1, obj3 := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + + now := time.Now() + + q, args, err = buildInsert(now, nid, []*relationtuple.RelationTuple{ + { + Namespace: "ns1", + Object: obj1, + Relation: "rel1", + Subject: &relationtuple.SubjectID{ + ID: sub1, + }, + }, + { + Namespace: "ns2", + Object: obj2, + Relation: "rel2", + Subject: &relationtuple.SubjectSet{ + Namespace: "ns3", + Object: obj3, + Relation: "rel3", + }, + }, + }) + require.NoError(t, err) + + assert.Equal(t, q, "INSERT INTO keto_relation_tuples (shard_id, nid, namespace, object, relation, subject_id, subject_set_namespace, subject_set_object, subject_set_relation, commit_time) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?), (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)") + assert.Equal(t, []any{ + args[0], // this is kind of cheating but we generate the shard id in the buildInsert function + nid, + "ns1", + obj1, + "rel1", + uuid.NullUUID{sub1, true}, + sql.NullString{}, uuid.NullUUID{}, sql.NullString{}, + now, + + args[10], // again, cheating + nid, + "ns2", + obj2, + "rel2", + uuid.NullUUID{}, + sql.NullString{"ns3", true}, uuid.NullUUID{obj3, true}, sql.NullString{"rel3", true}, + now, + }, args) +} + +func TestBuildInsertUUIDs(t *testing.T) { + t.Parallel() + + nid := uuidx.NewV4() + foo, bar, baz := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + uuids := []UUIDMapping{ + {foo, "foo"}, + {bar, "bar"}, + {baz, "baz"}, + } + + q, args := buildInsertUUIDs(nid, uuids, "mysql") + assert.Equal(t, "INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES (?,?),(?,?),(?,?)", q) + assert.Equal(t, []any{foo, "foo", bar, "bar", baz, "baz"}, args) + + q, args = buildInsertUUIDs(nid, uuids, "anything else") + assert.Equal(t, "INSERT INTO keto_uuid_mappings (id, string_representation) VALUES (?,?),(?,?),(?,?) ON CONFLICT (id) DO NOTHING", q) + assert.Equal(t, []any{foo, "foo", bar, "bar", baz, "baz"}, args) +} diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 9d82bee31..a7109128b 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -6,18 +6,29 @@ package sql import ( "context" "database/sql" + "fmt" + "slices" + "strings" "time" - "github.com/ory/keto/ketoapi" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" "github.com/pkg/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x" + "github.com/ory/keto/ketoapi" +) + +// Typical database limits for placeholders/bind vars are 1<<15 (32k, MySQL, SQLite) and 1<<16 (64k, PostgreSQL, CockroachDB). +const ( + chunkSizeInsertUUIDMappings = 15000 // two placeholders per mapping + chunkSizeInsertTuple = 3000 // ten placeholders per tuple + chunkSizeDeleteTuple = 100 // the database must build an expression tree for each chunk, so we must limit more agressively ) type ( @@ -71,7 +82,7 @@ func (r *RelationTuple) ToInternal() (*relationtuple.RelationTuple, error) { return rt, nil } -func (r *RelationTuple) insertSubject(_ context.Context, s relationtuple.Subject) error { +func (r *RelationTuple) insertSubject(s relationtuple.Subject) error { switch st := s.(type) { case *relationtuple.SubjectID: r.SubjectID = uuid.NullUUID{ @@ -90,39 +101,12 @@ func (r *RelationTuple) insertSubject(_ context.Context, s relationtuple.Subject return nil } -func (r *RelationTuple) FromInternal(ctx context.Context, p *Persister, rt *relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FromInternal") - defer otelx.End(span, &err) - +func (r *RelationTuple) FromInternal(rt *relationtuple.RelationTuple) (err error) { r.Namespace = rt.Namespace r.Object = rt.Object r.Relation = rt.Relation - return r.insertSubject(ctx, rt.Subject) -} - -func (p *Persister) InsertRelationTuple(ctx context.Context, rel *relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InsertRelationTuple") - defer otelx.End(span, &err) - - if rel.Subject == nil { - return errors.WithStack(ketoapi.ErrNilSubject) - } - - rt := &RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - CommitTime: time.Now(), - } - if err := rt.FromInternal(ctx, p, rel); err != nil { - return err - } - - if err := sqlcon.HandleError( - p.createWithNetwork(ctx, rt), - ); err != nil { - return err - } - return nil + return r.insertSubject(rt.Subject) } func (p *Persister) whereSubject(_ context.Context, q *pop.Query, sub relationtuple.Subject) error { @@ -165,25 +149,53 @@ func (p *Persister) whereQuery(ctx context.Context, q *pop.Query, rq *relationtu return nil } +func buildDelete(nid uuid.UUID, rs []*relationtuple.RelationTuple) (query string, args []any, err error) { + if len(rs) == 0 { + return "", nil, errors.WithStack(ketoapi.ErrMalformedInput) + } + + args = make([]any, 0, 6*len(rs)+1) + ors := make([]string, 0, len(rs)) + for _, rt := range rs { + switch s := rt.Subject.(type) { + case *relationtuple.SubjectID: + ors = append(ors, "(namespace = ? AND object = ? AND relation = ? AND subject_id = ? AND subject_set_namespace IS NULL AND subject_set_object IS NULL AND subject_set_relation IS NULL)") + args = append(args, rt.Namespace, rt.Object, rt.Relation, s.ID) + case *relationtuple.SubjectSet: + ors = append(ors, "(namespace = ? AND object = ? AND relation = ? AND subject_id IS NULL AND subject_set_namespace = ? AND subject_set_object = ? AND subject_set_relation = ?)") + args = append(args, rt.Namespace, rt.Object, rt.Relation, s.Namespace, s.Object, s.Relation) + case nil: + return "", nil, errors.WithStack(ketoapi.ErrNilSubject) + } + } + + query = fmt.Sprintf("DELETE FROM %s WHERE (%s) AND nid = ?", (&RelationTuple{}).TableName(), strings.Join(ors, " OR ")) + args = append(args, nid) + return query, args, nil +} + func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRelationTuples") + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteRelationTuples", + trace.WithAttributes(attribute.Int("count", len(rs)))) defer otelx.End(span, &err) + if len(rs) == 0 { + return nil + } + return p.Transaction(ctx, func(ctx context.Context) error { - for _, r := range rs { - q := p.queryWithNetwork(ctx). - Where("namespace = ?", r.Namespace). - Where("object = ?", r.Object). - Where("relation = ?", r.Relation) - if err := p.whereSubject(ctx, q, r.Subject); err != nil { + for chunk := range slices.Chunk(rs, chunkSizeDeleteTuple) { + q, args, err := buildDelete(p.NetworkID(ctx), chunk) + if err != nil { return err } - - if err := q.Delete(&RelationTuple{}); err != nil { - return err + if q == "" { + continue + } + if err := p.Connection(ctx).RawQuery(q, args...).Exec(); err != nil { + return sqlcon.HandleError(err) } } - return nil }) } @@ -260,15 +272,63 @@ func (p *Persister) ExistsRelationTuples(ctx context.Context, query *relationtup return exists, sqlcon.HandleError(err) } +func buildInsert(commitTime time.Time, nid uuid.UUID, rs []*relationtuple.RelationTuple) (query string, args []any, err error) { + if len(rs) == 0 { + return "", nil, errors.WithStack(ketoapi.ErrMalformedInput) + } + + var q strings.Builder + fmt.Fprintf(&q, "INSERT INTO %s (shard_id, nid, namespace, object, relation, subject_id, subject_set_namespace, subject_set_object, subject_set_relation, commit_time) VALUES ", (&RelationTuple{}).TableName()) + const placeholders = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + const separator = ", " + q.Grow(len(rs) * (len(placeholders) + len(separator))) + args = make([]any, 0, 10*len(rs)) + + for i, r := range rs { + if r.Subject == nil { + return "", nil, errors.WithStack(ketoapi.ErrNilSubject) + } + + rt := &RelationTuple{ + ID: uuid.Must(uuid.NewV4()), + NetworkID: nid, + CommitTime: commitTime, + } + if err := rt.FromInternal(r); err != nil { + return "", nil, err + } + + if i > 0 { + q.WriteString(separator) + } + q.WriteString(placeholders) + args = append(args, rt.ID, rt.NetworkID, rt.Namespace, rt.Object, rt.Relation, rt.SubjectID, rt.SubjectSetNamespace, rt.SubjectSetObject, rt.SubjectSetRelation, rt.CommitTime) + } + + query = q.String() + return query, args, nil +} + func (p *Persister) WriteRelationTuples(ctx context.Context, rs ...*relationtuple.RelationTuple) (err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.WriteRelationTuples") + ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.WriteRelationTuples", + trace.WithAttributes(attribute.Int("count", len(rs)))) defer otelx.End(span, &err) + if len(rs) == 0 { + return nil + } + + commitTime := time.Now() + return p.Transaction(ctx, func(ctx context.Context) error { - for _, r := range rs { - if err := p.InsertRelationTuple(ctx, r); err != nil { + for chunk := range slices.Chunk(rs, chunkSizeInsertTuple) { + q, args, err := buildInsert(commitTime, p.NetworkID(ctx), chunk) + if err != nil { return err } + if err := p.Connection(ctx).RawQuery(q, args...).Exec(); err != nil { + return sqlcon.HandleError(err) + } } return nil }) @@ -278,6 +338,10 @@ func (p *Persister) TransactRelationTuples(ctx context.Context, ins []*relationt ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.TransactRelationTuples") defer otelx.End(span, &err) + if len(ins)+len(del) == 0 { + return nil + } + return p.Transaction(ctx, func(ctx context.Context) error { if err := p.WriteRelationTuples(ctx, ins...); err != nil { return err diff --git a/internal/persistence/sql/uuid_mapping.go b/internal/persistence/sql/uuid_mapping.go index 9b48b1259..07ba33fae 100644 --- a/internal/persistence/sql/uuid_mapping.go +++ b/internal/persistence/sql/uuid_mapping.go @@ -7,6 +7,7 @@ import ( "context" "iter" "maps" + "slices" "strings" "github.com/gofrs/uuid" @@ -32,47 +33,6 @@ func (UUIDMapping) TableName() string { return "keto_uuid_mappings" } -func (p *Persister) batchToUUIDs(ctx context.Context, values []string, readOnly bool) (uuids []uuid.UUID, err error) { - if len(values) == 0 { - return - } - - uuids = make([]uuid.UUID, len(values)) - placeholderArray := make([]string, len(values)) - args := make([]interface{}, 0, len(values)*2) - for i, val := range values { - uuids[i] = uuid.NewV5(p.NetworkID(ctx), val) - placeholderArray[i] = "(?, ?)" - args = append(args, uuids[i], val) - } - placeholders := strings.Join(placeholderArray, ", ") - - p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings") - - if !readOnly { - // We need to write manual SQL here because the INSERT should not fail if - // the UUID already exists, but we still want to return an error if anything - // else goes wrong. - var query string - switch d := p.Connection(ctx).Dialect.Name(); d { - case "mysql": - query = ` - INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ` + placeholders - default: - query = ` - INSERT INTO keto_uuid_mappings (id, string_representation) - VALUES ` + placeholders + ` - ON CONFLICT (id) DO NOTHING` - } - - return uuids, sqlcon.HandleError( - p.Connection(ctx).RawQuery(query, args...).Exec(), - ) - } else { - return uuids, nil - } -} - func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) { if len(ids) == 0 { return @@ -128,18 +88,52 @@ func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts .. return } -func (p *Persister) MapStringsToUUIDs(ctx context.Context, s ...string) (_ []uuid.UUID, err error) { +func (p *Persister) MapStringsToUUIDs(ctx context.Context, values ...string) (uuids []uuid.UUID, err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDs") defer otelx.End(span, &err) - return p.batchToUUIDs(ctx, s, false) + if len(values) == 0 { + return + } + + uuids, err = p.MapStringsToUUIDsReadOnly(ctx, values...) + if err != nil { + return nil, err + } + + p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings") + + mappings := make([]UUIDMapping, len(values)) + for i := range len(values) { + mappings[i] = UUIDMapping{ + ID: uuids[i], + StringRepresentation: values[i], + } + } + + err = p.Transaction(ctx, func(ctx context.Context) error { + for chunk := range slices.Chunk(mappings, chunkSizeInsertUUIDMappings) { + query, args := buildInsertUUIDs(p.NetworkID(ctx), chunk, p.conn.Dialect.Name()) + if err := p.Connection(ctx).RawQuery(query, args...).Exec(); err != nil { + return sqlcon.HandleError(err) + } + } + return nil + }) + + return uuids, err } -func (p *Persister) MapStringsToUUIDsReadOnly(ctx context.Context, s ...string) (_ []uuid.UUID, err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDsReadOnly") - defer otelx.End(span, &err) +func (p *Persister) MapStringsToUUIDsReadOnly(ctx context.Context, ss ...string) (uuids []uuid.UUID, err error) { + // This function doesn't talk to the database or do anything interesting, so we don't need to trace it. + // ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDsReadOnly") + // defer otelx.End(span, &err) - return p.batchToUUIDs(ctx, s, true) + uuids = make([]uuid.UUID, len(ss)) + for i := range ss { + uuids[i] = uuid.NewV5(p.NetworkID(ctx), ss[i]) + } + return uuids, nil } func (p *Persister) MapUUIDsToStrings(ctx context.Context, u ...uuid.UUID) (_ []string, err error) { @@ -148,3 +142,39 @@ func (p *Persister) MapUUIDsToStrings(ctx context.Context, u ...uuid.UUID) (_ [] return p.batchFromUUIDs(ctx, u) } + +func buildInsertUUIDs(nid uuid.UUID, values []UUIDMapping, dialect string) (query string, args []any) { + if len(values) == 0 { + return "", nil + } + + const placeholder = "(?,?)" + const separator = "," + + var q strings.Builder + args = make([]any, 0, len(values)*2) + + if dialect == "mysql" { + q.WriteString("INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ") + } else { + q.WriteString("INSERT INTO keto_uuid_mappings (id, string_representation) VALUES ") + } + + q.Grow(len(values)*(len(placeholder)+len(separator)) + 100) + + for i, val := range values { + if i > 0 { + q.WriteString(separator) + } + q.WriteString(placeholder) + args = append(args, val.ID, val.StringRepresentation) + } + + if dialect == "mysql" { + // nothing + } else { + q.WriteString(" ON CONFLICT (id) DO NOTHING") + } + + return q.String(), args +}