From f49d3cb5533b19a48679de7273ff198c9f50a6de Mon Sep 17 00:00:00 2001 From: hperl <34397+hperl@users.noreply.github.com> Date: Tue, 3 May 2022 13:09:59 +0200 Subject: [PATCH] feat: map subject/object strings to UUIDs (#792) In relation tuples and related API objects such as queries, subjects (including subject sets) and objects are now automatically mapped to and from UUIDs. The mapping is done in each handler and for each protocol (HTTP, gRPC) separately. Below the handler, all objects and subjects are now UUIDs, but still handled as strings, for the following reasons: * No duplication of datastructures (one for type `string`, one for type `uuid.UUID`). * No unnecessary copying, the mapping is done in-place and batched across multiple objects. Additionally, the migration (adding the string values to the mapping table and replacing the values with UUIDs) is now performed automatically as part of the migrations. Tests are in `migratest`, which now also handles `GoMigrations` properly. This method replaces the dedicated command in `cmd/migrate/...` for the UUID mappings. XXX WIP XXX WIP XXX WIP migrations_test package XXX WIP migrations_test package XXX WIP migrations_test package XXX WIP migrations_test package XXX WIP migrations_test package XXX WIP migrations_test package WIP --- cmd/expand/root_test.go | 9 +- cmd/migrate/migrate_test.go | 4 +- cmd/migrate/migrate_uuid_mapping.go | 49 ----- cmd/migrate/root.go | 1 - internal/check/handler.go | 17 +- internal/check/handler_test.go | 4 +- internal/driver/config/namespace_memory.go | 2 +- internal/driver/registry.go | 15 +- internal/driver/registry_default.go | 14 +- internal/e2e/cases_test.go | 9 +- internal/expand/handler.go | 29 +-- internal/expand/handler_test.go | 48 ++++- internal/expand/tree.go | 15 +- internal/persistence/definitions.go | 3 +- internal/persistence/sql/full_test.go | 5 +- .../migrations/migratest/migration_test.go | 79 ++++++-- .../sql/migrations/single_table.go | 14 +- .../sql/migrations/single_table_test.go | 31 +-- .../sql/migrations/uuid_mapping.go | 114 ------------ .../sql/migrations/uuid_mapping_test.go | 136 -------------- .../uuidmapping/uuid_mapping_migrator.go | 176 ++++++++++++++++++ .../uuidmapping/uuid_mapping_migrator_test.go | 120 ++++++++++++ internal/persistence/sql/persister.go | 10 +- internal/persistence/sql/relationtuples.go | 4 +- internal/persistence/sql/uuid_mapping.go | 98 +++++++++- internal/relationtuple/definitions.go | 41 ++++ internal/relationtuple/read_server.go | 10 +- internal/relationtuple/read_server_test.go | 6 +- internal/relationtuple/swagger_definitions.go | 19 +- internal/relationtuple/test_helper.go | 18 ++ internal/relationtuple/transact_server.go | 21 ++- .../relationtuple/transact_server_test.go | 20 +- internal/relationtuple/uuid_mapping.go | 102 ++++++++++ internal/uuidmapping/definitions.go | 48 ----- 34 files changed, 816 insertions(+), 475 deletions(-) delete mode 100644 cmd/migrate/migrate_uuid_mapping.go delete mode 100644 internal/persistence/sql/migrations/uuid_mapping.go delete mode 100644 internal/persistence/sql/migrations/uuid_mapping_test.go create mode 100644 internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go create mode 100644 internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator_test.go create mode 100644 internal/relationtuple/test_helper.go create mode 100644 internal/relationtuple/uuid_mapping.go delete mode 100644 internal/uuidmapping/definitions.go diff --git a/cmd/expand/root_test.go b/cmd/expand/root_test.go index 4e87ab6a3..0be0a22bc 100644 --- a/cmd/expand/root_test.go +++ b/cmd/expand/root_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/ory/x/cmdx" - "github.com/stretchr/testify/assert" "github.com/ory/keto/cmd/client" @@ -18,12 +17,16 @@ func TestExpandCommand(t *testing.T) { t.Run("case=unknown tuple", func(t *testing.T) { t.Run("format=JSON", func(t *testing.T) { - stdOut := ts.Cmd.ExecNoErr(t, "access", nspace.Name, "object", "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) + stdOut := ts.Cmd.ExecNoErr(t, + "access", nspace.Name, + "object", "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) assert.Equal(t, "null\n", stdOut) }) t.Run("format=default", func(t *testing.T) { - stdOut := ts.Cmd.ExecNoErr(t, "access", nspace.Name, "object", "--"+cmdx.FlagFormat, string(cmdx.FormatDefault)) + stdOut := ts.Cmd.ExecNoErr(t, + "access", nspace.Name, + "object", "--"+cmdx.FlagFormat, string(cmdx.FormatDefault)) assert.Contains(t, stdOut, "empty tree") }) }) diff --git a/cmd/migrate/migrate_test.go b/cmd/migrate/migrate_test.go index d3bde2b16..849ce1026 100644 --- a/cmd/migrate/migrate_test.go +++ b/cmd/migrate/migrate_test.go @@ -105,7 +105,7 @@ func TestMigrate(t *testing.T) { t.Cleanup(func() { // migrate all down - t.Log(cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) + t.Logf("cleanup:\n%s\n", cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) }) parts := strings.Split(stdOut, "Are you sure that you want to apply this migration?") @@ -120,7 +120,7 @@ func TestMigrate(t *testing.T) { t.Cleanup(func() { // migrate all down - t.Log(cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) + t.Logf("cleanup:\n%s\n", cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) }) parts := strings.Split(out, "Applying migrations...") diff --git a/cmd/migrate/migrate_uuid_mapping.go b/cmd/migrate/migrate_uuid_mapping.go deleted file mode 100644 index 8c15c60c4..000000000 --- a/cmd/migrate/migrate_uuid_mapping.go +++ /dev/null @@ -1,49 +0,0 @@ -package migrate - -import ( - "fmt" - - "github.com/ory/x/cmdx" - "github.com/ory/x/flagx" - "github.com/pkg/errors" - "github.com/spf13/cobra" - - "github.com/ory/keto/internal/driver" - "github.com/ory/keto/internal/persistence" - "github.com/ory/keto/internal/persistence/sql/migrations" -) - -func newMigrateUUIDMappingCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "uuid-mapping", - Short: "Migrate the non-UUID subject and object names to UUIDs.", - Long: `Migrate the non-UUID subject and object names to UUIDs. -This step only has to be executed once. -Please ensure that you have a backup in case something goes wrong!"`, - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, _ []string) error { - reg, err := driver.NewDefaultRegistry(cmd.Context(), cmd.Flags(), false) - if errors.Is(err, persistence.ErrNetworkMigrationsMissing) { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), - "Migrations were not applied yet, please apply them first using `keto migrate up`.") - return cmdx.FailSilently(cmd) - } else if err != nil { - return err - } - - if !flagx.MustGetBool(cmd, FlagYes) && - !cmdx.AskForConfirmation( - "Are you sure you want to migrate the subject and object names to UUIDs?", - cmd.InOrStdin(), cmd.OutOrStdout()) { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "OK, aborting.") - return nil - } - - migrator := migrations.NewToUUIDMappingMigrator(reg) - return migrator.MigrateUUIDMappings(cmd.Context()) - }, - } - RegisterYesFlag(cmd.Flags()) - - return cmd -} diff --git a/cmd/migrate/root.go b/cmd/migrate/root.go index ddabf0f3d..f29f9d3f6 100644 --- a/cmd/migrate/root.go +++ b/cmd/migrate/root.go @@ -17,7 +17,6 @@ func newMigrateCmd(opts []ketoctx.Option) *cobra.Command { newStatusCmd(opts), newUpCmd(opts), newDownCmd(opts), - newMigrateUUIDMappingCmd(), ) return cmd } diff --git a/internal/check/handler.go b/internal/check/handler.go index 66d509cd1..8dcc7707e 100644 --- a/internal/check/handler.go +++ b/internal/check/handler.go @@ -5,18 +5,14 @@ import ( "encoding/json" "net/http" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" "github.com/pkg/errors" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "google.golang.org/grpc" "github.com/ory/keto/internal/relationtuple" - - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/x" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type ( @@ -30,7 +26,10 @@ type ( } ) -var _ rts.CheckServiceServer = (*Handler)(nil) +var ( + _ rts.CheckServiceServer = (*Handler)(nil) + _ *getCheckRequest = nil +) func NewHandler(d handlerDependencies) *Handler { return &Handler{d: d} @@ -64,7 +63,6 @@ type RESTResponse struct { } // swagger:parameters getCheck postCheck -// nolint:deadcode,unused type getCheckRequest struct { // in:query MaxDepth int `json:"max-depth"` @@ -105,6 +103,7 @@ func (h *Handler) getCheck(w http.ResponseWriter, r *http.Request, _ httprouter. return } + h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), tuple) allowed, err := h.d.PermissionEngine().SubjectIsAllowed(r.Context(), tuple, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -151,6 +150,7 @@ func (h *Handler) postCheck(w http.ResponseWriter, r *http.Request, _ httprouter return } + h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), &tuple) allowed, err := h.d.PermissionEngine().SubjectIsAllowed(r.Context(), &tuple, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -171,6 +171,7 @@ func (h *Handler) Check(ctx context.Context, req *rts.CheckRequest) (*rts.CheckR return nil, err } + h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(ctx, tuple) allowed, err := h.d.PermissionEngine().SubjectIsAllowed(ctx, tuple, int(req.MaxDepth)) // TODO add content change handling if err != nil { diff --git a/internal/check/handler_test.go b/internal/check/handler_test.go index 5f81bad24..19548dbf9 100644 --- a/internal/check/handler_test.go +++ b/internal/check/handler_test.go @@ -23,6 +23,8 @@ import ( ) func assertAllowed(t *testing.T, resp *http.Response) { + t.Helper() + body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -97,7 +99,7 @@ func TestRESTHandler(t *testing.T) { Relation: "r", Subject: &relationtuple.SubjectID{ID: "s"}, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rt)) + relationtuple.MapAndWriteTuples(t, reg, rt) q, err := rt.ToURLQuery() require.NoError(t, err) diff --git a/internal/driver/config/namespace_memory.go b/internal/driver/config/namespace_memory.go index 7f451c35f..2bd1e184b 100644 --- a/internal/driver/config/namespace_memory.go +++ b/internal/driver/config/namespace_memory.go @@ -34,7 +34,7 @@ func (s *memoryNamespaceManager) GetNamespaceByName(_ context.Context, name stri } } - return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf("Unknown namespace with name %s.", name)) + return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf("Unknown namespace with name %q.", name)) } func (s *memoryNamespaceManager) GetNamespaceByConfigID(_ context.Context, id int32) (*namespace.Namespace, error) { diff --git a/internal/driver/registry.go b/internal/driver/registry.go index 62e8a6231..eb6f914b0 100644 --- a/internal/driver/registry.go +++ b/internal/driver/registry.go @@ -4,21 +4,15 @@ import ( "context" "net/http" - prometheus "github.com/ory/x/prometheusx" - "github.com/gobuffalo/pop/v6" - - "github.com/ory/keto/internal/driver/config" - "github.com/ory/keto/internal/uuidmapping" - - "github.com/spf13/cobra" - - "google.golang.org/grpc" - "github.com/ory/x/healthx" + prometheus "github.com/ory/x/prometheusx" "github.com/ory/x/tracing" + "github.com/spf13/cobra" + "google.golang.org/grpc" "github.com/ory/keto/internal/check" + "github.com/ory/keto/internal/driver/config" "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/persistence" "github.com/ory/keto/internal/relationtuple" @@ -34,7 +28,6 @@ type ( x.WriterProvider relationtuple.ManagerProvider - uuidmapping.ManagerProvider expand.EngineProvider check.EngineProvider persistence.Migrator diff --git a/internal/driver/registry_default.go b/internal/driver/registry_default.go index 6e912ac6c..1ec705794 100644 --- a/internal/driver/registry_default.go +++ b/internal/driver/registry_default.go @@ -8,6 +8,7 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/ory/herodot" "github.com/ory/x/dbal" + "github.com/ory/x/fsx" "github.com/ory/x/healthx" "github.com/ory/x/logrusx" "github.com/ory/x/metricsx" @@ -24,8 +25,9 @@ import ( "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/persistence" "github.com/ory/keto/internal/persistence/sql" + _ "github.com/ory/keto/internal/persistence/sql/migrations" + "github.com/ory/keto/internal/persistence/sql/migrations/uuidmapping" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" "github.com/ory/keto/internal/x" "github.com/ory/keto/ketoctx" rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" @@ -152,7 +154,7 @@ func (r *RegistryDefault) RelationTupleManager() relationtuple.Manager { return r.p } -func (r *RegistryDefault) UUIDMappingManager() uuidmapping.Manager { +func (r *RegistryDefault) UUIDMappingManager() relationtuple.UUIDMappingManager { if r.p == nil { panic("no relation tuple manager, but expected to have one") } @@ -186,10 +188,16 @@ func (r *RegistryDefault) MigrationBox(ctx context.Context) (*popx.MigrationBox, if err != nil { return nil, err } - mb, err := sql.NewMigrationBox(c, r.Logger(), r.Tracer(ctx)) + + mb, err := popx.NewMigrationBox( + fsx.Merge(sql.Migrations, networkx.Migrations), + popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0), + popx.WithGoMigrations(uuidmapping.Migrations), + ) if err != nil { return nil, err } + r.mb = mb } return r.mb, nil diff --git a/internal/e2e/cases_test.go b/internal/e2e/cases_test.go index c71d9a684..01e357930 100644 --- a/internal/e2e/cases_test.go +++ b/internal/e2e/cases_test.go @@ -5,17 +5,14 @@ import ( "strconv" "testing" - "github.com/stretchr/testify/require" - "github.com/ory/herodot" - - "github.com/ory/keto/internal/x" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/x" ) func runCases(c client, addNamespace func(*testing.T, ...*namespace.Namespace)) func(*testing.T) { @@ -188,7 +185,7 @@ func runCases(c client, addNamespace func(*testing.T, ...*namespace.Namespace)) Relation: "rel", } resp := c.queryTuple(t, q) - require.Equal(t, resp.RelationTuples, rts) + require.ElementsMatch(t, resp.RelationTuples, rts) c.deleteAllTuples(t, q) resp = c.queryTuple(t, q) diff --git a/internal/expand/handler.go b/internal/expand/handler.go index abf32803c..d0d818a76 100644 --- a/internal/expand/handler.go +++ b/internal/expand/handler.go @@ -4,17 +4,13 @@ import ( "context" "net/http" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "google.golang.org/grpc" "github.com/ory/keto/internal/relationtuple" - - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/x" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type ( @@ -28,7 +24,10 @@ type ( } ) -var _ rts.ExpandServiceServer = (*handler)(nil) +var ( + _ rts.ExpandServiceServer = (*handler)(nil) + _ *getExpandRequest = nil +) const RouteBase = "/relation-tuples/expand" @@ -49,7 +48,6 @@ func (h *handler) RegisterReadGRPC(s *grpc.Server) { func (h *handler) RegisterWriteGRPC(s *grpc.Server) {} // swagger:parameters getExpand -// nolint:deadcode,unused type getExpandRequest struct { // in:query MaxDepth int `json:"max-depth"` @@ -81,24 +79,31 @@ func (h *handler) getExpand(w http.ResponseWriter, r *http.Request, _ httprouter return } - res, err := h.d.ExpandEngine().BuildTree(r.Context(), (&relationtuple.SubjectSet{}).FromURLQuery(r.URL.Query()), maxDepth) + subject := (&relationtuple.SubjectSet{}).FromURLQuery(r.URL.Query()) + + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), subject) + res, err := h.d.ExpandEngine().BuildTree(r.Context(), subject, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) return } + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), res) h.d.Writer().Write(w, r, res) } func (h *handler) Expand(ctx context.Context, req *rts.ExpandRequest) (*rts.ExpandResponse, error) { - sub, err := relationtuple.SubjectFromProto(req.Subject) + subject, err := relationtuple.SubjectFromProto(req.Subject) if err != nil { return nil, err } - tree, err := h.d.ExpandEngine().BuildTree(ctx, sub, int(req.MaxDepth)) + + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsToUUID(ctx, subject) + res, err := h.d.ExpandEngine().BuildTree(ctx, subject, int(req.MaxDepth)) if err != nil { return nil, err } + h.d.ExpandEngine().d.UUIDMappingManager().MapFieldsFromUUID(ctx, res) - return &rts.ExpandResponse{Tree: tree.ToProto()}, nil + return &rts.ExpandResponse{Tree: res.ToProto()}, nil } diff --git a/internal/expand/handler_test.go b/internal/expand/handler_test.go index 95ae81875..6ba553bf3 100644 --- a/internal/expand/handler_test.go +++ b/internal/expand/handler_test.go @@ -79,20 +79,20 @@ func TestRESTHandler(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), []*relationtuple.InternalRelationTuple{ - { + relationtuple.MapAndWriteTuples(t, reg, + &relationtuple.InternalRelationTuple{ Namespace: nspace.Name, Object: rootSub.Object, Relation: rootSub.Relation, Subject: expectedTree.Children[0].Subject, }, - { + &relationtuple.InternalRelationTuple{ Namespace: nspace.Name, Object: rootSub.Object, Relation: rootSub.Relation, Subject: expectedTree.Children[1].Subject, }, - }...)) + ) qs := rootSub.ToURLQuery() qs.Set("max-depth", "2") @@ -103,6 +103,44 @@ func TestRESTHandler(t *testing.T) { actualTree := expand.Tree{} require.NoError(t, json.NewDecoder(resp.Body).Decode(&actualTree)) - assert.Equal(t, expectedTree, &actualTree) + assertEqualTrees(t, expectedTree, &actualTree) }) } + +func assertEqualTrees(t *testing.T, expected, actual *expand.Tree) { + t.Helper() + assert.Truef(t, treesAreEqual(t, expected, actual), + "expected:\n%s\n\nactual:\n%s", expected.String(), actual.String()) +} + +func treesAreEqual(t *testing.T, expected, actual *expand.Tree) bool { + if expected == nil || actual == nil { + return expected == actual + } + + if expected.Type != actual.Type { + t.Logf("expected type %q, actual type %q", expected.Type, actual.Type) + return false + } + if expected.Subject.String() != actual.Subject.String() { + t.Logf("expected subject: %q, actual subject: %q", expected.Subject.String(), actual.Subject.String()) + return false + } + if len(expected.Children) != len(actual.Children) { + t.Logf("expected len(children)=%d, actual len(children)=%d", len(expected.Children), len(actual.Children)) + return false + } + + // For children, we check for equality disregarding the order +outer: + for _, expectedChild := range expected.Children { + for _, actualChild := range actual.Children { + if treesAreEqual(t, expectedChild, actualChild) { + continue outer + } + } + t.Logf("expected child:\n%s\n\nactual child:\n%s", expectedChild.String(), actual.String()) + return false + } + return true +} diff --git a/internal/expand/tree.go b/internal/expand/tree.go index 4340ce3e3..3f2c8ec27 100644 --- a/internal/expand/tree.go +++ b/internal/expand/tree.go @@ -217,7 +217,7 @@ func TreeFromProto(t *rts.SubjectTree) (*Tree, error) { func (t *Tree) String() string { if t == nil { - return "" + return "(nil)" } sub := t.Subject.String() @@ -233,3 +233,16 @@ func (t *Tree) String() string { return fmt.Sprintf("∪ %s\n├─ %s", sub, strings.Join(children, "\n├─ ")) } + +func (t *Tree) UUIDMappableFields() (res []*string) { + if t == nil { + return + } + if t.Subject != nil { + res = append(res, t.Subject.UUIDMappableFields()...) + } + for _, c := range t.Children { + res = append(res, c.UUIDMappableFields()...) + } + return +} diff --git a/internal/persistence/definitions.go b/internal/persistence/definitions.go index 46e87d937..0b1d267f3 100644 --- a/internal/persistence/definitions.go +++ b/internal/persistence/definitions.go @@ -9,13 +9,12 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" ) type ( Persister interface { relationtuple.Manager - uuidmapping.Manager + relationtuple.UUIDMappingManager Connection(ctx context.Context) *pop.Connection } diff --git a/internal/persistence/sql/full_test.go b/internal/persistence/sql/full_test.go index 595f9d051..ff772c9d9 100644 --- a/internal/persistence/sql/full_test.go +++ b/internal/persistence/sql/full_test.go @@ -14,7 +14,6 @@ import ( "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/persistence/sql" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" "github.com/ory/keto/internal/x/dbx" ) @@ -69,9 +68,9 @@ func TestPersister(t *testing.T) { relationtuple.IsolationTest(t, p0, p1, addNamespace(r, nspaces)) }) - t.Run("uuidmapping.ManagerTest", func(t *testing.T) { + t.Run("relationtuple.UUIDMappingManagerTest", func(t *testing.T) { p, _, _ := setup(t, dsn) - uuidmapping.ManagerTest(t, p) + relationtuple.UUIDMappingManagerTest(t, p) }) }) } diff --git a/internal/persistence/sql/migrations/migratest/migration_test.go b/internal/persistence/sql/migrations/migratest/migration_test.go index 8c7e7a4db..0a21c479d 100644 --- a/internal/persistence/sql/migrations/migratest/migration_test.go +++ b/internal/persistence/sql/migrations/migratest/migration_test.go @@ -2,15 +2,14 @@ package migratest import ( "context" + "io/fs" "os" + "regexp" + "strings" "testing" "time" "github.com/gobuffalo/pop/v6" - - "github.com/ory/keto/internal/driver/config" - "github.com/ory/keto/internal/namespace" - "github.com/gofrs/uuid" "github.com/ory/x/fsx" "github.com/ory/x/logrusx" @@ -22,11 +21,51 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/keto/internal/driver" + "github.com/ory/keto/internal/driver/config" + "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/persistence/sql" + "github.com/ory/keto/internal/persistence/sql/migrations/uuidmapping" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x/dbx" ) +func withTestdata(testdata fs.FS) func(*popx.MigrationBox) *popx.MigrationBox { + return func(m *popx.MigrationBox) *popx.MigrationBox { + fs.WalkDir(testdata, ".", func(path string, info fs.DirEntry, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if m, _ := regexp.MatchString(`\d+_testdata.sql`, info.Name()); !m { + return nil + } + version := strings.TrimSuffix(info.Name(), "_testdata.sql") + m.Migrations["up"] = append(m.Migrations["up"], popx.Migration{ + Version: version + "9", // run testdata after version + Path: path, + Name: "testdata", + DBType: "all", + Direction: "up", + Type: "sql", + Runner: func(m popx.Migration, _ *pop.Connection, tx *pop.Tx) error { + b, err := fs.ReadFile(testdata, m.Path) + if err != nil { + return err + } + _, err = tx.Exec(string(b)) + return err + }, + }) + + return nil + }) + + return m + } +} + func TestMigrations(t *testing.T) { const debugOnDisk = false @@ -50,7 +89,14 @@ func TestMigrations(t *testing.T) { } require.NoError(t, c.Store.(interface{ Ping() error }).Ping()) - tm := popx.NewTestMigrator(t, c, fsx.Merge(networkx.Migrations, os.DirFS("../sql")), os.DirFS("./testdata"), l) + tm, err := popx.NewMigrationBox( + fsx.Merge(sql.Migrations, networkx.Migrations, os.DirFS("./testdata")), + popx.NewMigrator(c, l, nil, 1*time.Minute), + popx.WithGoMigrations(uuidmapping.Migrations), + withTestdata(os.DirFS("./testdata")), + ) + require.NoError(t, err) + // cleanup first require.NoError(t, tm.Down(ctx, -1)) @@ -59,6 +105,12 @@ func TestMigrations(t *testing.T) { }) t.Run("suite=fixtures", func(t *testing.T) { + reg := driver.NewTestRegistry(t, db) + require.NoError(t, + reg.Config(ctx).Set(config.KeyNamespaces, []*namespace.Namespace{ + {ID: 1, Name: "foo"}})) + p, err := sql.NewPersister(ctx, reg, uuid.Must(uuid.FromString("77fdc5e0-2260-49da-8aae-c36ba255d05b"))) + t.Run("table=legacy namespaces", func(t *testing.T) { // as they are legacy, we expect them to be actually dropped @@ -66,17 +118,12 @@ func TestMigrations(t *testing.T) { }) t.Run("table=relation tuples", func(t *testing.T) { - reg := driver.NewTestRegistry(t, db) - require.NoError(t, - reg.Config(ctx).Set(config.KeyNamespaces, []*namespace.Namespace{{ID: 1, Name: "foo"}})) - - p, err := sql.NewPersister(ctx, reg, uuid.Must(uuid.FromString("77fdc5e0-2260-49da-8aae-c36ba255d05b"))) require.NoError(t, err) - rts, next, err := p.GetRelationTuples(context.Background(), &relationtuple.RelationQuery{Namespace: "foo"}) + actualRts, next, err := p.GetRelationTuples(ctx, &relationtuple.RelationQuery{Namespace: "foo"}) require.NoError(t, err) assert.Equal(t, "", next) - for _, rt := range []*relationtuple.InternalRelationTuple{ + expectedRts := []*relationtuple.InternalRelationTuple{ { Namespace: "foo", Object: "object", @@ -93,9 +140,13 @@ func TestMigrations(t *testing.T) { Relation: "s_relation", }, }, - } { - assert.Contains(t, rts, rt) } + + // The relationship tuples in the db have a UUID mapping, so + // we need to convert our expectations to that. + assert.NoError(t, p.MapFieldsToUUID( + ctx, relationtuple.InternalRelationTuples(expectedRts))) + assert.ElementsMatch(t, expectedRts, actualRts) }) }) diff --git a/internal/persistence/sql/migrations/single_table.go b/internal/persistence/sql/migrations/single_table.go index 62349d893..5b9cd033d 100644 --- a/internal/persistence/sql/migrations/single_table.go +++ b/internal/persistence/sql/migrations/single_table.go @@ -35,7 +35,7 @@ type ( } toSingleTableMigrator struct { d dependencies - perPage int + PerPage int } relationTuple struct { @@ -126,11 +126,11 @@ func (relationTuple) TableName(ctx context.Context) string { func NewToSingleTableMigrator(d dependencies) *toSingleTableMigrator { return &toSingleTableMigrator{ d: d, - perPage: 100, + PerPage: 100, } } -func (m *toSingleTableMigrator) namespaceMigrationBox(ctx context.Context, n *namespace.Namespace) (*popx.MigrationBox, error) { +func (m *toSingleTableMigrator) NamespaceMigrationBox(ctx context.Context, n *namespace.Namespace) (*popx.MigrationBox, error) { c, err := m.d.PopConnectionWithOpts(ctx, func(d *pop.ConnectionDetails) { d.Options = map[string]string{ "migration_table_name": migrationTableFromNamespace(n), @@ -149,7 +149,7 @@ func (m *toSingleTableMigrator) namespaceMigrationBox(ctx context.Context, n *na ) } -func (m *toSingleTableMigrator) getOldRelationTuples(ctx context.Context, n *namespace.Namespace, page, perPage int) (relationTuples, bool, error) { +func (m *toSingleTableMigrator) GetOldRelationTuples(ctx context.Context, n *namespace.Namespace, page, perPage int) (relationTuples, bool, error) { q := m.d.Persister().Connection(ctx). WithContext(context.WithValue(ctx, namespaceCtxKey, n)). Order("object, relation, subject, commit_time"). @@ -165,7 +165,7 @@ func (m *toSingleTableMigrator) getOldRelationTuples(ctx context.Context, n *nam return res, q.Paginator.Page < q.Paginator.TotalPages, nil } -func (m *toSingleTableMigrator) insertOldRelationTuples(ctx context.Context, n *namespace.Namespace, rs ...*relationtuple.InternalRelationTuple) error { +func (m *toSingleTableMigrator) InsertOldRelationTuples(ctx context.Context, n *namespace.Namespace, rs ...*relationtuple.InternalRelationTuple) error { for _, r := range rs { if r.Subject == nil { return errors.New("subject is not allowed to be nil") @@ -196,7 +196,7 @@ func (m *toSingleTableMigrator) MigrateNamespace(ctx context.Context, n *namespa if err := p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { for page := 1; ; page++ { - rs, hasNext, err := m.getOldRelationTuples(ctx, n, page, m.perPage) + rs, hasNext, err := m.GetOldRelationTuples(ctx, n, page, m.PerPage) if err != nil { return err } @@ -284,7 +284,7 @@ func (m *toSingleTableMigrator) LegacyNamespaces(ctx context.Context) ([]*namesp } func (m *toSingleTableMigrator) MigrateDown(ctx context.Context, n *namespace.Namespace) error { - mb, err := m.namespaceMigrationBox(ctx, n) + mb, err := m.NamespaceMigrationBox(ctx, n) if err != nil { return err } diff --git a/internal/persistence/sql/migrations/single_table_test.go b/internal/persistence/sql/migrations/single_table_test.go index 47de0622b..44f370a37 100644 --- a/internal/persistence/sql/migrations/single_table_test.go +++ b/internal/persistence/sql/migrations/single_table_test.go @@ -1,4 +1,4 @@ -package migrations +package migrations_test import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/ory/keto/internal/driver" "github.com/ory/keto/internal/driver/config" "github.com/ory/keto/internal/namespace" + "github.com/ory/keto/internal/persistence/sql/migrations" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x" "github.com/ory/keto/internal/x/dbx" @@ -24,7 +25,7 @@ func TestToSingleTableMigrator(t *testing.T) { r := driver.NewTestRegistry(t, dsn) ctx := context.Background() var nn []*namespace.Namespace - m := NewToSingleTableMigrator(r) + m := migrations.NewToSingleTableMigrator(r) setup := func(t *testing.T) *namespace.Namespace { n := &namespace.Namespace{ @@ -34,7 +35,7 @@ func TestToSingleTableMigrator(t *testing.T) { nn = append(nn, n) - mb, err := m.namespaceMigrationBox(ctx, n) + mb, err := m.NamespaceMigrationBox(ctx, n) require.NoError(t, err) require.NoError(t, mb.Up(ctx)) @@ -70,10 +71,10 @@ func TestToSingleTableMigrator(t *testing.T) { Relation: "b", }, } - require.NoError(t, m.insertOldRelationTuples(ctx, n, sID, sSet)) + require.NoError(t, m.InsertOldRelationTuples(ctx, n, sID, sSet)) // get the tuple from the old table - oldRts, next, err := m.getOldRelationTuples(ctx, n, 0, 100) + oldRts, next, err := m.GetOldRelationTuples(ctx, n, 0, 100) require.NoError(t, err) assert.False(t, next) require.Len(t, oldRts, 2) @@ -100,9 +101,9 @@ func TestToSingleTableMigrator(t *testing.T) { n := setup(t) defer func(old int) { - m.perPage = old - }(m.perPage) - m.perPage = 1 + m.PerPage = old + }(m.PerPage) + m.PerPage = 1 rts := make([]*relationtuple.InternalRelationTuple, 10) for i := range rts { @@ -114,7 +115,7 @@ func TestToSingleTableMigrator(t *testing.T) { } } - require.NoError(t, m.insertOldRelationTuples(ctx, n, rts...)) + require.NoError(t, m.InsertOldRelationTuples(ctx, n, rts...)) require.NoError(t, m.MigrateNamespace(ctx, n)) migrated, nextToken, err := r.RelationTupleManager().GetRelationTuples(ctx, &relationtuple.RelationQuery{Namespace: n.Name}, x.WithSize(len(rts))) @@ -132,7 +133,7 @@ func TestToSingleTableMigrator(t *testing.T) { Relation: "r", Subject: &relationtuple.SubjectID{ID: "s"}, } - require.NoError(t, m.insertOldRelationTuples(ctx, n, &relationtuple.InternalRelationTuple{ + require.NoError(t, m.InsertOldRelationTuples(ctx, n, &relationtuple.InternalRelationTuple{ Namespace: n.Name, Object: "o0", Relation: "r", @@ -145,7 +146,7 @@ func TestToSingleTableMigrator(t *testing.T) { }, valid)) err := m.MigrateNamespace(ctx, n) require.Error(t, err) - invalid, ok := err.(ErrInvalidTuples) + invalid, ok := err.(migrations.ErrInvalidTuples) require.True(t, ok) assert.Len(t, invalid, 2) @@ -168,7 +169,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { t.Run("case=simple detection", func(t *testing.T) { ctx := context.Background() reg := driver.NewTestRegistry(t, dsn) - m := NewToSingleTableMigrator(reg) + m := migrations.NewToSingleTableMigrator(reg) nspaces := []*namespace.Namespace{{ ID: 3, @@ -182,7 +183,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { assert.Len(t, legacyNamespaces, 0) // migrate legacy table up - mb, err := m.namespaceMigrationBox(ctx, nspaces[0]) + mb, err := m.NamespaceMigrationBox(ctx, nspaces[0]) require.NoError(t, err) require.NoError(t, mb.Up(ctx)) @@ -203,7 +204,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { t.Run("case=multiple namespaces", func(t *testing.T) { ctx := context.Background() reg := driver.NewTestRegistry(t, dsn) - m := NewToSingleTableMigrator(reg) + m := migrations.NewToSingleTableMigrator(reg) nspaces := []*namespace.Namespace{{ ID: 0, @@ -219,7 +220,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { for _, n := range nspaces { // migrate legacy table up - mb, err := m.namespaceMigrationBox(ctx, n) + mb, err := m.NamespaceMigrationBox(ctx, n) require.NoError(t, err) require.NoError(t, mb.Up(ctx)) } diff --git a/internal/persistence/sql/migrations/uuid_mapping.go b/internal/persistence/sql/migrations/uuid_mapping.go deleted file mode 100644 index 960256ab4..000000000 --- a/internal/persistence/sql/migrations/uuid_mapping.go +++ /dev/null @@ -1,114 +0,0 @@ -package migrations - -import ( - "context" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/ory/x/sqlcon" - - "github.com/ory/keto/internal/persistence/sql" -) - -type ( - toUUIDMappingMigrator struct { - d dependencies - perPage int - requestedPages int // track the number of pages we requested for testing. - } -) - -// NewToUUIDMappingMigrator creates a new UUID mapping migrator. -func NewToUUIDMappingMigrator(d dependencies) *toUUIDMappingMigrator { - return &toUUIDMappingMigrator{d: d, perPage: 100} -} - -// MigrateUUIDMappings migrates to UUID-mapped subject IDs for all relation -// tuples in the database. -func (m *toUUIDMappingMigrator) MigrateUUIDMappings(ctx context.Context) error { - p, ok := m.d.Persister().(*sql.Persister) - if !ok { - panic("got unexpected persister") - } - - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - m.requestedPages = 0 - for page := 1; ; page++ { - m.requestedPages++ - relationTuples, hasNext, err := m.getRelationTuples(ctx, page) - if err != nil { - return err - } - - if err := p.Connection(ctx).All(&relationTuples); err != nil { - return err - } - - for _, rt := range relationTuples { - if err := m.migrateSubjectID(ctx, rt); err != nil { - return err - } - if err := m.migrateSubjectSetObject(ctx, rt); err != nil { - return err - } - if err := m.migrateObject(ctx, rt); err != nil { - return err - } - if err := c.Update(rt); err != nil { - return err - } - } - - if !hasNext { - break - } - } - return nil - }) -} - -func (m *toUUIDMappingMigrator) migrateSubjectID(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.SubjectID.String) - if err == nil || !rt.SubjectID.Valid || rt.SubjectID.String == "" { - return nil - } - - rt.SubjectID.String, err = m.addUUIDMapping(ctx, rt.SubjectID.String) - return err -} - -func (m *toUUIDMappingMigrator) migrateSubjectSetObject(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.SubjectSetObject.String) - if err == nil || !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" { - return nil - } - - rt.SubjectSetObject.String, err = m.addUUIDMapping(ctx, rt.SubjectSetObject.String) - return err -} - -func (m *toUUIDMappingMigrator) migrateObject(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.Object) - if err == nil || rt.Object == "" { - return nil - } - - rt.Object, err = m.addUUIDMapping(ctx, rt.Object) - return err -} - -func (m *toUUIDMappingMigrator) addUUIDMapping(ctx context.Context, value string) (id string, err error) { - uid, err := m.d.Persister().ToUUID(ctx, value) - return uid.String(), err -} - -func (m *toUUIDMappingMigrator) getRelationTuples(ctx context.Context, page int) (res []*sql.RelationTuple, hasNext bool, err error) { - q := m.d.Persister().Connection(ctx). - Order("nid, shard_id"). - Paginate(page, m.perPage) - - if err := q.All(&res); err != nil { - return nil, false, sqlcon.HandleError(err) - } - return res, q.Paginator.Page < q.Paginator.TotalPages, nil -} diff --git a/internal/persistence/sql/migrations/uuid_mapping_test.go b/internal/persistence/sql/migrations/uuid_mapping_test.go deleted file mode 100644 index 9756c5563..000000000 --- a/internal/persistence/sql/migrations/uuid_mapping_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package migrations - -import ( - "context" - dbsql "database/sql" - "testing" - "time" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/ory/keto/internal/driver" - "github.com/ory/keto/internal/persistence/sql" - "github.com/ory/keto/internal/x/dbx" -) - -func TestToUUIDMappingMigrator(t *testing.T) { - const debugOnDisk = false - - for _, dsn := range dbx.GetDSNs(t, debugOnDisk) { - t.Run("db="+dsn.Name, func(t *testing.T) { - ctx := context.Background() - r := driver.NewTestRegistry(t, dsn) - m := NewToUUIDMappingMigrator(r) - p := m.d.Persister().(*sql.Persister) - conn := p.Connection(ctx) - nw, err := r.DetermineNetwork(ctx) - require.NoError(t, err) - - testCases := []struct { - name string - rt *sql.RelationTuple - expectMapping bool - }{{ - name: "with string subject", - rt: &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - Object: "object", - SubjectID: dbsql.NullString{String: "subject", Valid: true}, - CommitTime: time.Now(), - }, - expectMapping: true, - }, { - name: "with null subject", - rt: &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - SubjectID: dbsql.NullString{String: "", Valid: false}, - SubjectSetNamespaceID: dbsql.NullInt32{Int32: 0, Valid: true}, - SubjectSetObject: dbsql.NullString{String: "object", Valid: true}, - SubjectSetRelation: dbsql.NullString{String: "subject_set_relation", Valid: true}, - Object: "object", - CommitTime: time.Now(), - }, - expectMapping: true, - }, { - name: "with UUID subject", - rt: &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - SubjectID: dbsql.NullString{String: uuid.Must(uuid.NewV4()).String(), Valid: true}, - CommitTime: time.Now(), - }, - expectMapping: false, - }} - - for _, tc := range testCases { - t.Run("case="+tc.name, func(t *testing.T) { - require.NoError(t, conn.Create(tc.rt)) - require.NoError(t, m.MigrateUUIDMappings(ctx)) - - newRt := &sql.RelationTuple{} - require.NoError(t, conn.Find(newRt, tc.rt.ID)) - - if tc.expectMapping { - // Check subject mapping - if tc.rt.SubjectID.Valid { - assertHasMapping(t, conn, tc.rt.SubjectID.String, newRt.SubjectID.String) - } else { - assertHasMapping(t, conn, tc.rt.SubjectSetObject.String, newRt.SubjectSetObject.String) - // check that both "OBJECT" strings are mapped to same UUID. - assert.Equal(t, newRt.Object, newRt.SubjectSetObject.String) - } - assertHasMapping(t, conn, tc.rt.Object, newRt.Object) - } else { - // Nothing should have changed (ignoring commit time) - newRt.CommitTime = tc.rt.CommitTime - assert.Equal(t, tc.rt, newRt) - } - }) - } - }) - } -} - -func TestToUUIDMappingMigrator_paginates(t *testing.T) { - numTuples := 10 - perPage := 2 - - dsn := dbx.GetSqlite(t, dbx.SQLiteMemory) - ctx := context.Background() - r := driver.NewTestRegistry(t, dsn) - m := &toUUIDMappingMigrator{d: r, perPage: perPage} - p := m.d.Persister().(*sql.Persister) - conn := p.Connection(ctx) - nw, err := r.DetermineNetwork(ctx) - require.NoError(t, err) - - // Create a bunch of relation tuples - for i := 0; i < numTuples; i++ { - rt := &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - Object: "object", - SubjectID: dbsql.NullString{String: "subject", Valid: true}, - CommitTime: time.Now(), - } - require.NoError(t, conn.Create(rt)) - } - - require.NoError(t, m.MigrateUUIDMappings(ctx)) - assert.Equal(t, numTuples/perPage, m.requestedPages) -} - -// assertHasMapping checks that there is a mapping from the given string (value) -// to the given UUID (uid). -func assertHasMapping(t *testing.T, conn *pop.Connection, value, uid string) { - t.Helper() - mapping := &sql.UUIDMapping{} - require.NoError(t, conn.Find(mapping, uid), "Could not find mapping for %q", uid) - assert.NotEqual(t, value, uid, "value was not replaced by UUID") - assert.Equal(t, value, mapping.StringRepresentation) -} diff --git a/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go b/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go new file mode 100644 index 000000000..843419ac5 --- /dev/null +++ b/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go @@ -0,0 +1,176 @@ +package uuidmapping + +import ( + "database/sql" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/ory/x/popx" + "github.com/ory/x/sqlcon" +) + +type ( + // We copy the definitions of RelationTuple and UUIDMapping here so that the + // migration will always work on the same definitions. + RelationTuple struct { + // An ID field is required to make pop happy. The actual ID is a + // composite primary key. + ID uuid.UUID `db:"shard_id"` + NetworkID uuid.UUID `db:"nid"` + NamespaceID int32 `db:"namespace_id"` + Object string `db:"object"` + Relation string `db:"relation"` + SubjectID sql.NullString `db:"subject_id"` + SubjectSetNamespaceID sql.NullInt32 `db:"subject_set_namespace_id"` + SubjectSetObject sql.NullString `db:"subject_set_object"` + SubjectSetRelation sql.NullString `db:"subject_set_relation"` + CommitTime time.Time `db:"commit_time"` + } + UUIDMapping struct { + ID uuid.UUID `db:"id"` + StringRepresentation string `db:"string_representation"` + } + UUIDMappings []*UUIDMapping +) + +func (RelationTuple) TableName() string { return "keto_relation_tuples" } +func (UUIDMappings) TableName() string { return "keto_uuid_mappings" } +func (UUIDMapping) TableName() string { return "keto_uuid_mappings" } + +var ( + name = "migrate-strings-to-uuids" + version = "20220509000000000000" + Migrations = popx.Migrations{ + { + Version: version, + Name: name, + Path: name, + Direction: "up", + DBType: "all", + Type: "go", + Runner: func(_ popx.Migration, conn *pop.Connection, _ *pop.Tx) error { + for page := 1; ; page++ { + relationTuples, hasNext, err := getRelationTuples(conn, page) + if err != nil { + return err + } + + if err := conn.All(&relationTuples); err != nil { + return err + } + + for _, rt := range relationTuples { + if err := migrateSubjectID(conn, &rt); err != nil { + return err + } + if err := migrateSubjectSetObject(conn, &rt); err != nil { + return err + } + if err := migrateObject(conn, &rt); err != nil { + return err + } + if err := conn.Update(&rt); err != nil { + return err + } + } + + if !hasNext { + break + } + } + + return nil + }, + }, + // We have to specify the "down" migration even if it is a no-op, since + // the migrator will still manipulate the version table in the database. + { + Version: version, + Name: name, + Path: name, + Direction: "down", + DBType: "all", + Type: "go", + Runner: func(_ popx.Migration, _ *pop.Connection, _ *pop.Tx) error { + return nil + }, + }, + } +) + +func hasMapping(conn *pop.Connection, id string) (bool, error) { + return conn.Where("id = ?", id).Exists(&UUIDMapping{}) +} + +func migrateSubjectID(conn *pop.Connection, rt *RelationTuple) error { + skip, err := hasMapping(conn, rt.SubjectID.String) + if err != nil { + return err + } + if skip || !rt.SubjectID.Valid || rt.SubjectID.String == "" { + return nil + } + + rt.SubjectID.String, err = addUUIDMapping(conn, rt.NetworkID, rt.SubjectID.String) + return err +} + +func migrateSubjectSetObject(conn *pop.Connection, rt *RelationTuple) error { + skip, err := hasMapping(conn, rt.SubjectSetObject.String) + if err != nil { + return err + } + if skip || !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" { + return nil + } + + rt.SubjectSetObject.String, err = addUUIDMapping(conn, rt.NetworkID, rt.SubjectSetObject.String) + return err +} + +func migrateObject(conn *pop.Connection, rt *RelationTuple) error { + skip, err := hasMapping(conn, rt.Object) + if err != nil { + return err + } + if skip || rt.Object == "" { + return nil + } + + rt.Object, err = addUUIDMapping(conn, rt.NetworkID, rt.Object) + return err +} + +func addUUIDMapping(conn *pop.Connection, networkID uuid.UUID, value string) (uid string, err error) { + uid = uuid.NewV5(networkID, value).String() + + // We need to write manual SQL here because the INSERT should not fail if + // the UUID already exists, but we still want to return an error if anything + // else goes wrong. + var query string + switch d := conn.Dialect.Name(); d { + case "mysql": + query = ` + INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) + VALUES (?, ?)` + default: + query = ` + INSERT INTO keto_uuid_mappings (id, string_representation) + VALUES (?, ?) + ON CONFLICT (id) DO NOTHING` + } + + return uid, sqlcon.HandleError(conn.RawQuery(query, uid, value).Exec()) +} + +func getRelationTuples(conn *pop.Connection, page int) ( + res []RelationTuple, hasNext bool, err error, +) { + q := conn.Order("nid, shard_id").Paginate(page, 100) + + if err := q.All(&res); err != nil { + return nil, false, sqlcon.HandleError(err) + } + return res, q.Paginator.Page < q.Paginator.TotalPages, nil +} diff --git a/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator_test.go b/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator_test.go new file mode 100644 index 000000000..4989a033e --- /dev/null +++ b/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator_test.go @@ -0,0 +1,120 @@ +package uuidmapping_test + +// func TestToUUIDMappingMigrator(t *testing.T) { +// const debugOnDisk = false + +// for _, dsn := range dbx.GetDSNs(t, debugOnDisk) { +// t.Run("db="+dsn.Name, func(t *testing.T) { +// ctx := context.Background() +// r := driver.NewTestRegistry(t, dsn) +// m := migrations.NewToUUIDMappingMigrator(r) +// p := r.Persister() +// conn := p.Connection(ctx) +// nw, err := r.DetermineNetwork(ctx) +// require.NoError(t, err) + +// testCases := []struct { +// name string +// rt *sql.RelationTuple +// }{{ +// name: "with string subject", +// rt: &sql.relationtuple{ +// id: uuid.must(uuid.newv4()), +// networkid: nw.id, +// object: "object", +// subjectid: dbsql.nullstring{string: "subject", valid: true}, +// committime: time.now(), +// }, +// }, { +// name: "with null subject", +// rt: &sql.relationtuple{ +// id: uuid.must(uuid.newv4()), +// networkid: nw.id, +// subjectid: dbsql.nullstring{string: "", valid: false}, +// subjectsetnamespaceid: dbsql.nullint32{int32: 0, valid: true}, +// subjectsetobject: dbsql.nullstring{string: "object", valid: true}, +// subjectsetrelation: dbsql.nullstring{string: "subject_set_relation", valid: true}, +// object: "object", +// committime: time.now(), +// }, +// }, { +// name: "with uuid subject", +// rt: &sql.relationtuple{ +// id: uuid.must(uuid.newv4()), +// networkid: nw.id, +// subjectid: dbsql.nullstring{string: uuid.must(uuid.newv4()).string(), valid: true}, +// object: "object", +// committime: time.now(), +// }, +// }} + +// for _, tc := range testcases { +// t.run("case="+tc.name, func(t *testing.t) { +// require.noerror(t, conn.create(tc.rt)) +// require.noerror(t, m.migrateuuidmappings(ctx)) + +// newrt := &sql.relationtuple{} +// require.noerror(t, conn.find(newrt, tc.rt.id)) + +// // check subject mapping +// if tc.rt.subjectid.valid { +// asserthasmapping(t, conn, tc.rt.subjectid.string, newrt.subjectid.string) +// } else { +// asserthasmapping(t, conn, tc.rt.subjectsetobject.string, newrt.subjectsetobject.string) +// // check that both "object" strings are mapped to same uuid. +// assert.equal(t, newrt.object, newrt.subjectsetobject.string) +// } +// asserthasmapping(t, conn, tc.rt.object, newrt.object) + +// t.run("idempotency", func(t *testing.t) { +// // check that running the migration again doesn't change anything +// require.noerror(t, m.migrateuuidmappings(ctx)) +// twicemigratedrt := &sql.relationtuple{} +// require.noerror(t, conn.find(twicemigratedrt, tc.rt.id)) +// assert.equal(t, newrt, twicemigratedrt) +// }) +// }) +// } +// }) +// } +// } + +// func testtouuidmappingmigrator_paginates(t *testing.t) { +// numtuples := 10 +// perpage := 2 + +// dsn := dbx.getsqlite(t, dbx.sqlitememory) +// ctx := context.background() +// r := driver.newtestregistry(t, dsn) +// m := migrations.newtouuidmappingmigrator(r) +// m.perpage = perpage +// p := r.persister() +// conn := p.connection(ctx) +// nw, err := r.determinenetwork(ctx) +// require.noerror(t, err) + +// // create a bunch of relation tuples +// for i := 0; i < numtuples; i++ { +// rt := &sql.relationtuple{ +// id: uuid.must(uuid.newv4()), +// networkid: nw.id, +// object: "object", +// subjectid: dbsql.nullstring{string: "subject", valid: true}, +// committime: time.now(), +// } +// require.noerror(t, conn.create(rt)) +// } + +// require.noerror(t, m.migrateuuidmappings(ctx)) +// assert.equal(t, numtuples/perpage, m.requestedpages) +// } + +// // asserthasmapping checks that there is a mapping from the given string (value) +// // to the given uuid (uid). +// func asserthasmapping(t *testing.t, conn *pop.connection, value, uid string) { +// t.helper() +// mapping := &sql.uuidmapping{} +// require.noerror(t, conn.find(mapping, uid), "could not find mapping for %q", uid) +// assert.notequal(t, value, uid, "value was not replaced by uuid") +// assert.equal(t, value, mapping.stringrepresentation) +// } diff --git a/internal/persistence/sql/persister.go b/internal/persistence/sql/persister.go index e73fe0027..e4c836c12 100644 --- a/internal/persistence/sql/persister.go +++ b/internal/persistence/sql/persister.go @@ -9,11 +9,7 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/ory/x/fsx" - "github.com/ory/x/logrusx" - "github.com/ory/x/networkx" "github.com/ory/x/popx" - "github.com/ory/x/tracing" "github.com/pkg/errors" "github.com/ory/keto/internal/driver/config" @@ -46,7 +42,7 @@ const ( var ( //go:embed migrations/sql/*.sql - migrations embed.FS + Migrations embed.FS _ persistence.Persister = &Persister{} ) @@ -66,10 +62,6 @@ func NewPersister(ctx context.Context, reg dependencies, nid uuid.UUID) (*Persis return p, nil } -func NewMigrationBox(c *pop.Connection, logger *logrusx.Logger, tracer *tracing.Tracer) (*popx.MigrationBox, error) { - return popx.NewMigrationBox(fsx.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, logger, tracer, 0)) -} - func (p *Persister) Connection(ctx context.Context) *pop.Connection { return popx.GetConnection(ctx, p.conn.WithContext(ctx)) } diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 98aee6a14..47b94713b 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -198,7 +198,7 @@ func (p *Persister) whereQuery(ctx context.Context, q *pop.Query, rq *relationtu } func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtuple.InternalRelationTuple) error { - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { for _, r := range rs { n, err := p.GetNamespaceByName(ctx, r.Namespace) if err != nil { @@ -223,7 +223,7 @@ func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtup } func (p *Persister) DeleteAllRelationTuples(ctx context.Context, query *relationtuple.RelationQuery) error { - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { sqlQuery := p.QueryWithNetwork(ctx) err := p.whereQuery(ctx, sqlQuery, query) if err != nil { diff --git a/internal/persistence/sql/uuid_mapping.go b/internal/persistence/sql/uuid_mapping.go index 95ed4b5e5..1f1b5030a 100644 --- a/internal/persistence/sql/uuid_mapping.go +++ b/internal/persistence/sql/uuid_mapping.go @@ -2,9 +2,13 @@ package sql import ( "context" + "fmt" "github.com/gofrs/uuid" "github.com/ory/x/sqlcon" + + "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/x" ) type ( @@ -48,13 +52,95 @@ func (p *Persister) ToUUID(ctx context.Context, text string) (uuid.UUID, error) ) } -func (p *Persister) FromUUID(ctx context.Context, id uuid.UUID) (rep string, err error) { - p.d.Logger().Trace("looking up UUID") +func (p *Persister) FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) { + p.d.Logger().Trace("looking up UUIDs") + + // We need to paginate on the ids, because we want to get the exact chunk of + // string representations for the given ids. + pagination, _ := internalPaginationFromOptions(opts...) + pageSize := pagination.PerPage + + // Build a map from UUID -> indices in the result. + idIdx := make(map[uuid.UUID][]int) + for i, id := range ids { + if ids, ok := idIdx[id]; ok { + idIdx[id] = append(ids, i) + } else { + idIdx[id] = []int{i} + } + } + + res = make([]string, len(ids)) - m := &UUIDMapping{} - if err := sqlcon.HandleError(p.Connection(ctx).Find(m, id)); err != nil { - return "", err + for i := 0; i < len(ids); i += pageSize { + end := i + pageSize + if end > len(ids) { + end = len(ids) + } + idsToLookup := ids[i:end] + mappings := &[]UUIDMapping{} + query := p.Connection(ctx).Where("id in (?)", idsToLookup) + if err := sqlcon.HandleError(query.All(mappings)); err != nil { + return []string{}, err + } + + // Write the representation to the correct index. + for _, m := range *mappings { + for _, idx := range idIdx[m.ID] { + res[idx] = m.StringRepresentation + } + } } - return m.StringRepresentation, nil + return +} + +func (p *Persister) replaceWithUUID(ctx context.Context, s *string) error { + if s == nil { + return nil + } + uuid, err := p.ToUUID(ctx, *s) + if err != nil { + return err + } + *s = uuid.String() + + return nil +} + +func (p *Persister) MapFieldsToUUID(ctx context.Context, m relationtuple.UUIDMappable) error { + for _, s := range m.UUIDMappableFields() { + if err := p.replaceWithUUID(ctx, s); err != nil { + return err + } + } + return nil +} + +func (p *Persister) MapFieldsFromUUID(ctx context.Context, m relationtuple.UUIDMappable) error { + ids := make([]uuid.UUID, len(m.UUIDMappableFields())) + for i, field := range m.UUIDMappableFields() { + if field == nil { + continue + } + id, err := uuid.FromString(*field) + if err != nil { + return err + } + ids[i] = id + } + reps, err := p.FromUUID(ctx, ids) + if err != nil { + return err + } + for i, field := range m.UUIDMappableFields() { + if field == nil { + continue + } + if reps[i] == "" { + return fmt.Errorf("failed to map %s", ids[i]) + } + *field = reps[i] + } + return nil } diff --git a/internal/relationtuple/definitions.go b/internal/relationtuple/definitions.go index 9fc1c63c9..a3db8f77a 100644 --- a/internal/relationtuple/definitions.go +++ b/internal/relationtuple/definitions.go @@ -24,6 +24,7 @@ import ( type ( ManagerProvider interface { RelationTupleManager() Manager + UUIDMappingManager() UUIDMappingManager } Manager interface { GetRelationTuples(ctx context.Context, query *RelationQuery, options ...x.PaginationOptionSetter) ([]*InternalRelationTuple, string, error) @@ -88,6 +89,9 @@ type Subject interface { // swagger:ignore ToProto() *rts.Subject + + // swagger:ignore + UUIDMappable } // swagger:ignore @@ -97,6 +101,18 @@ type InternalRelationTuple struct { Relation string `json:"relation"` Subject Subject `json:"subject"` } +type InternalRelationTuples []*InternalRelationTuple + +func (rt *InternalRelationTuple) UUIDMappableFields() []*string { + return append([]*string{&rt.Object}, rt.Subject.UUIDMappableFields()...) +} + +func (rtt InternalRelationTuples) UUIDMappableFields() (fields []*string) { + for _, rt := range rtt { + fields = append(fields, rt.UUIDMappableFields()...) + } + return fields +} // swagger:parameters getExpand type SubjectSet struct { @@ -141,6 +157,16 @@ func SubjectFromString(s string) (Subject, error) { return (&SubjectID{}).FromString(s) } +// swagger:ignore +func (s *SubjectID) UUIDMappableFields() []*string { + return []*string{&s.ID} +} + +// swagger:ignore +func (s *SubjectSet) UUIDMappableFields() []*string { + return []*string{&s.Object} +} + // swagger:ignore func SubjectFromProto(gs *rts.Subject) (Subject, error) { switch s := gs.GetRef().(type) { @@ -447,6 +473,17 @@ func (q *RelationQuery) FromProto(query TupleData) (*RelationQuery, error) { return q, nil } +func (q *RelationQuery) UUIDMappableFields() []*string { + res := []*string{&q.Object} + if q.SubjectID != nil { + res = append(res, q.SubjectID) + } + if q.SubjectSet != nil { + res = append(res, q.SubjectSet.UUIDMappableFields()...) + } + return res +} + const ( subjectIDKey = "subject_id" subjectSetNamespaceKey = "subject_set.namespace" @@ -684,3 +721,7 @@ func (t *ManagerWrapper) TransactRelationTuples(ctx context.Context, insert []*I func (t *ManagerWrapper) RelationTupleManager() Manager { return t } + +func (t *ManagerWrapper) UUIDMappingManager() UUIDMappingManager { + return t.Reg.UUIDMappingManager() +} diff --git a/internal/relationtuple/read_server.go b/internal/relationtuple/read_server.go index cc36f682e..f36117ab4 100644 --- a/internal/relationtuple/read_server.go +++ b/internal/relationtuple/read_server.go @@ -16,7 +16,10 @@ import ( "github.com/ory/keto/internal/x" ) -var _ rts.ReadServiceServer = (*handler)(nil) +var ( + _ rts.ReadServiceServer = (*handler)(nil) + _ = (*getRelationsParams)(nil) +) func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationTuplesRequest) (*rts.ListRelationTuplesResponse, error) { if req.Query == nil { @@ -28,6 +31,7 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT return nil, err } + h.d.UUIDMappingManager().MapFieldsToUUID(ctx, q) rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(ctx, q, x.WithSize(int(req.PageSize)), x.WithToken(req.PageToken), @@ -35,6 +39,7 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT if err != nil { return nil, err } + h.d.UUIDMappingManager().MapFieldsFromUUID(ctx, InternalRelationTuples(rels)) resp := &rts.ListRelationTuplesResponse{ RelationTuples: make([]*rts.RelationTuple, len(rels)), @@ -48,7 +53,6 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT } // swagger:parameters getRelationTuples -// nolint:deadcode,unused type getRelationsParams struct { // Namespace of the Relation Tuple // @@ -139,11 +143,13 @@ func (h *handler) getRelations(w http.ResponseWriter, r *http.Request, _ httprou paginationOpts = append(paginationOpts, x.WithSize(int(s))) } + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), query) rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(r.Context(), query, paginationOpts...) if err != nil { h.d.Writer().WriteError(w, r, err) return } + h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), InternalRelationTuples(rels)) resp := &GetResponse{ RelationTuples: rels, diff --git a/internal/relationtuple/read_server_test.go b/internal/relationtuple/read_server_test.go index e772ea019..fb474fc1a 100644 --- a/internal/relationtuple/read_server_test.go +++ b/internal/relationtuple/read_server_test.go @@ -103,7 +103,7 @@ func TestReadHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{ "object": {obj}, @@ -114,7 +114,7 @@ func TestReadHandlers(t *testing.T) { var respMsg relationtuple.GetResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&respMsg)) assert.Equal(t, 1, len(respMsg.RelationTuples)) - assert.Contains(t, rts, respMsg.RelationTuples[0]) + assert.Containsf(t, rts, respMsg.RelationTuples[0], "expected to find %q in %q", respMsg.RelationTuples[0].String(), rts) assert.Equal(t, "", respMsg.NextPageToken) }) @@ -145,7 +145,7 @@ func TestReadHandlers(t *testing.T) { Subject: &relationtuple.SubjectID{ID: "s2"}, }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) var firstResp relationtuple.GetResponse t.Run("case=first page", func(t *testing.T) { diff --git a/internal/relationtuple/swagger_definitions.go b/internal/relationtuple/swagger_definitions.go index 7c32e2a38..5e215a002 100644 --- a/internal/relationtuple/swagger_definitions.go +++ b/internal/relationtuple/swagger_definitions.go @@ -1,7 +1,11 @@ package relationtuple +var ( + _ = (*relationTupleWithRequired)(nil) + _ = (*patchPayload)(nil) +) + // swagger:model InternalRelationTuple -// nolint:deadcode,unused type relationTupleWithRequired struct { // Namespace of the Relation Tuple // @@ -31,7 +35,6 @@ type relationTupleWithRequired struct { // The patch request payload // // swagger:parameters patchRelationTuples -// nolint:deadcode,unused type patchPayload struct { // in:body Payload []*PatchDelta @@ -41,3 +44,15 @@ type PatchDelta struct { Action patchAction `json:"action"` RelationTuple *InternalRelationTuple `json:"relation_tuple"` } + +// swagger:ignore +type PatchDeltas []*PatchDelta + +func (p PatchDeltas) UUIDMappableFields() (res []*string) { + for _, pd := range p { + if pd.RelationTuple != nil { + res = append(res, pd.RelationTuple.UUIDMappableFields()...) + } + } + return +} diff --git a/internal/relationtuple/test_helper.go b/internal/relationtuple/test_helper.go new file mode 100644 index 000000000..b51ed2cfd --- /dev/null +++ b/internal/relationtuple/test_helper.go @@ -0,0 +1,18 @@ +package relationtuple + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// MapAndWriteTuples is a test helper to write relation tuples to the database +// while mapping all strings to UUIDs. +func MapAndWriteTuples(t *testing.T, m ManagerProvider, tuples ...*InternalRelationTuple) { + t.Helper() + + m.UUIDMappingManager().MapFieldsToUUID(context.Background(), InternalRelationTuples(tuples)) + require.NoError(t, m.RelationTupleManager().WriteRelationTuples(context.Background(), tuples...)) + m.UUIDMappingManager().MapFieldsFromUUID(context.Background(), InternalRelationTuples(tuples)) +} diff --git a/internal/relationtuple/transact_server.go b/internal/relationtuple/transact_server.go index 5bdd02e72..a42769b8c 100644 --- a/internal/relationtuple/transact_server.go +++ b/internal/relationtuple/transact_server.go @@ -12,7 +12,11 @@ import ( "github.com/pkg/errors" ) -var _ rts.WriteServiceServer = (*handler)(nil) +var ( + _ rts.WriteServiceServer = (*handler)(nil) + _ = (*bodyRelationTuple)(nil) + _ = (*queryRelationTuple)(nil) +) func protoTuplesWithAction(deltas []*rts.RelationTupleDelta, action rts.RelationTupleDelta_Action) (filtered []*InternalRelationTuple, err error) { for _, d := range deltas { @@ -38,6 +42,7 @@ func (h *handler) TransactRelationTuples(ctx context.Context, req *rts.TransactR return nil, err } + h.d.UUIDMappingManager().MapFieldsToUUID(ctx, InternalRelationTuples(append(insertTuples, deleteTuples...))) err = h.d.RelationTupleManager().TransactRelationTuples(ctx, insertTuples, deleteTuples) if err != nil { return nil, err @@ -62,6 +67,7 @@ func (h *handler) DeleteRelationTuples(ctx context.Context, req *rts.DeleteRelat return nil, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error())) } + h.d.UUIDMappingManager().MapFieldsToUUID(ctx, q) if err := h.d.RelationTupleManager().DeleteAllRelationTuples(ctx, q); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithError(err.Error())) } @@ -72,7 +78,6 @@ func (h *handler) DeleteRelationTuples(ctx context.Context, req *rts.DeleteRelat // The basic ACL relation tuple // // swagger:parameters postCheck createRelationTuple -// nolint:deadcode,unused type bodyRelationTuple struct { // in: body Payload RelationQuery @@ -81,7 +86,6 @@ type bodyRelationTuple struct { // The basic ACL relation tuple // // swagger:parameters getCheck deleteRelationTuples -// nolint:deadcode,unused type queryRelationTuple struct { // Namespace of the Relation Tuple // @@ -151,11 +155,13 @@ func (h *handler) createRelation(w http.ResponseWriter, r *http.Request, _ httpr h.d.Logger().WithFields(rel.ToLoggerFields()).Debug("creating relation tuple") + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), &rel) if err := h.d.RelationTupleManager().WriteRelationTuples(r.Context(), &rel); err != nil { h.d.Logger().WithError(err).WithFields(rel.ToLoggerFields()).Errorf("got an error while creating the relation tuple") h.d.Writer().WriteError(w, r, err) return } + h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), &rel) q, err := rel.ToURLQuery() if err != nil { @@ -198,6 +204,7 @@ func (h *handler) deleteRelations(w http.ResponseWriter, r *http.Request, _ http } l.Debug("deleting relation tuples") + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), query) if err := h.d.RelationTupleManager().DeleteAllRelationTuples(r.Context(), query); err != nil { l.WithError(err).Errorf("got an error while deleting relation tuples") h.d.Writer().WriteError(w, r, herodot.ErrInternalServerError.WithError(err.Error())) @@ -254,7 +261,13 @@ func (h *handler) patchRelations(w http.ResponseWriter, r *http.Request, _ httpr } } - if err := h.d.RelationTupleManager().TransactRelationTuples(r.Context(), internalTuplesWithAction(deltas, ActionInsert), internalTuplesWithAction(deltas, ActionDelete)); err != nil { + h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), PatchDeltas(deltas)) + if err := h.d.RelationTupleManager(). + TransactRelationTuples( + r.Context(), + internalTuplesWithAction(deltas, ActionInsert), + internalTuplesWithAction(deltas, ActionDelete)); err != nil { + h.d.Writer().WriteError(w, r, err) return } diff --git a/internal/relationtuple/transact_server_test.go b/internal/relationtuple/transact_server_test.go index 921e509c6..845379561 100644 --- a/internal/relationtuple/transact_server_test.go +++ b/internal/relationtuple/transact_server_test.go @@ -80,10 +80,14 @@ func TestWriteHandlers(t *testing.T) { assert.JSONEq(t, string(payload), string(body)) t.Run("check=is contained in the manager", func(t *testing.T) { + query := rt.ToQuery() + reg.UUIDMappingManager().MapFieldsToUUID(context.Background(), query) // set a size > 1 just to make sure it gets all - actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), rt.ToQuery(), x.WithSize(10)) + actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), query, x.WithSize(10)) require.NoError(t, err) - assert.Equal(t, []*relationtuple.InternalRelationTuple{rt}, actualRTs) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), query) + assert.Equalf(t, []*relationtuple.InternalRelationTuple{rt}, actualRTs, "want: %s\ngot: %s", rt.String(), actualRTs[0].String()) }) t.Run("check=is gettable with the returned URL", func(t *testing.T) { @@ -136,6 +140,7 @@ func TestWriteHandlers(t *testing.T) { Namespace: nspace.Name, }) require.NoError(t, err) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actual)) assert.Equal(t, "", next) assert.Len(t, actual, 2) for _, rt := range rts { @@ -154,7 +159,7 @@ func TestWriteHandlers(t *testing.T) { Relation: "deleted rel", Subject: &relationtuple.SubjectID{ID: "deleted subj"}, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rt)) + relationtuple.MapAndWriteTuples(t, reg, rt) q, err := rt.ToURLQuery() require.NoError(t, err) @@ -167,6 +172,7 @@ func TestWriteHandlers(t *testing.T) { // set a size > 1 just to make sure it gets all actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), rt.ToQuery(), x.WithSize(10)) require.NoError(t, err) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) assert.Equal(t, []*relationtuple.InternalRelationTuple{}, actualRTs) }) @@ -188,7 +194,7 @@ func TestWriteHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) q := url.Values{ "namespace": {nspace.Name}, @@ -206,6 +212,7 @@ func TestWriteHandlers(t *testing.T) { actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), query, x.WithSize(10)) require.NoError(t, err) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) assert.Equal(t, []*relationtuple.InternalRelationTuple{}, actualRTs) }) }) @@ -234,7 +241,7 @@ func TestWriteHandlers(t *testing.T) { }, }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), deltas[1].RelationTuple)) + relationtuple.MapAndWriteTuples(t, reg, deltas[1].RelationTuple) body, err := json.Marshal(deltas) require.NoError(t, err) @@ -249,6 +256,8 @@ func TestWriteHandlers(t *testing.T) { Relation: t.Name(), }) require.NoError(t, err) + err = reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) + require.NoError(t, err) assert.Equal(t, []*relationtuple.InternalRelationTuple{deltas[0].RelationTuple}, actualRTs) }) @@ -316,6 +325,7 @@ func TestWriteHandlers(t *testing.T) { actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), &relationtuple.RelationQuery{ Namespace: nspace.Name, }) + reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) require.NoError(t, err) assert.Equal(t, []*relationtuple.InternalRelationTuple{deltas[0].RelationTuple}, actualRTs) }) diff --git a/internal/relationtuple/uuid_mapping.go b/internal/relationtuple/uuid_mapping.go new file mode 100644 index 000000000..54aea61c2 --- /dev/null +++ b/internal/relationtuple/uuid_mapping.go @@ -0,0 +1,102 @@ +package relationtuple + +import ( + "context" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/x" +) + +type ( + // UUIDMappable is an interface for objects that have fields that can be + // mapped to and from UUIDs. + UUIDMappable interface{ UUIDMappableFields() []*string } + + UUIDMappingManager interface { + // ToUUID returns the mapped UUID for the given string representation. + // If the string representation is not mapped, a new UUID will be + // created automatically. + ToUUID(ctx context.Context, representation string) (uuid.UUID, error) + + // MapFields maps all fields of the given object to UUIDs. + MapFieldsToUUID(ctx context.Context, m UUIDMappable) error + + // MapFieldsFromUUID maps all fields of the given object from UUIDs to + // their string value. + MapFieldsFromUUID(ctx context.Context, m UUIDMappable) error + + // FromUUID returns the text representations for the given UUIDs, such + // that ids[i] is mapped to reps[i]. + // + // Of the pagination options, only the page size is considered and used + // as a batch size. + FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (reps []string, err error) + } +) + +func UUIDMappingManagerTest(t *testing.T, m UUIDMappingManager) { + ctx := context.Background() + + t.Run("case=ToUUID_FromUUID", func(t *testing.T) { + rep1 := "foo" + id, err := m.ToUUID(ctx, rep1) + require.NoError(t, err) + + rep2, err := m.FromUUID(ctx, []uuid.UUID{id}) + assert.NoError(t, err) + assert.Equal(t, rep1, rep2[0]) + }) + + t.Run("case=Idempotent_ToUUID", func(t *testing.T) { + id1, err := m.ToUUID(ctx, "string") + assert.NoError(t, err) + id2, err := m.ToUUID(ctx, "string") + assert.NoError(t, err) + assert.Equal(t, id1, id2) + }) + + // Test that the batch mapping preserves ordering, i.e. id[i] is mapped to + // rep[i]. + t.Run("case=Batch_ToUUID_Paginates", func(t *testing.T) { + expected := []string{"foo", "foo", "bar", "baz"} + ids := make([]uuid.UUID, len(expected)) + for i, s := range expected { + var err error + ids[i], err = m.ToUUID(ctx, s) + assert.NoError(t, err) + } + + actual, err := m.FromUUID(ctx, ids, x.WithSize(1)) + assert.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("case=IdempotentMapFieldsToAndFromUUIDs", func(t *testing.T) { + tc := []struct { + name string + obj UUIDMappable + copy UUIDMappable + }{ + { + name: "RelationTuple", + obj: &InternalRelationTuple{Namespace: "n", Relation: "r", Object: "Object", Subject: &SubjectID{ID: "Subject"}}, + copy: &InternalRelationTuple{Namespace: "n", Relation: "r", Object: "Object", Subject: &SubjectID{ID: "Subject"}}, + }, { + name: "SubjectID", + obj: &SubjectID{ID: "sub"}, + copy: &SubjectID{ID: "sub"}, + }, + } + for _, tt := range tc { + t.Run("type="+tt.name, func(t *testing.T) { + assert.NoError(t, m.MapFieldsToUUID(ctx, tt.obj)) + assert.NoError(t, m.MapFieldsFromUUID(ctx, tt.obj)) + assert.Equal(t, tt.copy, tt.obj) + }) + } + }) +} diff --git a/internal/uuidmapping/definitions.go b/internal/uuidmapping/definitions.go deleted file mode 100644 index 9ddeac751..000000000 --- a/internal/uuidmapping/definitions.go +++ /dev/null @@ -1,48 +0,0 @@ -package uuidmapping - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/gofrs/uuid" -) - -type ( - ManagerProvider interface { - UUIDMappingManager() Manager - } - Manager interface { - // ToUUID returns the mapped UUID for the given string representation. - // If the string representation is not mapped, a new UUID will be - // created automatically. - ToUUID(ctx context.Context, representation string) (uuid.UUID, error) - - // FromUUID returns the text representation for the given UUID. - FromUUID(ctx context.Context, id uuid.UUID) (text string, err error) - } -) - -func ManagerTest(t *testing.T, m Manager) { - ctx := context.Background() - - t.Run("case=ToUUID_FromUUID", func(t *testing.T) { - rep1 := "foo" - id, err := m.ToUUID(ctx, rep1) - require.NoError(t, err) - - rep2, err := m.FromUUID(ctx, id) - assert.NoError(t, err) - assert.Equal(t, rep1, rep2) - }) - - t.Run("case=FromUUID", func(t *testing.T) { - id1, err := m.ToUUID(ctx, "string") - assert.NoError(t, err) - id2, err := m.ToUUID(ctx, "string") - assert.NoError(t, err) - assert.Equal(t, id1, id2) - }) -}