diff --git a/internal/e2e/transaction_cases_test.go b/internal/e2e/transaction_cases_test.go index 4e4e7e5b0..01d96ae51 100644 --- a/internal/e2e/transaction_cases_test.go +++ b/internal/e2e/transaction_cases_test.go @@ -58,6 +58,51 @@ func runTransactionCases(c transactClient, m *namespaceTestManager) func(*testin assert.Len(t, resp.RelationTuples, 0) }) + t.Run("case=large inserts and deletes", func(t *testing.T) { + ns := []*namespace.Namespace{ + {Name: t.Name() + "1"}, + {Name: t.Name() + "2"}, + } + m.add(t, ns...) + + var tuples []*ketoapi.RelationTuple + for range 5001 { + tuples = append(tuples, &ketoapi.RelationTuple{ + Namespace: ns[0].Name, + Object: "o", + Relation: "rel", + SubjectSet: &ketoapi.SubjectSet{ + Namespace: ns[1].Name, + Object: "o", + Relation: "rel", + }, + }, + &ketoapi.RelationTuple{ + Namespace: ns[0].Name, + Object: "o", + Relation: "rel", + SubjectID: pointerx.Ptr("sid"), + }, + ) + } + + c.transactTuples(t, tuples, nil) + + resp := c.queryTuple(t, &ketoapi.RelationQuery{ + Namespace: &ns[0].Name, + }) + for i := range tuples { + assert.Contains(t, resp.RelationTuples, tuples[i]) + } + + c.transactTuples(t, nil, tuples) + + resp = c.queryTuple(t, &ketoapi.RelationQuery{ + Namespace: &ns[0].Name, + }) + assert.Len(t, resp.RelationTuples, 0) + }) + t.Run("case=expand-api-display-access docs code sample", func(t *testing.T) { files := &namespace.Namespace{Name: t.Name() + "files"} directories := &namespace.Namespace{Name: t.Name() + "directories"} diff --git a/internal/persistence/sql/query_test.go b/internal/persistence/sql/query_test.go index 0fedea73f..c418def24 100644 --- a/internal/persistence/sql/query_test.go +++ b/internal/persistence/sql/query_test.go @@ -107,3 +107,23 @@ func TestBuildInsert(t *testing.T) { now, }, args) } + +func TestBuildInsertUUIDs(t *testing.T) { + t.Parallel() + + nid := uuidx.NewV4() + foo, bar, baz := uuidx.NewV4(), uuidx.NewV4(), uuidx.NewV4() + uuids := []UUIDMapping{ + {foo, "foo"}, + {bar, "bar"}, + {baz, "baz"}, + } + + q, args := buildInsertUUIDs(nid, uuids, "mysql") + assert.Equal(t, "INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES (?,?),(?,?),(?,?)", q) + assert.Equal(t, []any{foo, "foo", bar, "bar", baz, "baz"}, args) + + q, args = buildInsertUUIDs(nid, uuids, "anything else") + assert.Equal(t, "INSERT INTO keto_uuid_mappings (id, string_representation) VALUES (?,?),(?,?),(?,?) ON CONFLICT (id) DO NOTHING", q) + assert.Equal(t, []any{foo, "foo", bar, "bar", baz, "baz"}, args) +} diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index efe0f4a25..a3e4b8e9b 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -177,7 +177,7 @@ func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtup } return p.Transaction(ctx, func(ctx context.Context) error { - for chunk := range slices.Chunk(rs, 500) { + for chunk := range slices.Chunk(rs, 250) { q, args, err := buildDelete(p.NetworkID(ctx), chunk) if err != nil { return err @@ -314,7 +314,7 @@ func (p *Persister) WriteRelationTuples(ctx context.Context, rs ...*relationtupl commitTime := time.Now() return p.Transaction(ctx, func(ctx context.Context) error { - for chunk := range slices.Chunk(rs, 500) { + for chunk := range slices.Chunk(rs, 250) { q, args, err := buildInsert(commitTime, p.NetworkID(ctx), chunk) if err != nil { return err diff --git a/internal/persistence/sql/uuid_mapping.go b/internal/persistence/sql/uuid_mapping.go index 9b48b1259..22844716d 100644 --- a/internal/persistence/sql/uuid_mapping.go +++ b/internal/persistence/sql/uuid_mapping.go @@ -7,6 +7,7 @@ import ( "context" "iter" "maps" + "slices" "strings" "github.com/gofrs/uuid" @@ -32,47 +33,6 @@ func (UUIDMapping) TableName() string { return "keto_uuid_mappings" } -func (p *Persister) batchToUUIDs(ctx context.Context, values []string, readOnly bool) (uuids []uuid.UUID, err error) { - if len(values) == 0 { - return - } - - uuids = make([]uuid.UUID, len(values)) - placeholderArray := make([]string, len(values)) - args := make([]interface{}, 0, len(values)*2) - for i, val := range values { - uuids[i] = uuid.NewV5(p.NetworkID(ctx), val) - placeholderArray[i] = "(?, ?)" - args = append(args, uuids[i], val) - } - placeholders := strings.Join(placeholderArray, ", ") - - p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings") - - if !readOnly { - // We need to write manual SQL here because the INSERT should not fail if - // the UUID already exists, but we still want to return an error if anything - // else goes wrong. - var query string - switch d := p.Connection(ctx).Dialect.Name(); d { - case "mysql": - query = ` - INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ` + placeholders - default: - query = ` - INSERT INTO keto_uuid_mappings (id, string_representation) - VALUES ` + placeholders + ` - ON CONFLICT (id) DO NOTHING` - } - - return uuids, sqlcon.HandleError( - p.Connection(ctx).RawQuery(query, args...).Exec(), - ) - } else { - return uuids, nil - } -} - func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) { if len(ids) == 0 { return @@ -128,18 +88,52 @@ func (p *Persister) batchFromUUIDs(ctx context.Context, ids []uuid.UUID, opts .. return } -func (p *Persister) MapStringsToUUIDs(ctx context.Context, s ...string) (_ []uuid.UUID, err error) { +func (p *Persister) MapStringsToUUIDs(ctx context.Context, values ...string) (uuids []uuid.UUID, err error) { ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDs") defer otelx.End(span, &err) - return p.batchToUUIDs(ctx, s, false) + if len(values) == 0 { + return + } + + uuids, err = p.MapStringsToUUIDsReadOnly(ctx, values...) + if err != nil { + return nil, err + } + + p.d.Logger().WithField("values", values).WithField("UUIDs", uuids).Trace("adding UUID mappings") + + mappings := make([]UUIDMapping, len(values)) + for i := range len(values) { + mappings[i] = UUIDMapping{ + ID: uuids[i], + StringRepresentation: values[i], + } + } + + err = p.Transaction(ctx, func(ctx context.Context) error { + for chunk := range slices.Chunk(mappings, 1000) { + query, args := buildInsertUUIDs(p.NetworkID(ctx), chunk, p.conn.Dialect.Name()) + if err := p.Connection(ctx).RawQuery(query, args...).Exec(); err != nil { + return sqlcon.HandleError(err) + } + } + return nil + }) + + return uuids, err } -func (p *Persister) MapStringsToUUIDsReadOnly(ctx context.Context, s ...string) (_ []uuid.UUID, err error) { - ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDsReadOnly") - defer otelx.End(span, &err) +func (p *Persister) MapStringsToUUIDsReadOnly(ctx context.Context, ss ...string) (uuids []uuid.UUID, err error) { + // This function doesn't talk to the database or do anything interesting, so we don't need to trace it. + // ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MapStringsToUUIDsReadOnly") + // defer otelx.End(span, &err) - return p.batchToUUIDs(ctx, s, true) + uuids = make([]uuid.UUID, len(ss)) + for i := range ss { + uuids[i] = uuid.NewV5(p.NetworkID(ctx), ss[i]) + } + return uuids, nil } func (p *Persister) MapUUIDsToStrings(ctx context.Context, u ...uuid.UUID) (_ []string, err error) { @@ -148,3 +142,39 @@ func (p *Persister) MapUUIDsToStrings(ctx context.Context, u ...uuid.UUID) (_ [] return p.batchFromUUIDs(ctx, u) } + +func buildInsertUUIDs(nid uuid.UUID, values []UUIDMapping, dialect string) (query string, args []any) { + if len(values) == 0 { + return "", nil + } + + const placeholder = "(?,?)" + const separator = "," + + var q strings.Builder + args = make([]any, 0, len(values)*2) + + if dialect == "mysql" { + q.WriteString("INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) VALUES ") + } else { + q.WriteString("INSERT INTO keto_uuid_mappings (id, string_representation) VALUES ") + } + + q.Grow(len(values)*(len(placeholder)+len(separator)) + 100) + + for i, val := range values { + if i > 0 { + q.WriteString(separator) + } + q.WriteString(placeholder) + args = append(args, val.ID, val.StringRepresentation) + } + + if dialect == "mysql" { + // nothing + } else { + q.WriteString(" ON CONFLICT (id) DO NOTHING") + } + + return q.String(), args +}