From 44f442fff9716cc22fc803ce26e8a09d6d61fa1d Mon Sep 17 00:00:00 2001 From: hperl <34397+hperl@users.noreply.github.com> Date: Tue, 3 May 2022 13:09:59 +0200 Subject: [PATCH] feat: map subject/object strings to UUIDs (#792) In relation tuples and related API objects such as queries, subjects (including subject sets) and objects are now automatically mapped to and from UUIDs. The mapping is done in each handler and for each protocol (HTTP, gRPC) separately. Below the handler, all objects and subjects are now UUIDs, but still handled as strings, for the following reasons: * No duplication of datastructures (one for type `string`, one for type `uuid.UUID`). * No unnecessary copying, the mapping is done in-place and batched across multiple objects. --- cmd/expand/root_test.go | 9 +- internal/check/handler.go | 17 +-- internal/check/handler_test.go | 4 +- internal/driver/config/namespace_memory.go | 2 +- internal/driver/registry.go | 15 +-- internal/driver/registry_default.go | 3 +- internal/e2e/cases_test.go | 9 +- internal/expand/handler.go | 29 ++--- internal/expand/handler_test.go | 48 ++++++++- internal/expand/tree.go | 15 ++- internal/persistence/definitions.go | 3 +- internal/persistence/sql/full_test.go | 5 +- ...id_mapping.go => uuid_mapping_migrator.go} | 35 ++++-- ..._test.go => uuid_mapping_migrator_test.go} | 37 ++++--- internal/persistence/sql/relationtuples.go | 4 +- internal/persistence/sql/uuid_mapping.go | 98 +++++++++++++++-- internal/relationtuple/definitions.go | 41 +++++++ internal/relationtuple/read_server.go | 10 +- internal/relationtuple/read_server_test.go | 6 +- internal/relationtuple/swagger_definitions.go | 19 +++- internal/relationtuple/test_helper.go | 18 ++++ internal/relationtuple/transact_server.go | 21 +++- .../relationtuple/transact_server_test.go | 20 +++- internal/relationtuple/uuid_mapping.go | 102 ++++++++++++++++++ internal/uuidmapping/definitions.go | 48 --------- 25 files changed, 465 insertions(+), 153 deletions(-) rename internal/persistence/sql/migrations/{uuid_mapping.go => uuid_mapping_migrator.go} (78%) rename internal/persistence/sql/migrations/{uuid_mapping_test.go => uuid_mapping_migrator_test.go} (80%) create mode 100644 internal/relationtuple/test_helper.go create mode 100644 internal/relationtuple/uuid_mapping.go delete mode 100644 internal/uuidmapping/definitions.go diff --git a/cmd/expand/root_test.go b/cmd/expand/root_test.go index 4e87ab6a3..0be0a22bc 100644 --- a/cmd/expand/root_test.go +++ b/cmd/expand/root_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/ory/x/cmdx" - "github.com/stretchr/testify/assert" "github.com/ory/keto/cmd/client" @@ -18,12 +17,16 @@ func TestExpandCommand(t *testing.T) { t.Run("case=unknown tuple", func(t *testing.T) { t.Run("format=JSON", func(t *testing.T) { - stdOut := ts.Cmd.ExecNoErr(t, "access", nspace.Name, "object", "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) + stdOut := ts.Cmd.ExecNoErr(t, + "access", nspace.Name, + "object", "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) assert.Equal(t, "null\n", stdOut) }) t.Run("format=default", func(t *testing.T) { - stdOut := ts.Cmd.ExecNoErr(t, "access", nspace.Name, "object", "--"+cmdx.FlagFormat, string(cmdx.FormatDefault)) + stdOut := ts.Cmd.ExecNoErr(t, + "access", nspace.Name, + "object", "--"+cmdx.FlagFormat, string(cmdx.FormatDefault)) assert.Contains(t, stdOut, "empty tree") }) }) diff --git a/internal/check/handler.go b/internal/check/handler.go index 66d509cd1..8dcc7707e 100644 --- a/internal/check/handler.go +++ b/internal/check/handler.go @@ -5,18 +5,14 @@ import ( "encoding/json" "net/http" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" "github.com/pkg/errors" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "google.golang.org/grpc" "github.com/ory/keto/internal/relationtuple" - - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/x" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type ( @@ -30,7 +26,10 @@ type ( } ) -var _ rts.CheckServiceServer = (*Handler)(nil) +var ( + _ rts.CheckServiceServer = (*Handler)(nil) + _ *getCheckRequest = nil +) func NewHandler(d handlerDependencies) *Handler { return &Handler{d: d} @@ -64,7 +63,6 @@ type RESTResponse struct { } // swagger:parameters getCheck postCheck -// nolint:deadcode,unused type getCheckRequest struct { // in:query MaxDepth int `json:"max-depth"` @@ -105,6 +103,7 @@ func (h *Handler) getCheck(w http.ResponseWriter, r *http.Request, _ httprouter. return } + h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), tuple) allowed, err := h.d.PermissionEngine().SubjectIsAllowed(r.Context(), tuple, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -151,6 +150,7 @@ func (h *Handler) postCheck(w http.ResponseWriter, r *http.Request, _ httprouter return } + h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), &tuple) allowed, err := h.d.PermissionEngine().SubjectIsAllowed(r.Context(), &tuple, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -171,6 +171,7 @@ func (h *Handler) Check(ctx context.Context, req *rts.CheckRequest) (*rts.CheckR return nil, err } + h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(ctx, tuple) allowed, err := h.d.PermissionEngine().SubjectIsAllowed(ctx, tuple, int(req.MaxDepth)) // TODO add content change handling if err != nil { diff --git a/internal/check/handler_test.go b/internal/check/handler_test.go index 5f81bad24..19548dbf9 100644 --- a/internal/check/handler_test.go +++ b/internal/check/handler_test.go @@ -23,6 +23,8 @@ import ( ) func assertAllowed(t *testing.T, resp *http.Response) { + t.Helper() + body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -97,7 +99,7 @@ func TestRESTHandler(t *testing.T) { Relation: "r", Subject: &relationtuple.SubjectID{ID: "s"}, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rt)) + relationtuple.MapAndWriteTuples(t, reg, rt) q, err := rt.ToURLQuery() require.NoError(t, err) diff --git a/internal/driver/config/namespace_memory.go b/internal/driver/config/namespace_memory.go index 7f451c35f..2bd1e184b 100644 --- a/internal/driver/config/namespace_memory.go +++ b/internal/driver/config/namespace_memory.go @@ -34,7 +34,7 @@ func (s *memoryNamespaceManager) GetNamespaceByName(_ context.Context, name stri } } - return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf("Unknown namespace with name %s.", name)) + return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf("Unknown namespace with name %q.", name)) } func (s *memoryNamespaceManager) GetNamespaceByConfigID(_ context.Context, id int32) (*namespace.Namespace, error) { diff --git a/internal/driver/registry.go b/internal/driver/registry.go index 62e8a6231..eb6f914b0 100644 --- a/internal/driver/registry.go +++ b/internal/driver/registry.go @@ -4,21 +4,15 @@ import ( "context" "net/http" - prometheus "github.com/ory/x/prometheusx" - "github.com/gobuffalo/pop/v6" - - "github.com/ory/keto/internal/driver/config" - "github.com/ory/keto/internal/uuidmapping" - - "github.com/spf13/cobra" - - "google.golang.org/grpc" - "github.com/ory/x/healthx" + prometheus "github.com/ory/x/prometheusx" "github.com/ory/x/tracing" + "github.com/spf13/cobra" + "google.golang.org/grpc" "github.com/ory/keto/internal/check" + "github.com/ory/keto/internal/driver/config" "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/persistence" "github.com/ory/keto/internal/relationtuple" @@ -34,7 +28,6 @@ type ( x.WriterProvider relationtuple.ManagerProvider - uuidmapping.ManagerProvider expand.EngineProvider check.EngineProvider persistence.Migrator diff --git a/internal/driver/registry_default.go b/internal/driver/registry_default.go index 6e912ac6c..3cc50dbb0 100644 --- a/internal/driver/registry_default.go +++ b/internal/driver/registry_default.go @@ -25,7 +25,6 @@ import ( "github.com/ory/keto/internal/persistence" "github.com/ory/keto/internal/persistence/sql" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" "github.com/ory/keto/internal/x" "github.com/ory/keto/ketoctx" rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" @@ -152,7 +151,7 @@ func (r *RegistryDefault) RelationTupleManager() relationtuple.Manager { return r.p } -func (r *RegistryDefault) UUIDMappingManager() uuidmapping.Manager { +func (r *RegistryDefault) UUIDMappingManager() relationtuple.UUIDMappingManager { if r.p == nil { panic("no relation tuple manager, but expected to have one") } diff --git a/internal/e2e/cases_test.go b/internal/e2e/cases_test.go index c71d9a684..01e357930 100644 --- a/internal/e2e/cases_test.go +++ b/internal/e2e/cases_test.go @@ -5,17 +5,14 @@ import ( "strconv" "testing" - "github.com/stretchr/testify/require" - "github.com/ory/herodot" - - "github.com/ory/keto/internal/x" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/x" ) func runCases(c client, addNamespace func(*testing.T, ...*namespace.Namespace)) func(*testing.T) { @@ -188,7 +185,7 @@ func runCases(c client, addNamespace func(*testing.T, ...*namespace.Namespace)) Relation: "rel", } resp := c.queryTuple(t, q) - require.Equal(t, resp.RelationTuples, rts) + require.ElementsMatch(t, resp.RelationTuples, rts) c.deleteAllTuples(t, q) resp = c.queryTuple(t, q) diff --git a/internal/expand/handler.go b/internal/expand/handler.go index abf32803c..d0d818a76 100644 --- a/internal/expand/handler.go +++ b/internal/expand/handler.go @@ -4,17 +4,13 @@ import ( "context" "net/http" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "google.golang.org/grpc" "github.com/ory/keto/internal/relationtuple" - - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/x" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type ( @@ -28,7 +24,10 @@ type ( } ) -var _ rts.ExpandServiceServer = (*handler)(nil) +var ( + _ rts.ExpandServiceServer = (*handler)(nil) + _ *getExpandRequest = nil +) const RouteBase = "/relation-tuples/expand" @@ -49,7 +48,6 @@ func (h *handler) RegisterReadGRPC(s *grpc.Server) { func (h *handler) RegisterWriteGRPC(s *grpc.Server) {} // swagger:parameters getExpand -// nolint:deadcode,unused type getExpandRequest struct { // in:query MaxDepth int `json:"max-depth"` @@ -81,24 +79,31 @@ func (h *handler) getExpand(w http.ResponseWriter, r *http.Request, _ httprouter return } - res, err := h.d.ExpandEngine().BuildTree(r.Context(), (&relationtuple.SubjectSet{}).FromURLQuery(r.URL.Query()), maxDepth) + subject := (&relationtuple.SubjectSet{}).FromURLQuery(r.URL.Query()) + + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), subject) + res, err := h.d.ExpandEngine().BuildTree(r.Context(), subject, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) return } + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), res) h.d.Writer().Write(w, r, res) } func (h *handler) Expand(ctx context.Context, req *rts.ExpandRequest) (*rts.ExpandResponse, error) { - sub, err := relationtuple.SubjectFromProto(req.Subject) + subject, err := relationtuple.SubjectFromProto(req.Subject) if err != nil { return nil, err } - tree, err := h.d.ExpandEngine().BuildTree(ctx, sub, int(req.MaxDepth)) + + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsToUUID(ctx, subject) + res, err := h.d.ExpandEngine().BuildTree(ctx, subject, int(req.MaxDepth)) if err != nil { return nil, err } + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsFromUUID(ctx, res) - return &rts.ExpandResponse{Tree: tree.ToProto()}, nil + return &rts.ExpandResponse{Tree: res.ToProto()}, nil } diff --git a/internal/expand/handler_test.go b/internal/expand/handler_test.go index 95ae81875..6ba553bf3 100644 --- a/internal/expand/handler_test.go +++ b/internal/expand/handler_test.go @@ -79,20 +79,20 @@ func TestRESTHandler(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), []*relationtuple.InternalRelationTuple{ - { + relationtuple.MapAndWriteTuples(t, reg, + &relationtuple.InternalRelationTuple{ Namespace: nspace.Name, Object: rootSub.Object, Relation: rootSub.Relation, Subject: expectedTree.Children[0].Subject, }, - { + &relationtuple.InternalRelationTuple{ Namespace: nspace.Name, Object: rootSub.Object, Relation: rootSub.Relation, Subject: expectedTree.Children[1].Subject, }, - }...)) + ) qs := rootSub.ToURLQuery() qs.Set("max-depth", "2") @@ -103,6 +103,44 @@ func TestRESTHandler(t *testing.T) { actualTree := expand.Tree{} require.NoError(t, json.NewDecoder(resp.Body).Decode(&actualTree)) - assert.Equal(t, expectedTree, &actualTree) + assertEqualTrees(t, expectedTree, &actualTree) }) } + +func assertEqualTrees(t *testing.T, expected, actual *expand.Tree) { + t.Helper() + assert.Truef(t, treesAreEqual(t, expected, actual), + "expected:\n%s\n\nactual:\n%s", expected.String(), actual.String()) +} + +func treesAreEqual(t *testing.T, expected, actual *expand.Tree) bool { + if expected == nil || actual == nil { + return expected == actual + } + + if expected.Type != actual.Type { + t.Logf("expected type %q, actual type %q", expected.Type, actual.Type) + return false + } + if expected.Subject.String() != actual.Subject.String() { + t.Logf("expected subject: %q, actual subject: %q", expected.Subject.String(), actual.Subject.String()) + return false + } + if len(expected.Children) != len(actual.Children) { + t.Logf("expected len(children)=%d, actual len(children)=%d", len(expected.Children), len(actual.Children)) + return false + } + + // For children, we check for equality disregarding the order +outer: + for _, expectedChild := range expected.Children { + for _, actualChild := range actual.Children { + if treesAreEqual(t, expectedChild, actualChild) { + continue outer + } + } + t.Logf("expected child:\n%s\n\nactual child:\n%s", expectedChild.String(), actual.String()) + return false + } + return true +} diff --git a/internal/expand/tree.go b/internal/expand/tree.go index 4340ce3e3..3f2c8ec27 100644 --- a/internal/expand/tree.go +++ b/internal/expand/tree.go @@ -217,7 +217,7 @@ func TreeFromProto(t *rts.SubjectTree) (*Tree, error) { func (t *Tree) String() string { if t == nil { - return "" + return "(nil)" } sub := t.Subject.String() @@ -233,3 +233,16 @@ func (t *Tree) String() string { return fmt.Sprintf("∪ %s\n├─ %s", sub, strings.Join(children, "\n├─ ")) } + +func (t *Tree) UUIDMappableFields() (res []*string) { + if t == nil { + return + } + if t.Subject != nil { + res = append(res, t.Subject.UUIDMappableFields()...) + } + for _, c := range t.Children { + res = append(res, c.UUIDMappableFields()...) + } + return +} diff --git a/internal/persistence/definitions.go b/internal/persistence/definitions.go index 46e87d937..0b1d267f3 100644 --- a/internal/persistence/definitions.go +++ b/internal/persistence/definitions.go @@ -9,13 +9,12 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" ) type ( Persister interface { relationtuple.Manager - uuidmapping.Manager + relationtuple.UUIDMappingManager Connection(ctx context.Context) *pop.Connection } diff --git a/internal/persistence/sql/full_test.go b/internal/persistence/sql/full_test.go index 595f9d051..ff772c9d9 100644 --- a/internal/persistence/sql/full_test.go +++ b/internal/persistence/sql/full_test.go @@ -14,7 +14,6 @@ import ( "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/persistence/sql" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" "github.com/ory/keto/internal/x/dbx" ) @@ -69,9 +68,9 @@ func TestPersister(t *testing.T) { relationtuple.IsolationTest(t, p0, p1, addNamespace(r, nspaces)) }) - t.Run("uuidmapping.ManagerTest", func(t *testing.T) { + t.Run("relationtuple.UUIDMappingManagerTest", func(t *testing.T) { p, _, _ := setup(t, dsn) - uuidmapping.ManagerTest(t, p) + relationtuple.UUIDMappingManagerTest(t, p) }) }) } diff --git a/internal/persistence/sql/migrations/uuid_mapping.go b/internal/persistence/sql/migrations/uuid_mapping_migrator.go similarity index 78% rename from internal/persistence/sql/migrations/uuid_mapping.go rename to internal/persistence/sql/migrations/uuid_mapping_migrator.go index 960256ab4..15dc884db 100644 --- a/internal/persistence/sql/migrations/uuid_mapping.go +++ b/internal/persistence/sql/migrations/uuid_mapping_migrator.go @@ -2,9 +2,10 @@ package migrations import ( "context" + dbSql "database/sql" + "errors" "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" "github.com/ory/x/sqlcon" "github.com/ory/keto/internal/persistence/sql" @@ -67,9 +68,23 @@ func (m *toUUIDMappingMigrator) MigrateUUIDMappings(ctx context.Context) error { }) } +func (m *toUUIDMappingMigrator) hasMapping(ctx context.Context, id string) (bool, error) { + err := m.d.Persister().Connection(ctx).Find(&sql.UUIDMapping{}, id) + if err == nil { + return true, nil + } + if errors.Is(err, dbSql.ErrNoRows) { + return false, nil + } + return false, err +} + func (m *toUUIDMappingMigrator) migrateSubjectID(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.SubjectID.String) - if err == nil || !rt.SubjectID.Valid || rt.SubjectID.String == "" { + skip, err := m.hasMapping(ctx, rt.SubjectID.String) + if err != nil { + return err + } + if skip || !rt.SubjectID.Valid || rt.SubjectID.String == "" { return nil } @@ -78,8 +93,11 @@ func (m *toUUIDMappingMigrator) migrateSubjectID(ctx context.Context, rt *sql.Re } func (m *toUUIDMappingMigrator) migrateSubjectSetObject(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.SubjectSetObject.String) - if err == nil || !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" { + skip, err := m.hasMapping(ctx, rt.SubjectSetObject.String) + if err != nil { + return err + } + if skip || !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" { return nil } @@ -88,8 +106,11 @@ func (m *toUUIDMappingMigrator) migrateSubjectSetObject(ctx context.Context, rt } func (m *toUUIDMappingMigrator) migrateObject(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.Object) - if err == nil || rt.Object == "" { + skip, err := m.hasMapping(ctx, rt.Object) + if err != nil { + return err + } + if skip || rt.Object == "" { return nil } diff --git a/internal/persistence/sql/migrations/uuid_mapping_test.go b/internal/persistence/sql/migrations/uuid_mapping_migrator_test.go similarity index 80% rename from internal/persistence/sql/migrations/uuid_mapping_test.go rename to internal/persistence/sql/migrations/uuid_mapping_migrator_test.go index 9756c5563..98208ca1c 100644 --- a/internal/persistence/sql/migrations/uuid_mapping_test.go +++ b/internal/persistence/sql/migrations/uuid_mapping_migrator_test.go @@ -30,9 +30,8 @@ func TestToUUIDMappingMigrator(t *testing.T) { require.NoError(t, err) testCases := []struct { - name string - rt *sql.RelationTuple - expectMapping bool + name string + rt *sql.RelationTuple }{{ name: "with string subject", rt: &sql.RelationTuple{ @@ -42,7 +41,6 @@ func TestToUUIDMappingMigrator(t *testing.T) { SubjectID: dbsql.NullString{String: "subject", Valid: true}, CommitTime: time.Now(), }, - expectMapping: true, }, { name: "with null subject", rt: &sql.RelationTuple{ @@ -55,16 +53,15 @@ func TestToUUIDMappingMigrator(t *testing.T) { Object: "object", CommitTime: time.Now(), }, - expectMapping: true, }, { name: "with UUID subject", rt: &sql.RelationTuple{ ID: uuid.Must(uuid.NewV4()), NetworkID: nw.ID, SubjectID: dbsql.NullString{String: uuid.Must(uuid.NewV4()).String(), Valid: true}, + Object: "object", CommitTime: time.Now(), }, - expectMapping: false, }} for _, tc := range testCases { @@ -75,21 +72,23 @@ func TestToUUIDMappingMigrator(t *testing.T) { newRt := &sql.RelationTuple{} require.NoError(t, conn.Find(newRt, tc.rt.ID)) - if tc.expectMapping { - // Check subject mapping - if tc.rt.SubjectID.Valid { - assertHasMapping(t, conn, tc.rt.SubjectID.String, newRt.SubjectID.String) - } else { - assertHasMapping(t, conn, tc.rt.SubjectSetObject.String, newRt.SubjectSetObject.String) - // check that both "OBJECT" strings are mapped to same UUID. - assert.Equal(t, newRt.Object, newRt.SubjectSetObject.String) - } - assertHasMapping(t, conn, tc.rt.Object, newRt.Object) + // Check subject mapping + if tc.rt.SubjectID.Valid { + assertHasMapping(t, conn, tc.rt.SubjectID.String, newRt.SubjectID.String) } else { - // Nothing should have changed (ignoring commit time) - newRt.CommitTime = tc.rt.CommitTime - assert.Equal(t, tc.rt, newRt) + assertHasMapping(t, conn, tc.rt.SubjectSetObject.String, newRt.SubjectSetObject.String) + // check that both "OBJECT" strings are mapped to same UUID. + assert.Equal(t, newRt.Object, newRt.SubjectSetObject.String) } + assertHasMapping(t, conn, tc.rt.Object, newRt.Object) + + t.Run("idempotency", func(t *testing.T) { + // Check that running the migration again doesn't change anything + require.NoError(t, m.MigrateUUIDMappings(ctx)) + twiceMigratedRt := &sql.RelationTuple{} + require.NoError(t, conn.Find(twiceMigratedRt, tc.rt.ID)) + assert.Equal(t, newRt, twiceMigratedRt) + }) }) } }) diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 98aee6a14..47b94713b 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -198,7 +198,7 @@ func (p *Persister) whereQuery(ctx context.Context, q *pop.Query, rq *relationtu } func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtuple.InternalRelationTuple) error { - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { for _, r := range rs { n, err := p.GetNamespaceByName(ctx, r.Namespace) if err != nil { @@ -223,7 +223,7 @@ func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtup } func (p *Persister) DeleteAllRelationTuples(ctx context.Context, query *relationtuple.RelationQuery) error { - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { sqlQuery := p.QueryWithNetwork(ctx) err := p.whereQuery(ctx, sqlQuery, query) if err != nil { diff --git a/internal/persistence/sql/uuid_mapping.go b/internal/persistence/sql/uuid_mapping.go index 95ed4b5e5..1f1b5030a 100644 --- a/internal/persistence/sql/uuid_mapping.go +++ b/internal/persistence/sql/uuid_mapping.go @@ -2,9 +2,13 @@ package sql import ( "context" + "fmt" "github.com/gofrs/uuid" "github.com/ory/x/sqlcon" + + "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/x" ) type ( @@ -48,13 +52,95 @@ func (p *Persister) ToUUID(ctx context.Context, text string) (uuid.UUID, error) ) } -func (p *Persister) FromUUID(ctx context.Context, id uuid.UUID) (rep string, err error) { - p.d.Logger().Trace("looking up UUID") +func (p *Persister) FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) { + p.d.Logger().Trace("looking up UUIDs") + + // We need to paginate on the ids, because we want to get the exact chunk of + // string representations for the given ids. + pagination, _ := internalPaginationFromOptions(opts...) + pageSize := pagination.PerPage + + // Build a map from UUID -> indices in the result. + idIdx := make(map[uuid.UUID][]int) + for i, id := range ids { + if ids, ok := idIdx[id]; ok { + idIdx[id] = append(ids, i) + } else { + idIdx[id] = []int{i} + } + } + + res = make([]string, len(ids)) - m := &UUIDMapping{} - if err := sqlcon.HandleError(p.Connection(ctx).Find(m, id)); err != nil { - return "", err + for i := 0; i < len(ids); i += pageSize { + end := i + pageSize + if end > len(ids) { + end = len(ids) + } + idsToLookup := ids[i:end] + mappings := &[]UUIDMapping{} + query := p.Connection(ctx).Where("id in (?)", idsToLookup) + if err := sqlcon.HandleError(query.All(mappings)); err != nil { + return []string{}, err + } + + // Write the representation to the correct index. + for _, m := range *mappings { + for _, idx := range idIdx[m.ID] { + res[idx] = m.StringRepresentation + } + } } - return m.StringRepresentation, nil + return +} + +func (p *Persister) replaceWithUUID(ctx context.Context, s *string) error { + if s == nil { + return nil + } + uuid, err := p.ToUUID(ctx, *s) + if err != nil { + return err + } + *s = uuid.String() + + return nil +} + +func (p *Persister) MapFieldsToUUID(ctx context.Context, m relationtuple.UUIDMappable) error { + for _, s := range m.UUIDMappableFields() { + if err := p.replaceWithUUID(ctx, s); err != nil { + return err + } + } + return nil +} + +func (p *Persister) MapFieldsFromUUID(ctx context.Context, m relationtuple.UUIDMappable) error { + ids := make([]uuid.UUID, len(m.UUIDMappableFields())) + for i, field := range m.UUIDMappableFields() { + if field == nil { + continue + } + id, err := uuid.FromString(*field) + if err != nil { + return err + } + ids[i] = id + } + reps, err := p.FromUUID(ctx, ids) + if err != nil { + return err + } + for i, field := range m.UUIDMappableFields() { + if field == nil { + continue + } + if reps[i] == "" { + return fmt.Errorf("failed to map %s", ids[i]) + } + *field = reps[i] + } + return nil } diff --git a/internal/relationtuple/definitions.go b/internal/relationtuple/definitions.go index 9fc1c63c9..a3db8f77a 100644 --- a/internal/relationtuple/definitions.go +++ b/internal/relationtuple/definitions.go @@ -24,6 +24,7 @@ import ( type ( ManagerProvider interface { RelationTupleManager() Manager + UUIDMappingManager() UUIDMappingManager } Manager interface { GetRelationTuples(ctx context.Context, query *RelationQuery, options ...x.PaginationOptionSetter) ([]*InternalRelationTuple, string, error) @@ -88,6 +89,9 @@ type Subject interface { // swagger:ignore ToProto() *rts.Subject + + // swagger:ignore + UUIDMappable } // swagger:ignore @@ -97,6 +101,18 @@ type InternalRelationTuple struct { Relation string `json:"relation"` Subject Subject `json:"subject"` } +type InternalRelationTuples []*InternalRelationTuple + +func (rt *InternalRelationTuple) UUIDMappableFields() []*string { + return append([]*string{&rt.Object}, rt.Subject.UUIDMappableFields()...) +} + +func (rtt InternalRelationTuples) UUIDMappableFields() (fields []*string) { + for _, rt := range rtt { + fields = append(fields, rt.UUIDMappableFields()...) + } + return fields +} // swagger:parameters getExpand type SubjectSet struct { @@ -141,6 +157,16 @@ func SubjectFromString(s string) (Subject, error) { return (&SubjectID{}).FromString(s) } +// swagger:ignore +func (s *SubjectID) UUIDMappableFields() []*string { + return []*string{&s.ID} +} + +// swagger:ignore +func (s *SubjectSet) UUIDMappableFields() []*string { + return []*string{&s.Object} +} + // swagger:ignore func SubjectFromProto(gs *rts.Subject) (Subject, error) { switch s := gs.GetRef().(type) { @@ -447,6 +473,17 @@ func (q *RelationQuery) FromProto(query TupleData) (*RelationQuery, error) { return q, nil } +func (q *RelationQuery) UUIDMappableFields() []*string { + res := []*string{&q.Object} + if q.SubjectID != nil { + res = append(res, q.SubjectID) + } + if q.SubjectSet != nil { + res = append(res, q.SubjectSet.UUIDMappableFields()...) + } + return res +} + const ( subjectIDKey = "subject_id" subjectSetNamespaceKey = "subject_set.namespace" @@ -684,3 +721,7 @@ func (t *ManagerWrapper) TransactRelationTuples(ctx context.Context, insert []*I func (t *ManagerWrapper) RelationTupleManager() Manager { return t } + +func (t *ManagerWrapper) UUIDMappingManager() UUIDMappingManager { + return t.Reg.UUIDMappingManager() +} diff --git a/internal/relationtuple/read_server.go b/internal/relationtuple/read_server.go index cc36f682e..f36117ab4 100644 --- a/internal/relationtuple/read_server.go +++ b/internal/relationtuple/read_server.go @@ -16,7 +16,10 @@ import ( "github.com/ory/keto/internal/x" ) -var _ rts.ReadServiceServer = (*handler)(nil) +var ( + _ rts.ReadServiceServer = (*handler)(nil) + _ = (*getRelationsParams)(nil) +) func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationTuplesRequest) (*rts.ListRelationTuplesResponse, error) { if req.Query == nil { @@ -28,6 +31,7 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT return nil, err } + h.d.UUIDMappingManager().MapFieldsToUUID(ctx, q) rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(ctx, q, x.WithSize(int(req.PageSize)), x.WithToken(req.PageToken), @@ -35,6 +39,7 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT if err != nil { return nil, err } + h.d.UUIDMappingManager().MapFieldsFromUUID(ctx, InternalRelationTuples(rels)) resp := &rts.ListRelationTuplesResponse{ RelationTuples: make([]*rts.RelationTuple, len(rels)), @@ -48,7 +53,6 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT } // swagger:parameters getRelationTuples -// nolint:deadcode,unused type getRelationsParams struct { // Namespace of the Relation Tuple // @@ -139,11 +143,13 @@ func (h *handler) getRelations(w http.ResponseWriter, r *http.Request, _ httprou paginationOpts = append(paginationOpts, x.WithSize(int(s))) } + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), query) rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(r.Context(), query, paginationOpts...) if err != nil { h.d.Writer().WriteError(w, r, err) return } + h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), InternalRelationTuples(rels)) resp := &GetResponse{ RelationTuples: rels, diff --git a/internal/relationtuple/read_server_test.go b/internal/relationtuple/read_server_test.go index e772ea019..fb474fc1a 100644 --- a/internal/relationtuple/read_server_test.go +++ b/internal/relationtuple/read_server_test.go @@ -103,7 +103,7 @@ func TestReadHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{ "object": {obj}, @@ -114,7 +114,7 @@ func TestReadHandlers(t *testing.T) { var respMsg relationtuple.GetResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&respMsg)) assert.Equal(t, 1, len(respMsg.RelationTuples)) - assert.Contains(t, rts, respMsg.RelationTuples[0]) + assert.Containsf(t, rts, respMsg.RelationTuples[0], "expected to find %q in %q", respMsg.RelationTuples[0].String(), rts) assert.Equal(t, "", respMsg.NextPageToken) }) @@ -145,7 +145,7 @@ func TestReadHandlers(t *testing.T) { Subject: &relationtuple.SubjectID{ID: "s2"}, }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) var firstResp relationtuple.GetResponse t.Run("case=first page", func(t *testing.T) { diff --git a/internal/relationtuple/swagger_definitions.go b/internal/relationtuple/swagger_definitions.go index 7c32e2a38..5e215a002 100644 --- a/internal/relationtuple/swagger_definitions.go +++ b/internal/relationtuple/swagger_definitions.go @@ -1,7 +1,11 @@ package relationtuple +var ( + _ = (*relationTupleWithRequired)(nil) + _ = (*patchPayload)(nil) +) + // swagger:model InternalRelationTuple -// nolint:deadcode,unused type relationTupleWithRequired struct { // Namespace of the Relation Tuple // @@ -31,7 +35,6 @@ type relationTupleWithRequired struct { // The patch request payload // // swagger:parameters patchRelationTuples -// nolint:deadcode,unused type patchPayload struct { // in:body Payload []*PatchDelta @@ -41,3 +44,15 @@ type PatchDelta struct { Action patchAction `json:"action"` RelationTuple *InternalRelationTuple `json:"relation_tuple"` } + +// swagger:ignore +type PatchDeltas []*PatchDelta + +func (p PatchDeltas) UUIDMappableFields() (res []*string) { + for _, pd := range p { + if pd.RelationTuple != nil { + res = append(res, pd.RelationTuple.UUIDMappableFields()...) + } + } + return +} diff --git a/internal/relationtuple/test_helper.go b/internal/relationtuple/test_helper.go new file mode 100644 index 000000000..b51ed2cfd --- /dev/null +++ b/internal/relationtuple/test_helper.go @@ -0,0 +1,18 @@ +package relationtuple + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// MapAndWriteTuples is a test helper to write relation tuples to the database +// while mapping all strings to UUIDs. +func MapAndWriteTuples(t *testing.T, m ManagerProvider, tuples ...*InternalRelationTuple) { + t.Helper() + + m.UUIDMappingManager().MapFieldsToUUID(context.Background(), InternalRelationTuples(tuples)) + require.NoError(t, m.RelationTupleManager().WriteRelationTuples(context.Background(), tuples...)) + m.UUIDMappingManager().MapFieldsFromUUID(context.Background(), InternalRelationTuples(tuples)) +} diff --git a/internal/relationtuple/transact_server.go b/internal/relationtuple/transact_server.go index 5bdd02e72..a42769b8c 100644 --- a/internal/relationtuple/transact_server.go +++ b/internal/relationtuple/transact_server.go @@ -12,7 +12,11 @@ import ( "github.com/pkg/errors" ) -var _ rts.WriteServiceServer = (*handler)(nil) +var ( + _ rts.WriteServiceServer = (*handler)(nil) + _ = (*bodyRelationTuple)(nil) + _ = (*queryRelationTuple)(nil) +) func protoTuplesWithAction(deltas []*rts.RelationTupleDelta, action rts.RelationTupleDelta_Action) (filtered []*InternalRelationTuple, err error) { for _, d := range deltas { @@ -38,6 +42,7 @@ func (h *handler) TransactRelationTuples(ctx context.Context, req *rts.TransactR return nil, err } + h.d.UUIDMappingManager().MapFieldsToUUID(ctx, InternalRelationTuples(append(insertTuples, deleteTuples...))) err = h.d.RelationTupleManager().TransactRelationTuples(ctx, insertTuples, deleteTuples) if err != nil { return nil, err @@ -62,6 +67,7 @@ func (h *handler) DeleteRelationTuples(ctx context.Context, req *rts.DeleteRelat return nil, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error())) } + h.d.UUIDMappingManager().MapFieldsToUUID(ctx, q) if err := h.d.RelationTupleManager().DeleteAllRelationTuples(ctx, q); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithError(err.Error())) } @@ -72,7 +78,6 @@ func (h *handler) DeleteRelationTuples(ctx context.Context, req *rts.DeleteRelat // The basic ACL relation tuple // // swagger:parameters postCheck createRelationTuple -// nolint:deadcode,unused type bodyRelationTuple struct { // in: body Payload RelationQuery @@ -81,7 +86,6 @@ type bodyRelationTuple struct { // The basic ACL relation tuple // // swagger:parameters getCheck deleteRelationTuples -// nolint:deadcode,unused type queryRelationTuple struct { // Namespace of the Relation Tuple // @@ -151,11 +155,13 @@ func (h *handler) createRelation(w http.ResponseWriter, r *http.Request, _ httpr h.d.Logger().WithFields(rel.ToLoggerFields()).Debug("creating relation tuple") + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), &rel) if err := h.d.RelationTupleManager().WriteRelationTuples(r.Context(), &rel); err != nil { h.d.Logger().WithError(err).WithFields(rel.ToLoggerFields()).Errorf("got an error while creating the relation tuple") h.d.Writer().WriteError(w, r, err) return } + h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), &rel) q, err := rel.ToURLQuery() if err != nil { @@ -198,6 +204,7 @@ func (h *handler) deleteRelations(w http.ResponseWriter, r *http.Request, _ http } l.Debug("deleting relation tuples") + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), query) if err := h.d.RelationTupleManager().DeleteAllRelationTuples(r.Context(), query); err != nil { l.WithError(err).Errorf("got an error while deleting relation tuples") h.d.Writer().WriteError(w, r, herodot.ErrInternalServerError.WithError(err.Error())) @@ -254,7 +261,13 @@ func (h *handler) patchRelations(w http.ResponseWriter, r *http.Request, _ httpr } } - if err := h.d.RelationTupleManager().TransactRelationTuples(r.Context(), internalTuplesWithAction(deltas, ActionInsert), internalTuplesWithAction(deltas, ActionDelete)); err != nil { + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), PatchDeltas(deltas)) + if err := h.d.RelationTupleManager(). + TransactRelationTuples( + r.Context(), + internalTuplesWithAction(deltas, ActionInsert), + internalTuplesWithAction(deltas, ActionDelete)); err != nil { + h.d.Writer().WriteError(w, r, err) return } diff --git a/internal/relationtuple/transact_server_test.go b/internal/relationtuple/transact_server_test.go index 921e509c6..845379561 100644 --- a/internal/relationtuple/transact_server_test.go +++ b/internal/relationtuple/transact_server_test.go @@ -80,10 +80,14 @@ func TestWriteHandlers(t *testing.T) { assert.JSONEq(t, string(payload), string(body)) t.Run("check=is contained in the manager", func(t *testing.T) { + query := rt.ToQuery() + reg.UUIDMappingManager().MapFieldsToUUID(context.Background(), query) // set a size > 1 just to make sure it gets all - actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), rt.ToQuery(), x.WithSize(10)) + actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), query, x.WithSize(10)) require.NoError(t, err) - assert.Equal(t, []*relationtuple.InternalRelationTuple{rt}, actualRTs) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), query) + assert.Equalf(t, []*relationtuple.InternalRelationTuple{rt}, actualRTs, "want: %s\ngot: %s", rt.String(), actualRTs[0].String()) }) t.Run("check=is gettable with the returned URL", func(t *testing.T) { @@ -136,6 +140,7 @@ func TestWriteHandlers(t *testing.T) { Namespace: nspace.Name, }) require.NoError(t, err) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actual)) assert.Equal(t, "", next) assert.Len(t, actual, 2) for _, rt := range rts { @@ -154,7 +159,7 @@ func TestWriteHandlers(t *testing.T) { Relation: "deleted rel", Subject: &relationtuple.SubjectID{ID: "deleted subj"}, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rt)) + relationtuple.MapAndWriteTuples(t, reg, rt) q, err := rt.ToURLQuery() require.NoError(t, err) @@ -167,6 +172,7 @@ func TestWriteHandlers(t *testing.T) { // set a size > 1 just to make sure it gets all actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), rt.ToQuery(), x.WithSize(10)) require.NoError(t, err) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) assert.Equal(t, []*relationtuple.InternalRelationTuple{}, actualRTs) }) @@ -188,7 +194,7 @@ func TestWriteHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) q := url.Values{ "namespace": {nspace.Name}, @@ -206,6 +212,7 @@ func TestWriteHandlers(t *testing.T) { actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), query, x.WithSize(10)) require.NoError(t, err) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) assert.Equal(t, []*relationtuple.InternalRelationTuple{}, actualRTs) }) }) @@ -234,7 +241,7 @@ func TestWriteHandlers(t *testing.T) { }, }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), deltas[1].RelationTuple)) + relationtuple.MapAndWriteTuples(t, reg, deltas[1].RelationTuple) body, err := json.Marshal(deltas) require.NoError(t, err) @@ -249,6 +256,8 @@ func TestWriteHandlers(t *testing.T) { Relation: t.Name(), }) require.NoError(t, err) + err = reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) + require.NoError(t, err) assert.Equal(t, []*relationtuple.InternalRelationTuple{deltas[0].RelationTuple}, actualRTs) }) @@ -316,6 +325,7 @@ func TestWriteHandlers(t *testing.T) { actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), &relationtuple.RelationQuery{ Namespace: nspace.Name, }) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) require.NoError(t, err) assert.Equal(t, []*relationtuple.InternalRelationTuple{deltas[0].RelationTuple}, actualRTs) }) diff --git a/internal/relationtuple/uuid_mapping.go b/internal/relationtuple/uuid_mapping.go new file mode 100644 index 000000000..54aea61c2 --- /dev/null +++ b/internal/relationtuple/uuid_mapping.go @@ -0,0 +1,102 @@ +package relationtuple + +import ( + "context" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/x" +) + +type ( + // UUIDMappable is an interface for objects that have fields that can be + // mapped to and from UUIDs. + UUIDMappable interface{ UUIDMappableFields() []*string } + + UUIDMappingManager interface { + // ToUUID returns the mapped UUID for the given string representation. + // If the string representation is not mapped, a new UUID will be + // created automatically. + ToUUID(ctx context.Context, representation string) (uuid.UUID, error) + + // MapFields maps all fields of the given object to UUIDs. + MapFieldsToUUID(ctx context.Context, m UUIDMappable) error + + // MapFieldsFromUUID maps all fields of the given object from UUIDs to + // their string value. + MapFieldsFromUUID(ctx context.Context, m UUIDMappable) error + + // FromUUID returns the text representations for the given UUIDs, such + // that ids[i] is mapped to reps[i]. + // + // Of the pagination options, only the page size is considered and used + // as a batch size. + FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (reps []string, err error) + } +) + +func UUIDMappingManagerTest(t *testing.T, m UUIDMappingManager) { + ctx := context.Background() + + t.Run("case=ToUUID_FromUUID", func(t *testing.T) { + rep1 := "foo" + id, err := m.ToUUID(ctx, rep1) + require.NoError(t, err) + + rep2, err := m.FromUUID(ctx, []uuid.UUID{id}) + assert.NoError(t, err) + assert.Equal(t, rep1, rep2[0]) + }) + + t.Run("case=Idempotent_ToUUID", func(t *testing.T) { + id1, err := m.ToUUID(ctx, "string") + assert.NoError(t, err) + id2, err := m.ToUUID(ctx, "string") + assert.NoError(t, err) + assert.Equal(t, id1, id2) + }) + + // Test that the batch mapping preserves ordering, i.e. id[i] is mapped to + // rep[i]. + t.Run("case=Batch_ToUUID_Paginates", func(t *testing.T) { + expected := []string{"foo", "foo", "bar", "baz"} + ids := make([]uuid.UUID, len(expected)) + for i, s := range expected { + var err error + ids[i], err = m.ToUUID(ctx, s) + assert.NoError(t, err) + } + + actual, err := m.FromUUID(ctx, ids, x.WithSize(1)) + assert.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("case=IdempotentMapFieldsToAndFromUUIDs", func(t *testing.T) { + tc := []struct { + name string + obj UUIDMappable + copy UUIDMappable + }{ + { + name: "RelationTuple", + obj: &InternalRelationTuple{Namespace: "n", Relation: "r", Object: "Object", Subject: &SubjectID{ID: "Subject"}}, + copy: &InternalRelationTuple{Namespace: "n", Relation: "r", Object: "Object", Subject: &SubjectID{ID: "Subject"}}, + }, { + name: "SubjectID", + obj: &SubjectID{ID: "sub"}, + copy: &SubjectID{ID: "sub"}, + }, + } + for _, tt := range tc { + t.Run("type="+tt.name, func(t *testing.T) { + assert.NoError(t, m.MapFieldsToUUID(ctx, tt.obj)) + assert.NoError(t, m.MapFieldsFromUUID(ctx, tt.obj)) + assert.Equal(t, tt.copy, tt.obj) + }) + } + }) +} diff --git a/internal/uuidmapping/definitions.go b/internal/uuidmapping/definitions.go deleted file mode 100644 index 9ddeac751..000000000 --- a/internal/uuidmapping/definitions.go +++ /dev/null @@ -1,48 +0,0 @@ -package uuidmapping - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/gofrs/uuid" -) - -type ( - ManagerProvider interface { - UUIDMappingManager() Manager - } - Manager interface { - // ToUUID returns the mapped UUID for the given string representation. - // If the string representation is not mapped, a new UUID will be - // created automatically. - ToUUID(ctx context.Context, representation string) (uuid.UUID, error) - - // FromUUID returns the text representation for the given UUID. - FromUUID(ctx context.Context, id uuid.UUID) (text string, err error) - } -) - -func ManagerTest(t *testing.T, m Manager) { - ctx := context.Background() - - t.Run("case=ToUUID_FromUUID", func(t *testing.T) { - rep1 := "foo" - id, err := m.ToUUID(ctx, rep1) - require.NoError(t, err) - - rep2, err := m.FromUUID(ctx, id) - assert.NoError(t, err) - assert.Equal(t, rep1, rep2) - }) - - t.Run("case=FromUUID", func(t *testing.T) { - id1, err := m.ToUUID(ctx, "string") - assert.NoError(t, err) - id2, err := m.ToUUID(ctx, "string") - assert.NoError(t, err) - assert.Equal(t, id1, id2) - }) -}