diff --git a/cmd/expand/root_test.go b/cmd/expand/root_test.go index 4e87ab6a3..0be0a22bc 100644 --- a/cmd/expand/root_test.go +++ b/cmd/expand/root_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/ory/x/cmdx" - "github.com/stretchr/testify/assert" "github.com/ory/keto/cmd/client" @@ -18,12 +17,16 @@ func TestExpandCommand(t *testing.T) { t.Run("case=unknown tuple", func(t *testing.T) { t.Run("format=JSON", func(t *testing.T) { - stdOut := ts.Cmd.ExecNoErr(t, "access", nspace.Name, "object", "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) + stdOut := ts.Cmd.ExecNoErr(t, + "access", nspace.Name, + "object", "--"+cmdx.FlagFormat, string(cmdx.FormatJSON)) assert.Equal(t, "null\n", stdOut) }) t.Run("format=default", func(t *testing.T) { - stdOut := ts.Cmd.ExecNoErr(t, "access", nspace.Name, "object", "--"+cmdx.FlagFormat, string(cmdx.FormatDefault)) + stdOut := ts.Cmd.ExecNoErr(t, + "access", nspace.Name, + "object", "--"+cmdx.FlagFormat, string(cmdx.FormatDefault)) assert.Contains(t, stdOut, "empty tree") }) }) diff --git a/cmd/migrate/migrate_test.go b/cmd/migrate/migrate_test.go index d3bde2b16..849ce1026 100644 --- a/cmd/migrate/migrate_test.go +++ b/cmd/migrate/migrate_test.go @@ -105,7 +105,7 @@ func TestMigrate(t *testing.T) { t.Cleanup(func() { // migrate all down - t.Log(cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) + t.Logf("cleanup:\n%s\n", cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) }) parts := strings.Split(stdOut, "Are you sure that you want to apply this migration?") @@ -120,7 +120,7 @@ func TestMigrate(t *testing.T) { t.Cleanup(func() { // migrate all down - t.Log(cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) + t.Logf("cleanup:\n%s\n", cmd.ExecNoErr(t, "down", "0", "--"+FlagYes)) }) parts := strings.Split(out, "Applying migrations...") diff --git a/cmd/migrate/migrate_uuid_mapping.go b/cmd/migrate/migrate_uuid_mapping.go deleted file mode 100644 index 8c15c60c4..000000000 --- a/cmd/migrate/migrate_uuid_mapping.go +++ /dev/null @@ -1,49 +0,0 @@ -package migrate - -import ( - "fmt" - - "github.com/ory/x/cmdx" - "github.com/ory/x/flagx" - "github.com/pkg/errors" - "github.com/spf13/cobra" - - "github.com/ory/keto/internal/driver" - "github.com/ory/keto/internal/persistence" - "github.com/ory/keto/internal/persistence/sql/migrations" -) - -func newMigrateUUIDMappingCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "uuid-mapping", - Short: "Migrate the non-UUID subject and object names to UUIDs.", - Long: `Migrate the non-UUID subject and object names to UUIDs. -This step only has to be executed once. -Please ensure that you have a backup in case something goes wrong!"`, - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, _ []string) error { - reg, err := driver.NewDefaultRegistry(cmd.Context(), cmd.Flags(), false) - if errors.Is(err, persistence.ErrNetworkMigrationsMissing) { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), - "Migrations were not applied yet, please apply them first using `keto migrate up`.") - return cmdx.FailSilently(cmd) - } else if err != nil { - return err - } - - if !flagx.MustGetBool(cmd, FlagYes) && - !cmdx.AskForConfirmation( - "Are you sure you want to migrate the subject and object names to UUIDs?", - cmd.InOrStdin(), cmd.OutOrStdout()) { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "OK, aborting.") - return nil - } - - migrator := migrations.NewToUUIDMappingMigrator(reg) - return migrator.MigrateUUIDMappings(cmd.Context()) - }, - } - RegisterYesFlag(cmd.Flags()) - - return cmd -} diff --git a/cmd/migrate/root.go b/cmd/migrate/root.go index ddabf0f3d..f29f9d3f6 100644 --- a/cmd/migrate/root.go +++ b/cmd/migrate/root.go @@ -17,7 +17,6 @@ func newMigrateCmd(opts []ketoctx.Option) *cobra.Command { newStatusCmd(opts), newUpCmd(opts), newDownCmd(opts), - newMigrateUUIDMappingCmd(), ) return cmd } diff --git a/contrib/docs-code-samples/list-api-display-objects/99-cleanup/index.js b/contrib/docs-code-samples/list-api-display-objects/99-cleanup/index.js index 0a50c6d61..41a63b793 100644 --- a/contrib/docs-code-samples/list-api-display-objects/99-cleanup/index.js +++ b/contrib/docs-code-samples/list-api-display-objects/99-cleanup/index.js @@ -30,7 +30,6 @@ readClient.listRelationTuples(readRequest, (err, resp) => { writeClient.transactRelationTuples(writeRequest, (err) => { if (err) { console.log('Unexpected err', err) - return 1 } }) }) diff --git a/internal/check/handler.go b/internal/check/handler.go index 66d509cd1..0485ff687 100644 --- a/internal/check/handler.go +++ b/internal/check/handler.go @@ -5,18 +5,14 @@ import ( "encoding/json" "net/http" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" "github.com/pkg/errors" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "google.golang.org/grpc" "github.com/ory/keto/internal/relationtuple" - - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/x" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type ( @@ -30,7 +26,10 @@ type ( } ) -var _ rts.CheckServiceServer = (*Handler)(nil) +var ( + _ rts.CheckServiceServer = (*Handler)(nil) + _ *getCheckRequest = nil +) func NewHandler(d handlerDependencies) *Handler { return &Handler{d: d} @@ -64,7 +63,6 @@ type RESTResponse struct { } // swagger:parameters getCheck postCheck -// nolint:deadcode,unused type getCheckRequest struct { // in:query MaxDepth int `json:"max-depth"` @@ -105,6 +103,10 @@ func (h *Handler) getCheck(w http.ResponseWriter, r *http.Request, _ httprouter. return } + if err := h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), tuple); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } allowed, err := h.d.PermissionEngine().SubjectIsAllowed(r.Context(), tuple, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -151,6 +153,10 @@ func (h *Handler) postCheck(w http.ResponseWriter, r *http.Request, _ httprouter return } + if err := h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(r.Context(), &tuple); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } allowed, err := h.d.PermissionEngine().SubjectIsAllowed(r.Context(), &tuple, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) @@ -171,6 +177,9 @@ func (h *Handler) Check(ctx context.Context, req *rts.CheckRequest) (*rts.CheckR return nil, err } + if err := h.d.PermissionEngine().d.UUIDMappingManager().MapFieldsToUUID(ctx, tuple); err != nil { + return nil, err + } allowed, err := h.d.PermissionEngine().SubjectIsAllowed(ctx, tuple, int(req.MaxDepth)) // TODO add content change handling if err != nil { diff --git a/internal/check/handler_test.go b/internal/check/handler_test.go index 5f81bad24..19548dbf9 100644 --- a/internal/check/handler_test.go +++ b/internal/check/handler_test.go @@ -23,6 +23,8 @@ import ( ) func assertAllowed(t *testing.T, resp *http.Response) { + t.Helper() + body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -97,7 +99,7 @@ func TestRESTHandler(t *testing.T) { Relation: "r", Subject: &relationtuple.SubjectID{ID: "s"}, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rt)) + relationtuple.MapAndWriteTuples(t, reg, rt) q, err := rt.ToURLQuery() require.NoError(t, err) diff --git a/internal/driver/config/namespace_memory.go b/internal/driver/config/namespace_memory.go index 7f451c35f..2bd1e184b 100644 --- a/internal/driver/config/namespace_memory.go +++ b/internal/driver/config/namespace_memory.go @@ -34,7 +34,7 @@ func (s *memoryNamespaceManager) GetNamespaceByName(_ context.Context, name stri } } - return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf("Unknown namespace with name %s.", name)) + return nil, errors.WithStack(herodot.ErrNotFound.WithReasonf("Unknown namespace with name %q.", name)) } func (s *memoryNamespaceManager) GetNamespaceByConfigID(_ context.Context, id int32) (*namespace.Namespace, error) { diff --git a/internal/driver/pop_connection.go b/internal/driver/pop_connection.go index 578379742..da02b9d9b 100644 --- a/internal/driver/pop_connection.go +++ b/internal/driver/pop_connection.go @@ -63,6 +63,12 @@ func (r *RegistryDefault) PopConnectionWithOpts(ctx context.Context, popOpts ... return nil, errors.WithStack(err) } + // Close this connection when the context is closed. + go func() { + <-ctx.Done() + conn.Close() + }() + return conn.WithContext(ctx), nil } diff --git a/internal/driver/registry.go b/internal/driver/registry.go index 2d66d655e..d7aa61f22 100644 --- a/internal/driver/registry.go +++ b/internal/driver/registry.go @@ -4,21 +4,15 @@ import ( "context" "net/http" + "github.com/gobuffalo/pop/v6" + "github.com/ory/x/healthx" "github.com/ory/x/otelx" prometheus "github.com/ory/x/prometheusx" - - "github.com/gobuffalo/pop/v6" - - "github.com/ory/keto/internal/driver/config" - "github.com/ory/keto/internal/uuidmapping" - "github.com/spf13/cobra" - "google.golang.org/grpc" - "github.com/ory/x/healthx" - "github.com/ory/keto/internal/check" + "github.com/ory/keto/internal/driver/config" "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/persistence" "github.com/ory/keto/internal/relationtuple" @@ -34,7 +28,6 @@ type ( x.WriterProvider relationtuple.ManagerProvider - uuidmapping.ManagerProvider expand.EngineProvider check.EngineProvider persistence.Migrator diff --git a/internal/driver/registry_default.go b/internal/driver/registry_default.go index dc2712a47..0178713b8 100644 --- a/internal/driver/registry_default.go +++ b/internal/driver/registry_default.go @@ -8,6 +8,7 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/ory/herodot" "github.com/ory/x/dbal" + "github.com/ory/x/fsx" "github.com/ory/x/healthx" "github.com/ory/x/logrusx" "github.com/ory/x/metricsx" @@ -24,8 +25,9 @@ import ( "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/persistence" "github.com/ory/keto/internal/persistence/sql" + _ "github.com/ory/keto/internal/persistence/sql/migrations" + "github.com/ory/keto/internal/persistence/sql/migrations/uuidmapping" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" "github.com/ory/keto/internal/x" "github.com/ory/keto/ketoctx" rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" @@ -152,7 +154,7 @@ func (r *RegistryDefault) RelationTupleManager() relationtuple.Manager { return r.p } -func (r *RegistryDefault) UUIDMappingManager() uuidmapping.Manager { +func (r *RegistryDefault) UUIDMappingManager() relationtuple.UUIDMappingManager { if r.p == nil { panic("no relation tuple manager, but expected to have one") } @@ -186,10 +188,16 @@ func (r *RegistryDefault) MigrationBox(ctx context.Context) (*popx.MigrationBox, if err != nil { return nil, err } - mb, err := sql.NewMigrationBox(c, r.Logger(), r.Tracer(ctx)) + + mb, err := popx.NewMigrationBox( + fsx.Merge(sql.Migrations, networkx.Migrations), + popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0), + popx.WithGoMigrations(uuidmapping.Migrations), + ) if err != nil { return nil, err } + r.mb = mb } return r.mb, nil diff --git a/internal/e2e/cases_test.go b/internal/e2e/cases_test.go index c71d9a684..01e357930 100644 --- a/internal/e2e/cases_test.go +++ b/internal/e2e/cases_test.go @@ -5,17 +5,14 @@ import ( "strconv" "testing" - "github.com/stretchr/testify/require" - "github.com/ory/herodot" - - "github.com/ory/keto/internal/x" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/keto/internal/expand" "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/x" ) func runCases(c client, addNamespace func(*testing.T, ...*namespace.Namespace)) func(*testing.T) { @@ -188,7 +185,7 @@ func runCases(c client, addNamespace func(*testing.T, ...*namespace.Namespace)) Relation: "rel", } resp := c.queryTuple(t, q) - require.Equal(t, resp.RelationTuples, rts) + require.ElementsMatch(t, resp.RelationTuples, rts) c.deleteAllTuples(t, q) resp = c.queryTuple(t, q) diff --git a/internal/expand/handler.go b/internal/expand/handler.go index abf32803c..42dfc4d04 100644 --- a/internal/expand/handler.go +++ b/internal/expand/handler.go @@ -4,22 +4,19 @@ import ( "context" "net/http" + "github.com/julienschmidt/httprouter" "github.com/ory/herodot" - - rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" - "google.golang.org/grpc" "github.com/ory/keto/internal/relationtuple" - - "github.com/julienschmidt/httprouter" - "github.com/ory/keto/internal/x" + rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2" ) type ( handlerDependencies interface { EngineProvider + relationtuple.ManagerProvider x.LoggerProvider x.WriterProvider } @@ -28,7 +25,10 @@ type ( } ) -var _ rts.ExpandServiceServer = (*handler)(nil) +var ( + _ rts.ExpandServiceServer = (*handler)(nil) + _ *getExpandRequest = nil +) const RouteBase = "/relation-tuples/expand" @@ -49,7 +49,6 @@ func (h *handler) RegisterReadGRPC(s *grpc.Server) { func (h *handler) RegisterWriteGRPC(s *grpc.Server) {} // swagger:parameters getExpand -// nolint:deadcode,unused type getExpandRequest struct { // in:query MaxDepth int `json:"max-depth"` @@ -81,24 +80,41 @@ func (h *handler) getExpand(w http.ResponseWriter, r *http.Request, _ httprouter return } - res, err := h.d.ExpandEngine().BuildTree(r.Context(), (&relationtuple.SubjectSet{}).FromURLQuery(r.URL.Query()), maxDepth) + subject := (&relationtuple.SubjectSet{}).FromURLQuery(r.URL.Query()) + + if err := h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), subject); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } + res, err := h.d.ExpandEngine().BuildTree(r.Context(), subject, maxDepth) if err != nil { h.d.Writer().WriteError(w, r, err) return } + if err := h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), res); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } h.d.Writer().Write(w, r, res) } func (h *handler) Expand(ctx context.Context, req *rts.ExpandRequest) (*rts.ExpandResponse, error) { - sub, err := relationtuple.SubjectFromProto(req.Subject) + subject, err := relationtuple.SubjectFromProto(req.Subject) if err != nil { return nil, err } - tree, err := h.d.ExpandEngine().BuildTree(ctx, sub, int(req.MaxDepth)) + + if err := h.d.UUIDMappingManager().MapFieldsToUUID(ctx, subject); err != nil { + return nil, err + } + res, err := h.d.ExpandEngine().BuildTree(ctx, subject, int(req.MaxDepth)) if err != nil { return nil, err } + if err := h.d.UUIDMappingManager().MapFieldsFromUUID(ctx, res); err != nil { + return nil, err + } - return &rts.ExpandResponse{Tree: tree.ToProto()}, nil + return &rts.ExpandResponse{Tree: res.ToProto()}, nil } diff --git a/internal/expand/handler_test.go b/internal/expand/handler_test.go index 95ae81875..6ba553bf3 100644 --- a/internal/expand/handler_test.go +++ b/internal/expand/handler_test.go @@ -79,20 +79,20 @@ func TestRESTHandler(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), []*relationtuple.InternalRelationTuple{ - { + relationtuple.MapAndWriteTuples(t, reg, + &relationtuple.InternalRelationTuple{ Namespace: nspace.Name, Object: rootSub.Object, Relation: rootSub.Relation, Subject: expectedTree.Children[0].Subject, }, - { + &relationtuple.InternalRelationTuple{ Namespace: nspace.Name, Object: rootSub.Object, Relation: rootSub.Relation, Subject: expectedTree.Children[1].Subject, }, - }...)) + ) qs := rootSub.ToURLQuery() qs.Set("max-depth", "2") @@ -103,6 +103,44 @@ func TestRESTHandler(t *testing.T) { actualTree := expand.Tree{} require.NoError(t, json.NewDecoder(resp.Body).Decode(&actualTree)) - assert.Equal(t, expectedTree, &actualTree) + assertEqualTrees(t, expectedTree, &actualTree) }) } + +func assertEqualTrees(t *testing.T, expected, actual *expand.Tree) { + t.Helper() + assert.Truef(t, treesAreEqual(t, expected, actual), + "expected:\n%s\n\nactual:\n%s", expected.String(), actual.String()) +} + +func treesAreEqual(t *testing.T, expected, actual *expand.Tree) bool { + if expected == nil || actual == nil { + return expected == actual + } + + if expected.Type != actual.Type { + t.Logf("expected type %q, actual type %q", expected.Type, actual.Type) + return false + } + if expected.Subject.String() != actual.Subject.String() { + t.Logf("expected subject: %q, actual subject: %q", expected.Subject.String(), actual.Subject.String()) + return false + } + if len(expected.Children) != len(actual.Children) { + t.Logf("expected len(children)=%d, actual len(children)=%d", len(expected.Children), len(actual.Children)) + return false + } + + // For children, we check for equality disregarding the order +outer: + for _, expectedChild := range expected.Children { + for _, actualChild := range actual.Children { + if treesAreEqual(t, expectedChild, actualChild) { + continue outer + } + } + t.Logf("expected child:\n%s\n\nactual child:\n%s", expectedChild.String(), actual.String()) + return false + } + return true +} diff --git a/internal/expand/tree.go b/internal/expand/tree.go index 4340ce3e3..3f2c8ec27 100644 --- a/internal/expand/tree.go +++ b/internal/expand/tree.go @@ -217,7 +217,7 @@ func TreeFromProto(t *rts.SubjectTree) (*Tree, error) { func (t *Tree) String() string { if t == nil { - return "" + return "(nil)" } sub := t.Subject.String() @@ -233,3 +233,16 @@ func (t *Tree) String() string { return fmt.Sprintf("∪ %s\n├─ %s", sub, strings.Join(children, "\n├─ ")) } + +func (t *Tree) UUIDMappableFields() (res []*string) { + if t == nil { + return + } + if t.Subject != nil { + res = append(res, t.Subject.UUIDMappableFields()...) + } + for _, c := range t.Children { + res = append(res, c.UUIDMappableFields()...) + } + return +} diff --git a/internal/persistence/definitions.go b/internal/persistence/definitions.go index 46e87d937..0b1d267f3 100644 --- a/internal/persistence/definitions.go +++ b/internal/persistence/definitions.go @@ -9,13 +9,12 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" ) type ( Persister interface { relationtuple.Manager - uuidmapping.Manager + relationtuple.UUIDMappingManager Connection(ctx context.Context) *pop.Connection } diff --git a/internal/persistence/sql/full_test.go b/internal/persistence/sql/full_test.go index 595f9d051..ff772c9d9 100644 --- a/internal/persistence/sql/full_test.go +++ b/internal/persistence/sql/full_test.go @@ -14,7 +14,6 @@ import ( "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/persistence/sql" "github.com/ory/keto/internal/relationtuple" - "github.com/ory/keto/internal/uuidmapping" "github.com/ory/keto/internal/x/dbx" ) @@ -69,9 +68,9 @@ func TestPersister(t *testing.T) { relationtuple.IsolationTest(t, p0, p1, addNamespace(r, nspaces)) }) - t.Run("uuidmapping.ManagerTest", func(t *testing.T) { + t.Run("relationtuple.UUIDMappingManagerTest", func(t *testing.T) { p, _, _ := setup(t, dsn) - uuidmapping.ManagerTest(t, p) + relationtuple.UUIDMappingManagerTest(t, p) }) }) } diff --git a/internal/persistence/sql/migrations/migratest/migration_test.go b/internal/persistence/sql/migrations/migratest/migration_test.go index 8c7e7a4db..9f7440040 100644 --- a/internal/persistence/sql/migrations/migratest/migration_test.go +++ b/internal/persistence/sql/migrations/migratest/migration_test.go @@ -2,15 +2,15 @@ package migratest import ( "context" + "fmt" + "io/fs" "os" + "regexp" + "strings" "testing" "time" "github.com/gobuffalo/pop/v6" - - "github.com/ory/keto/internal/driver/config" - "github.com/ory/keto/internal/namespace" - "github.com/gofrs/uuid" "github.com/ory/x/fsx" "github.com/ory/x/logrusx" @@ -22,11 +22,87 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/keto/internal/driver" + "github.com/ory/keto/internal/driver/config" + "github.com/ory/keto/internal/namespace" "github.com/ory/keto/internal/persistence/sql" + "github.com/ory/keto/internal/persistence/sql/migrations/uuidmapping" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x/dbx" ) +// TODO(hperl): move to ory/x +func withTestdata(t *testing.T, testdata fs.FS) func(*popx.MigrationBox) *popx.MigrationBox { + return func(m *popx.MigrationBox) *popx.MigrationBox { + err := fs.WalkDir(testdata, ".", func(path string, info fs.DirEntry, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + if m, _ := regexp.MatchString(`\d+_testdata.sql`, info.Name()); !m { + return nil + } + version := strings.TrimSuffix(info.Name(), "_testdata.sql") + m.Migrations["up"] = append(m.Migrations["up"], popx.Migration{ + Version: version + "9", // run testdata after version + Path: path, + Name: "testdata", + DBType: "all", + Direction: "up", + Type: "sql", + Runner: func(m popx.Migration, _ *pop.Connection, tx *pop.Tx) error { + b, err := fs.ReadFile(testdata, m.Path) + if err != nil { + return err + } + _, err = tx.Exec(string(b)) + return err + }, + }) + m.Migrations["down"] = append(m.Migrations["down"], popx.Migration{ + Version: version + "9", // run testdata after version + Path: path, + Name: "testdata", + DBType: "all", + Direction: "down", + Type: "sql", + Runner: func(m popx.Migration, _ *pop.Connection, tx *pop.Tx) error { + return nil + }, + }) + + return nil + }) + if err != nil { + t.Fatalf("could not add all testdata migrations: %v", err) + } + + return m + } +} + +func hasDownMigrationWithVersion(mb *popx.MigrationBox, version string) bool { + for _, down := range mb.Migrations["down"] { + if version == down.Version { + return true + } + } + return false +} + +// check that every "up" migration has a corresponding "down" migration in +// reverse order. +// TODO(hperl): move to ory/x +func check(mb *popx.MigrationBox) error { + for _, up := range mb.Migrations["up"] { + if !hasDownMigrationWithVersion(mb, up.Version) { + return fmt.Errorf("migration %s has no corresponding down migration", up.Version) + } + } + return nil +} + func TestMigrations(t *testing.T) { const debugOnDisk = false @@ -39,10 +115,11 @@ func TestMigrations(t *testing.T) { var c *pop.Connection var err error + c, err = pop.NewConnection(&pop.ConnectionDetails{URL: db.Conn}) + require.NoError(t, err) + require.NoError(t, c.Open()) + t.Cleanup(func() { c.Close() }) for i := 0; i < 120; i++ { - c, err = pop.NewConnection(&pop.ConnectionDetails{URL: db.Conn}) - require.NoError(t, err) - require.NoError(t, c.Open()) if err := c.Store.(interface{ Ping() error }).Ping(); err == nil { break } @@ -50,33 +127,63 @@ func TestMigrations(t *testing.T) { } require.NoError(t, c.Store.(interface{ Ping() error }).Ping()) - tm := popx.NewTestMigrator(t, c, fsx.Merge(networkx.Migrations, os.DirFS("../sql")), os.DirFS("./testdata"), l) + tm, err := popx.NewMigrationBox( + fsx.Merge(sql.Migrations, networkx.Migrations), + popx.NewMigrator(c, l, nil, 1*time.Minute), + popx.WithGoMigrations(uuidmapping.Migrations), + withTestdata(t, os.DirFS("./testdata")), + ) + require.NoError(t, err) + if err := check(tm); err != nil { + t.Log(err) + t.Log("up migrations:") + for _, m := range tm.Migrations["up"] { + t.Logf("\t%s\t%s\t%s\n", m.Name, m.Version, m.DBType) + } + t.Log("down migrations:") + for _, m := range tm.Migrations["down"] { + t.Logf("\t%s\t%s\t%s\n", m.Name, m.Version, m.DBType) + } + t.FailNow() + } + // cleanup first require.NoError(t, tm.Down(ctx, -1)) + t.Log("before migration") + logMigrationStatus(t, tm) + t.Run("suite=up", func(t *testing.T) { - require.NoError(t, tm.Up(ctx)) + if err := tm.Up(ctx); err != nil { + t.Log("migrations failed:", err) + t.Fail() + } + logMigrationStatus(t, tm) }) t.Run("suite=fixtures", func(t *testing.T) { + reg := driver.NewTestRegistry(t, db) + require.NoError(t, + reg.Config(ctx).Set(config.KeyNamespaces, []*namespace.Namespace{ + {ID: 1, Name: "foo"}})) + p, err := sql.NewPersister(ctx, reg, uuid.Must(uuid.FromString("77fdc5e0-2260-49da-8aae-c36ba255d05b"))) + t.Run("table=legacy namespaces", func(t *testing.T) { // as they are legacy, we expect them to be actually dropped - - assert.ErrorIs(t, sqlcon.HandleError(c.RawQuery("SELECT * FROM keto_namespace").Exec()), sqlcon.ErrNoSuchTable) + assert.ErrorIs(t, sqlcon.HandleError(c.RawQuery( + "SELECT * FROM keto_namespace", + ).Exec()), sqlcon.ErrNoSuchTable) }) t.Run("table=relation tuples", func(t *testing.T) { - reg := driver.NewTestRegistry(t, db) - require.NoError(t, - reg.Config(ctx).Set(config.KeyNamespaces, []*namespace.Namespace{{ID: 1, Name: "foo"}})) - - p, err := sql.NewPersister(ctx, reg, uuid.Must(uuid.FromString("77fdc5e0-2260-49da-8aae-c36ba255d05b"))) require.NoError(t, err) - rts, next, err := p.GetRelationTuples(context.Background(), &relationtuple.RelationQuery{Namespace: "foo"}) + actualRts, next, err := p.GetRelationTuples(ctx, &relationtuple.RelationQuery{Namespace: "foo"}) + require.NoError(t, err) assert.Equal(t, "", next) + t.Log("actual rts:", actualRts) - for _, rt := range []*relationtuple.InternalRelationTuple{ + expectedRts := []*relationtuple.InternalRelationTuple{ { Namespace: "foo", Object: "object", @@ -93,9 +200,14 @@ func TestMigrations(t *testing.T) { Relation: "s_relation", }, }, - } { - assert.Contains(t, rts, rt) } + + // The relationship tuples in the db have a UUID mapping, so + // we need to convert our expectations to that. + assert.NoError(t, p.MapFieldsToUUID( + ctx, relationtuple.InternalRelationTuples(expectedRts))) + assert.ElementsMatch(t, expectedRts, actualRts) + logMigrationStatus(t, tm) }) }) @@ -108,3 +220,11 @@ func TestMigrations(t *testing.T) { }) } } + +func logMigrationStatus(t *testing.T, m *popx.MigrationBox) { + status, err := m.Status(context.Background()) + require.NoError(t, err) + s := strings.Builder{} + _ = status.Write(&s) + t.Log("Migration status:\n", s.String()) +} diff --git a/internal/persistence/sql/migrations/single_table.go b/internal/persistence/sql/migrations/single_table.go index 62349d893..6ef438f1c 100644 --- a/internal/persistence/sql/migrations/single_table.go +++ b/internal/persistence/sql/migrations/single_table.go @@ -35,7 +35,7 @@ type ( } toSingleTableMigrator struct { d dependencies - perPage int + PerPage int } relationTuple struct { @@ -126,11 +126,11 @@ func (relationTuple) TableName(ctx context.Context) string { func NewToSingleTableMigrator(d dependencies) *toSingleTableMigrator { return &toSingleTableMigrator{ d: d, - perPage: 100, + PerPage: 100, } } -func (m *toSingleTableMigrator) namespaceMigrationBox(ctx context.Context, n *namespace.Namespace) (*popx.MigrationBox, error) { +func (m *toSingleTableMigrator) NamespaceMigrationBox(ctx context.Context, n *namespace.Namespace) (*popx.MigrationBox, error) { c, err := m.d.PopConnectionWithOpts(ctx, func(d *pop.ConnectionDetails) { d.Options = map[string]string{ "migration_table_name": migrationTableFromNamespace(n), @@ -149,7 +149,7 @@ func (m *toSingleTableMigrator) namespaceMigrationBox(ctx context.Context, n *na ) } -func (m *toSingleTableMigrator) getOldRelationTuples(ctx context.Context, n *namespace.Namespace, page, perPage int) (relationTuples, bool, error) { +func (m *toSingleTableMigrator) GetOldRelationTuples(ctx context.Context, n *namespace.Namespace, page, perPage int) (relationTuples, bool, error) { q := m.d.Persister().Connection(ctx). WithContext(context.WithValue(ctx, namespaceCtxKey, n)). Order("object, relation, subject, commit_time"). @@ -165,7 +165,7 @@ func (m *toSingleTableMigrator) getOldRelationTuples(ctx context.Context, n *nam return res, q.Paginator.Page < q.Paginator.TotalPages, nil } -func (m *toSingleTableMigrator) insertOldRelationTuples(ctx context.Context, n *namespace.Namespace, rs ...*relationtuple.InternalRelationTuple) error { +func (m *toSingleTableMigrator) InsertOldRelationTuples(ctx context.Context, n *namespace.Namespace, rs ...*relationtuple.InternalRelationTuple) error { for _, r := range rs { if r.Subject == nil { return errors.New("subject is not allowed to be nil") @@ -196,7 +196,7 @@ func (m *toSingleTableMigrator) MigrateNamespace(ctx context.Context, n *namespa if err := p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { for page := 1; ; page++ { - rs, hasNext, err := m.getOldRelationTuples(ctx, n, page, m.perPage) + rs, hasNext, err := m.GetOldRelationTuples(ctx, n, page, m.PerPage) if err != nil { return err } @@ -252,8 +252,10 @@ func (m *toSingleTableMigrator) LegacyNamespaces(ctx context.Context) ([]*namesp query = c.RawQuery("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'keto_%_relation_tuples'") case "postgres": query = c.RawQuery("SELECT tablename FROM pg_catalog.pg_tables WHERE tablename LIKE 'keto_%_relation_tuples'") - case "cockroach", "mysql": + case "cockroach": query = c.RawQuery("SELECT table_name FROM information_schema.tables WHERE table_name LIKE 'keto_%_relation_tuples'") + case "mysql": + query = c.RawQuery("SELECT table_name FROM information_schema.tables WHERE table_name LIKE 'keto_%_relation_tuples' AND table_schema = DATABASE()") default: panic("got unknown database dialect " + d) } @@ -262,6 +264,7 @@ func (m *toSingleTableMigrator) LegacyNamespaces(ctx context.Context) ([]*namesp if err := sqlcon.HandleError(query.All(&tableNames)); err != nil { return nil, err } + m.d.Logger().Debugf("Found tables %v", tableNames) nm, err := m.d.Config(ctx).NamespaceManager() if err != nil { @@ -284,7 +287,7 @@ func (m *toSingleTableMigrator) LegacyNamespaces(ctx context.Context) ([]*namesp } func (m *toSingleTableMigrator) MigrateDown(ctx context.Context, n *namespace.Namespace) error { - mb, err := m.namespaceMigrationBox(ctx, n) + mb, err := m.NamespaceMigrationBox(ctx, n) if err != nil { return err } diff --git a/internal/persistence/sql/migrations/single_table_test.go b/internal/persistence/sql/migrations/single_table_test.go index 47de0622b..44f370a37 100644 --- a/internal/persistence/sql/migrations/single_table_test.go +++ b/internal/persistence/sql/migrations/single_table_test.go @@ -1,4 +1,4 @@ -package migrations +package migrations_test import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/ory/keto/internal/driver" "github.com/ory/keto/internal/driver/config" "github.com/ory/keto/internal/namespace" + "github.com/ory/keto/internal/persistence/sql/migrations" "github.com/ory/keto/internal/relationtuple" "github.com/ory/keto/internal/x" "github.com/ory/keto/internal/x/dbx" @@ -24,7 +25,7 @@ func TestToSingleTableMigrator(t *testing.T) { r := driver.NewTestRegistry(t, dsn) ctx := context.Background() var nn []*namespace.Namespace - m := NewToSingleTableMigrator(r) + m := migrations.NewToSingleTableMigrator(r) setup := func(t *testing.T) *namespace.Namespace { n := &namespace.Namespace{ @@ -34,7 +35,7 @@ func TestToSingleTableMigrator(t *testing.T) { nn = append(nn, n) - mb, err := m.namespaceMigrationBox(ctx, n) + mb, err := m.NamespaceMigrationBox(ctx, n) require.NoError(t, err) require.NoError(t, mb.Up(ctx)) @@ -70,10 +71,10 @@ func TestToSingleTableMigrator(t *testing.T) { Relation: "b", }, } - require.NoError(t, m.insertOldRelationTuples(ctx, n, sID, sSet)) + require.NoError(t, m.InsertOldRelationTuples(ctx, n, sID, sSet)) // get the tuple from the old table - oldRts, next, err := m.getOldRelationTuples(ctx, n, 0, 100) + oldRts, next, err := m.GetOldRelationTuples(ctx, n, 0, 100) require.NoError(t, err) assert.False(t, next) require.Len(t, oldRts, 2) @@ -100,9 +101,9 @@ func TestToSingleTableMigrator(t *testing.T) { n := setup(t) defer func(old int) { - m.perPage = old - }(m.perPage) - m.perPage = 1 + m.PerPage = old + }(m.PerPage) + m.PerPage = 1 rts := make([]*relationtuple.InternalRelationTuple, 10) for i := range rts { @@ -114,7 +115,7 @@ func TestToSingleTableMigrator(t *testing.T) { } } - require.NoError(t, m.insertOldRelationTuples(ctx, n, rts...)) + require.NoError(t, m.InsertOldRelationTuples(ctx, n, rts...)) require.NoError(t, m.MigrateNamespace(ctx, n)) migrated, nextToken, err := r.RelationTupleManager().GetRelationTuples(ctx, &relationtuple.RelationQuery{Namespace: n.Name}, x.WithSize(len(rts))) @@ -132,7 +133,7 @@ func TestToSingleTableMigrator(t *testing.T) { Relation: "r", Subject: &relationtuple.SubjectID{ID: "s"}, } - require.NoError(t, m.insertOldRelationTuples(ctx, n, &relationtuple.InternalRelationTuple{ + require.NoError(t, m.InsertOldRelationTuples(ctx, n, &relationtuple.InternalRelationTuple{ Namespace: n.Name, Object: "o0", Relation: "r", @@ -145,7 +146,7 @@ func TestToSingleTableMigrator(t *testing.T) { }, valid)) err := m.MigrateNamespace(ctx, n) require.Error(t, err) - invalid, ok := err.(ErrInvalidTuples) + invalid, ok := err.(migrations.ErrInvalidTuples) require.True(t, ok) assert.Len(t, invalid, 2) @@ -168,7 +169,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { t.Run("case=simple detection", func(t *testing.T) { ctx := context.Background() reg := driver.NewTestRegistry(t, dsn) - m := NewToSingleTableMigrator(reg) + m := migrations.NewToSingleTableMigrator(reg) nspaces := []*namespace.Namespace{{ ID: 3, @@ -182,7 +183,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { assert.Len(t, legacyNamespaces, 0) // migrate legacy table up - mb, err := m.namespaceMigrationBox(ctx, nspaces[0]) + mb, err := m.NamespaceMigrationBox(ctx, nspaces[0]) require.NoError(t, err) require.NoError(t, mb.Up(ctx)) @@ -203,7 +204,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { t.Run("case=multiple namespaces", func(t *testing.T) { ctx := context.Background() reg := driver.NewTestRegistry(t, dsn) - m := NewToSingleTableMigrator(reg) + m := migrations.NewToSingleTableMigrator(reg) nspaces := []*namespace.Namespace{{ ID: 0, @@ -219,7 +220,7 @@ func TestToSingleTableMigrator_HasLegacyTable(t *testing.T) { for _, n := range nspaces { // migrate legacy table up - mb, err := m.namespaceMigrationBox(ctx, n) + mb, err := m.NamespaceMigrationBox(ctx, n) require.NoError(t, err) require.NoError(t, mb.Up(ctx)) } diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.cockroach.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.cockroach.down.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.cockroach.up.sql b/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.cockroach.up.sql deleted file mode 100644 index 9aad41258..000000000 --- a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.cockroach.up.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE TABLE keto_uuid_mappings -( - id UUID NOT NULL, - string_representation TEXT NOT NULL CHECK (string_representation <> ''), - - PRIMARY KEY (id) -); \ No newline at end of file diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.mysql.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.mysql.down.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.postgres.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.postgres.down.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.sqlite3.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.sqlite3.down.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.sqlite3.up.sql b/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.sqlite3.up.sql deleted file mode 100644 index 9aad41258..000000000 --- a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.sqlite3.up.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE TABLE keto_uuid_mappings -( - id UUID NOT NULL, - string_representation TEXT NOT NULL CHECK (string_representation <> ''), - - PRIMARY KEY (id) -); \ No newline at end of file diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.cockroach.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.cockroach.down.sql deleted file mode 100644 index 1a0e9eb51..000000000 --- a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.cockroach.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE keto_uuid_mappings; \ No newline at end of file diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.cockroach.up.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.cockroach.up.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.mysql.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.mysql.down.sql deleted file mode 100644 index 1a0e9eb51..000000000 --- a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.mysql.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE keto_uuid_mappings; \ No newline at end of file diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.mysql.up.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.mysql.up.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.postgres.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.postgres.down.sql deleted file mode 100644 index 1a0e9eb51..000000000 --- a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.postgres.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE keto_uuid_mappings; \ No newline at end of file diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.postgres.up.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.postgres.up.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.sqlite3.down.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.sqlite3.down.sql deleted file mode 100644 index 1a0e9eb51..000000000 --- a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.sqlite3.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE keto_uuid_mappings; \ No newline at end of file diff --git a/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.sqlite3.up.sql b/internal/persistence/sql/migrations/sql/20220110200400000001_create-uuid-mapping-table.sqlite3.up.sql deleted file mode 100644 index e69de29bb..000000000 diff --git a/internal/persistence/sql/migrations/sql/20220217152313000000_nid_fk.down.sql b/internal/persistence/sql/migrations/sql/20220217152313000000_nid_fk.down.sql index afbcf5401..1deb47d87 100644 --- a/internal/persistence/sql/migrations/sql/20220217152313000000_nid_fk.down.sql +++ b/internal/persistence/sql/migrations/sql/20220217152313000000_nid_fk.down.sql @@ -1 +1 @@ -ALTER TABLE keto_relation_tuples DROP CONSTRAINT keto_relation_tuples_nid_fk; +ALTER TABLE keto_relation_tuples DROP CONSTRAINT keto_relation_tuples_nid_fk; \ No newline at end of file diff --git a/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.down.sql b/internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.down.sql similarity index 100% rename from internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.down.sql rename to internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.down.sql diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.mysql.up.sql b/internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.mysql.up.sql similarity index 100% rename from internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.mysql.up.sql rename to internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.mysql.up.sql diff --git a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.postgres.up.sql b/internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.up.sql similarity index 73% rename from internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.postgres.up.sql rename to internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.up.sql index 9aad41258..1bed91840 100644 --- a/internal/persistence/sql/migrations/sql/20220110200400000000_create-uuid-mapping-table.postgres.up.sql +++ b/internal/persistence/sql/migrations/sql/20220513200400000000_create-uuid-mapping-table.up.sql @@ -1,6 +1,6 @@ CREATE TABLE keto_uuid_mappings ( - id UUID NOT NULL, + id UUID NOT NULL UNIQUE, string_representation TEXT NOT NULL CHECK (string_representation <> ''), PRIMARY KEY (id) diff --git a/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.mysql.up.sql b/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.mysql.up.sql deleted file mode 100644 index 5db883292..000000000 --- a/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.mysql.up.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE TABLE keto_uuid_mappings -( - id VARCHAR(64) NOT NULL, - string_representation TEXT NOT NULL CHECK (string_representation <> ''), - - PRIMARY KEY (id) -); \ No newline at end of file diff --git a/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.up.sql b/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.up.sql deleted file mode 100644 index 9aad41258..000000000 --- a/internal/persistence/sql/migrations/templates/20220110200400_create-uuid-mapping-table.up.sql +++ /dev/null @@ -1,7 +0,0 @@ -CREATE TABLE keto_uuid_mappings -( - id UUID NOT NULL, - string_representation TEXT NOT NULL CHECK (string_representation <> ''), - - PRIMARY KEY (id) -); \ No newline at end of file diff --git a/internal/persistence/sql/migrations/uuid_mapping.go b/internal/persistence/sql/migrations/uuid_mapping.go deleted file mode 100644 index 960256ab4..000000000 --- a/internal/persistence/sql/migrations/uuid_mapping.go +++ /dev/null @@ -1,114 +0,0 @@ -package migrations - -import ( - "context" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/ory/x/sqlcon" - - "github.com/ory/keto/internal/persistence/sql" -) - -type ( - toUUIDMappingMigrator struct { - d dependencies - perPage int - requestedPages int // track the number of pages we requested for testing. - } -) - -// NewToUUIDMappingMigrator creates a new UUID mapping migrator. -func NewToUUIDMappingMigrator(d dependencies) *toUUIDMappingMigrator { - return &toUUIDMappingMigrator{d: d, perPage: 100} -} - -// MigrateUUIDMappings migrates to UUID-mapped subject IDs for all relation -// tuples in the database. -func (m *toUUIDMappingMigrator) MigrateUUIDMappings(ctx context.Context) error { - p, ok := m.d.Persister().(*sql.Persister) - if !ok { - panic("got unexpected persister") - } - - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - m.requestedPages = 0 - for page := 1; ; page++ { - m.requestedPages++ - relationTuples, hasNext, err := m.getRelationTuples(ctx, page) - if err != nil { - return err - } - - if err := p.Connection(ctx).All(&relationTuples); err != nil { - return err - } - - for _, rt := range relationTuples { - if err := m.migrateSubjectID(ctx, rt); err != nil { - return err - } - if err := m.migrateSubjectSetObject(ctx, rt); err != nil { - return err - } - if err := m.migrateObject(ctx, rt); err != nil { - return err - } - if err := c.Update(rt); err != nil { - return err - } - } - - if !hasNext { - break - } - } - return nil - }) -} - -func (m *toUUIDMappingMigrator) migrateSubjectID(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.SubjectID.String) - if err == nil || !rt.SubjectID.Valid || rt.SubjectID.String == "" { - return nil - } - - rt.SubjectID.String, err = m.addUUIDMapping(ctx, rt.SubjectID.String) - return err -} - -func (m *toUUIDMappingMigrator) migrateSubjectSetObject(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.SubjectSetObject.String) - if err == nil || !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" { - return nil - } - - rt.SubjectSetObject.String, err = m.addUUIDMapping(ctx, rt.SubjectSetObject.String) - return err -} - -func (m *toUUIDMappingMigrator) migrateObject(ctx context.Context, rt *sql.RelationTuple) error { - _, err := uuid.FromString(rt.Object) - if err == nil || rt.Object == "" { - return nil - } - - rt.Object, err = m.addUUIDMapping(ctx, rt.Object) - return err -} - -func (m *toUUIDMappingMigrator) addUUIDMapping(ctx context.Context, value string) (id string, err error) { - uid, err := m.d.Persister().ToUUID(ctx, value) - return uid.String(), err -} - -func (m *toUUIDMappingMigrator) getRelationTuples(ctx context.Context, page int) (res []*sql.RelationTuple, hasNext bool, err error) { - q := m.d.Persister().Connection(ctx). - Order("nid, shard_id"). - Paginate(page, m.perPage) - - if err := q.All(&res); err != nil { - return nil, false, sqlcon.HandleError(err) - } - return res, q.Paginator.Page < q.Paginator.TotalPages, nil -} diff --git a/internal/persistence/sql/migrations/uuid_mapping_test.go b/internal/persistence/sql/migrations/uuid_mapping_test.go deleted file mode 100644 index 9756c5563..000000000 --- a/internal/persistence/sql/migrations/uuid_mapping_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package migrations - -import ( - "context" - dbsql "database/sql" - "testing" - "time" - - "github.com/gobuffalo/pop/v6" - "github.com/gofrs/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/ory/keto/internal/driver" - "github.com/ory/keto/internal/persistence/sql" - "github.com/ory/keto/internal/x/dbx" -) - -func TestToUUIDMappingMigrator(t *testing.T) { - const debugOnDisk = false - - for _, dsn := range dbx.GetDSNs(t, debugOnDisk) { - t.Run("db="+dsn.Name, func(t *testing.T) { - ctx := context.Background() - r := driver.NewTestRegistry(t, dsn) - m := NewToUUIDMappingMigrator(r) - p := m.d.Persister().(*sql.Persister) - conn := p.Connection(ctx) - nw, err := r.DetermineNetwork(ctx) - require.NoError(t, err) - - testCases := []struct { - name string - rt *sql.RelationTuple - expectMapping bool - }{{ - name: "with string subject", - rt: &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - Object: "object", - SubjectID: dbsql.NullString{String: "subject", Valid: true}, - CommitTime: time.Now(), - }, - expectMapping: true, - }, { - name: "with null subject", - rt: &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - SubjectID: dbsql.NullString{String: "", Valid: false}, - SubjectSetNamespaceID: dbsql.NullInt32{Int32: 0, Valid: true}, - SubjectSetObject: dbsql.NullString{String: "object", Valid: true}, - SubjectSetRelation: dbsql.NullString{String: "subject_set_relation", Valid: true}, - Object: "object", - CommitTime: time.Now(), - }, - expectMapping: true, - }, { - name: "with UUID subject", - rt: &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - SubjectID: dbsql.NullString{String: uuid.Must(uuid.NewV4()).String(), Valid: true}, - CommitTime: time.Now(), - }, - expectMapping: false, - }} - - for _, tc := range testCases { - t.Run("case="+tc.name, func(t *testing.T) { - require.NoError(t, conn.Create(tc.rt)) - require.NoError(t, m.MigrateUUIDMappings(ctx)) - - newRt := &sql.RelationTuple{} - require.NoError(t, conn.Find(newRt, tc.rt.ID)) - - if tc.expectMapping { - // Check subject mapping - if tc.rt.SubjectID.Valid { - assertHasMapping(t, conn, tc.rt.SubjectID.String, newRt.SubjectID.String) - } else { - assertHasMapping(t, conn, tc.rt.SubjectSetObject.String, newRt.SubjectSetObject.String) - // check that both "OBJECT" strings are mapped to same UUID. - assert.Equal(t, newRt.Object, newRt.SubjectSetObject.String) - } - assertHasMapping(t, conn, tc.rt.Object, newRt.Object) - } else { - // Nothing should have changed (ignoring commit time) - newRt.CommitTime = tc.rt.CommitTime - assert.Equal(t, tc.rt, newRt) - } - }) - } - }) - } -} - -func TestToUUIDMappingMigrator_paginates(t *testing.T) { - numTuples := 10 - perPage := 2 - - dsn := dbx.GetSqlite(t, dbx.SQLiteMemory) - ctx := context.Background() - r := driver.NewTestRegistry(t, dsn) - m := &toUUIDMappingMigrator{d: r, perPage: perPage} - p := m.d.Persister().(*sql.Persister) - conn := p.Connection(ctx) - nw, err := r.DetermineNetwork(ctx) - require.NoError(t, err) - - // Create a bunch of relation tuples - for i := 0; i < numTuples; i++ { - rt := &sql.RelationTuple{ - ID: uuid.Must(uuid.NewV4()), - NetworkID: nw.ID, - Object: "object", - SubjectID: dbsql.NullString{String: "subject", Valid: true}, - CommitTime: time.Now(), - } - require.NoError(t, conn.Create(rt)) - } - - require.NoError(t, m.MigrateUUIDMappings(ctx)) - assert.Equal(t, numTuples/perPage, m.requestedPages) -} - -// assertHasMapping checks that there is a mapping from the given string (value) -// to the given UUID (uid). -func assertHasMapping(t *testing.T, conn *pop.Connection, value, uid string) { - t.Helper() - mapping := &sql.UUIDMapping{} - require.NoError(t, conn.Find(mapping, uid), "Could not find mapping for %q", uid) - assert.NotEqual(t, value, uid, "value was not replaced by UUID") - assert.Equal(t, value, mapping.StringRepresentation) -} diff --git a/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go b/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go new file mode 100644 index 000000000..3c4f6c8f0 --- /dev/null +++ b/internal/persistence/sql/migrations/uuidmapping/uuid_mapping_migrator.go @@ -0,0 +1,190 @@ +package uuidmapping + +import ( + "database/sql" + "fmt" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/ory/x/popx" + "github.com/ory/x/sqlcon" +) + +type ( + // We copy the definitions of RelationTuple and UUIDMapping here so that the + // migration will always work on the same definitions. + RelationTuple struct { + // An ID field is required to make pop happy. The actual ID is a + // composite primary key. + ID uuid.UUID `db:"shard_id"` + NetworkID uuid.UUID `db:"nid"` + NamespaceID int32 `db:"namespace_id"` + Object string `db:"object"` + Relation string `db:"relation"` + SubjectID sql.NullString `db:"subject_id"` + SubjectSetNamespaceID sql.NullInt32 `db:"subject_set_namespace_id"` + SubjectSetObject sql.NullString `db:"subject_set_object"` + SubjectSetRelation sql.NullString `db:"subject_set_relation"` + CommitTime time.Time `db:"commit_time"` + } + UUIDMapping struct { + ID uuid.UUID `db:"id"` + StringRepresentation string `db:"string_representation"` + } + UUIDMappings []*UUIDMapping +) + +func (RelationTuple) TableName() string { return "keto_relation_tuples" } +func (UUIDMappings) TableName() string { return "keto_uuid_mappings" } +func (UUIDMapping) TableName() string { return "keto_uuid_mappings" } + +var ( + name = "migrate-strings-to-uuids" + version = "20220513210000000000" + Migrations = popx.Migrations{ + { + Version: version, + Name: name, + Path: name, + Direction: "up", + DBType: "all", + Type: "go", + Runner: func(_ popx.Migration, conn *pop.Connection, _ *pop.Tx) error { + for page := 1; ; page++ { + relationTuples, hasNext, err := getRelationTuples(conn, page) + if err != nil { + return fmt.Errorf("could not get relation tuples: %w", err) + } + + for _, rt := range relationTuples { + if err = migrateSubjectID(conn, &rt); err != nil { + return fmt.Errorf("could not migrate subject ID: %w", err) + } + if err = migrateSubjectSetObject(conn, &rt); err != nil { + return fmt.Errorf("could not migrate subject set object: %w", err) + } + if err = migrateObject(conn, &rt); err != nil { + return fmt.Errorf("could not migrate object: %w", err) + } + if err = conn.Update(&rt); err != nil { + return fmt.Errorf("failed to update relation tuple: %w", err) + } + } + + if !hasNext { + break + } + } + + return nil + }, + }, + // We have to specify the "down" migration even if it is a no-op, since + // the migrator will still manipulate the version table in the database. + { + Version: version, + Name: name, + Path: name, + Direction: "down", + DBType: "all", + Type: "go", + Runner: func(_ popx.Migration, _ *pop.Connection, _ *pop.Tx) error { + return nil + }, + }, + } +) + +func hasMapping(conn *pop.Connection, id string) (bool, error) { + found, err := conn.Where("id = ?", id).Exists(&UUIDMapping{}) + if err != nil { + return false, nil + } + return found, nil +} + +func migrateSubjectID(conn *pop.Connection, rt *RelationTuple) error { + if !rt.SubjectID.Valid || rt.SubjectID.String == "" { + return nil + } + skip, err := hasMapping(conn, rt.SubjectID.String) + if err != nil { + return err + } + if skip { + return nil + } + + rt.SubjectID.String, err = addUUIDMapping(conn, rt.NetworkID, rt.SubjectID.String) + return err +} + +func migrateSubjectSetObject(conn *pop.Connection, rt *RelationTuple) error { + if !rt.SubjectSetObject.Valid || rt.SubjectSetObject.String == "" { + return nil + } + skip, err := hasMapping(conn, rt.SubjectSetObject.String) + if err != nil { + return err + } + if skip { + return nil + } + + rt.SubjectSetObject.String, err = addUUIDMapping(conn, rt.NetworkID, rt.SubjectSetObject.String) + return err +} + +func migrateObject(conn *pop.Connection, rt *RelationTuple) error { + if rt.Object == "" { + return nil + } + skip, err := hasMapping(conn, rt.Object) + if err != nil { + return err + } + if skip { + return nil + } + + rt.Object, err = addUUIDMapping(conn, rt.NetworkID, rt.Object) + return err +} + +func addUUIDMapping(conn *pop.Connection, networkID uuid.UUID, value string) (uid string, err error) { + uid = uuid.NewV5(networkID, value).String() + + // We need to write manual SQL here because the INSERT should not fail if + // the UUID already exists, but we still want to return an error if anything + // else goes wrong. + var query string + switch d := conn.Dialect.Name(); d { + case "mysql": + query = ` + INSERT IGNORE INTO keto_uuid_mappings (id, string_representation) + VALUES (?, ?)` + default: + query = ` + INSERT INTO keto_uuid_mappings (id, string_representation) + VALUES (?, ?) + ON CONFLICT (id) DO NOTHING` + } + + err = sqlcon.HandleError(conn.RawQuery(query, uid, value).Exec()) + if err != nil { + return "", fmt.Errorf("failed to add UUID mapping: %w", err) + } + return +} + +func getRelationTuples(conn *pop.Connection, page int) ( + res []RelationTuple, hasNext bool, err error, +) { + q := conn.Order("nid, shard_id").Paginate(page, 100) + + if err := q.All(&res); err != nil { + return nil, false, sqlcon.HandleError(err) + } + return res, q.Paginator.Page < q.Paginator.TotalPages, nil +} diff --git a/internal/persistence/sql/persister.go b/internal/persistence/sql/persister.go index 86403a9b6..b47f97566 100644 --- a/internal/persistence/sql/persister.go +++ b/internal/persistence/sql/persister.go @@ -46,7 +46,7 @@ const ( var ( //go:embed migrations/sql/*.sql - migrations embed.FS + Migrations embed.FS _ persistence.Persister = &Persister{} ) @@ -67,7 +67,7 @@ func NewPersister(ctx context.Context, reg dependencies, nid uuid.UUID) (*Persis } func NewMigrationBox(c *pop.Connection, logger *logrusx.Logger, tracer *otelx.Tracer) (*popx.MigrationBox, error) { - return popx.NewMigrationBox(fsx.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, logger, tracer, 0)) + return popx.NewMigrationBox(fsx.Merge(Migrations, networkx.Migrations), popx.NewMigrator(c, logger, tracer, 0)) } func (p *Persister) Connection(ctx context.Context) *pop.Connection { diff --git a/internal/persistence/sql/relationtuples.go b/internal/persistence/sql/relationtuples.go index 98aee6a14..47b94713b 100644 --- a/internal/persistence/sql/relationtuples.go +++ b/internal/persistence/sql/relationtuples.go @@ -198,7 +198,7 @@ func (p *Persister) whereQuery(ctx context.Context, q *pop.Query, rq *relationtu } func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtuple.InternalRelationTuple) error { - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { for _, r := range rs { n, err := p.GetNamespaceByName(ctx, r.Namespace) if err != nil { @@ -223,7 +223,7 @@ func (p *Persister) DeleteRelationTuples(ctx context.Context, rs ...*relationtup } func (p *Persister) DeleteAllRelationTuples(ctx context.Context, query *relationtuple.RelationQuery) error { - return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + return p.Transaction(ctx, func(ctx context.Context, _ *pop.Connection) error { sqlQuery := p.QueryWithNetwork(ctx) err := p.whereQuery(ctx, sqlQuery, query) if err != nil { diff --git a/internal/persistence/sql/uuid_mapping.go b/internal/persistence/sql/uuid_mapping.go index 95ed4b5e5..48a1ab383 100644 --- a/internal/persistence/sql/uuid_mapping.go +++ b/internal/persistence/sql/uuid_mapping.go @@ -2,9 +2,13 @@ package sql import ( "context" + "fmt" "github.com/gofrs/uuid" "github.com/ory/x/sqlcon" + + "github.com/ory/keto/internal/relationtuple" + "github.com/ory/keto/internal/x" ) type ( @@ -48,13 +52,102 @@ func (p *Persister) ToUUID(ctx context.Context, text string) (uuid.UUID, error) ) } -func (p *Persister) FromUUID(ctx context.Context, id uuid.UUID) (rep string, err error) { - p.d.Logger().Trace("looking up UUID") +func (p *Persister) FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (res []string, err error) { + p.d.Logger().Trace("looking up UUIDs") + + // We need to paginate on the ids, because we want to get the exact chunk of + // string representations for the given ids. + pagination, _ := internalPaginationFromOptions(opts...) + pageSize := pagination.PerPage + + // Build a map from UUID -> indices in the result. + idIdx := make(map[uuid.UUID][]int) + for i, id := range ids { + if ids, ok := idIdx[id]; ok { + idIdx[id] = append(ids, i) + } else { + idIdx[id] = []int{i} + } + } + + res = make([]string, len(ids)) - m := &UUIDMapping{} - if err := sqlcon.HandleError(p.Connection(ctx).Find(m, id)); err != nil { - return "", err + for i := 0; i < len(ids); i += pageSize { + end := i + pageSize + if end > len(ids) { + end = len(ids) + } + idsToLookup := ids[i:end] + mappings := &[]UUIDMapping{} + query := p.Connection(ctx).Where("id in (?)", idsToLookup) + if err := sqlcon.HandleError(query.All(mappings)); err != nil { + return []string{}, err + } + + // Write the representation to the correct index. + for _, m := range *mappings { + for _, idx := range idIdx[m.ID] { + res[idx] = m.StringRepresentation + } + } } - return m.StringRepresentation, nil + return +} + +func (p *Persister) replaceWithUUID(ctx context.Context, s *string) error { + if s == nil { + return nil + } + uuid, err := p.ToUUID(ctx, *s) + if err != nil { + return err + } + *s = uuid.String() + + return nil +} + +func (p *Persister) MapFieldsToUUID(ctx context.Context, m relationtuple.UUIDMappable) error { + for _, s := range m.UUIDMappableFields() { + if s == nil || *s == "" { + continue + } + if err := p.replaceWithUUID(ctx, s); err != nil { + p.d.Logger().WithError(err).WithField("string", s).Error("got an error while mapping string to UUID") + return err + } + } + return nil +} + +func (p *Persister) MapFieldsFromUUID(ctx context.Context, m relationtuple.UUIDMappable) error { + ids := make([]uuid.UUID, len(m.UUIDMappableFields())) + for i, field := range m.UUIDMappableFields() { + if field == nil { + continue + } + id, err := uuid.FromString(*field) + if err != nil { + p.d.Logger().WithError(err).WithField("UUID", *field).Error("could not parse as UUID") + return err + } + ids[i] = id + } + reps, err := p.FromUUID(ctx, ids) + if err != nil { + p.d.Logger().WithError(err).WithField("UUIDs", ids).Error("could fetch string mappings from DB") + return err + } + for i, field := range m.UUIDMappableFields() { + if field == nil { + continue + } + if reps[i] == "" { + p.d.Logger().WithError(err).WithField("string", reps[i]).Error("could not find the corresponding UUID") + return fmt.Errorf("failed to map %s", ids[i]) + } + *field = reps[i] + } + return nil } diff --git a/internal/persistence/sql/uuid_mapping_test.go b/internal/persistence/sql/uuid_mapping_test.go index eddebd284..b45ab5a59 100644 --- a/internal/persistence/sql/uuid_mapping_test.go +++ b/internal/persistence/sql/uuid_mapping_test.go @@ -2,8 +2,11 @@ package sql_test import ( "context" + "fmt" + "strings" "testing" + "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -12,6 +15,19 @@ import ( "github.com/ory/keto/internal/x/dbx" ) +func assertCheckErr(t assert.TestingT, err error, msgAndArgs ...interface{}) bool { + t.(*testing.T).Helper() + if err == nil { + return assert.Fail(t, "Did not receive an error", msgAndArgs...) + } + + if strings.Contains(err.Error(), "keto_uuid_mappings") || // <- normal databases + strings.Contains(err.Error(), "SQLSTATE 23514") { // <- mysql + return true + } + return assert.Fail(t, fmt.Sprintf("Did not receive check error, got:\n%+v", err), msgAndArgs...) +} + func TestUUIDMapping(t *testing.T) { for _, dsn := range dbx.GetDSNs(t, false) { t.Run("dsn="+dsn.Name, func(t *testing.T) { @@ -19,33 +35,42 @@ func TestUUIDMapping(t *testing.T) { c, err := reg.PopConnection(context.Background()) require.NoError(t, err) + testUUID := uuid.Must(uuid.NewV4()) + for _, tc := range []struct { desc string mappings interface{} - shouldErr bool + assertErr assert.ErrorAssertionFunc }{{ desc: "empty should fail on constraint", mappings: &sql.UUIDMapping{}, - shouldErr: true, + assertErr: assertCheckErr, + }, { + desc: "empty strings should fail on constraint", + mappings: &sql.UUIDMapping{uuid.Nil, ""}, + assertErr: assertCheckErr, }, { desc: "single with string rep should succeed", mappings: &sql.UUIDMapping{StringRepresentation: "foo"}, - shouldErr: false, + assertErr: assert.NoError, + }, { + desc: "two with same uuid should fail on constraint", + mappings: &[]sql.UUIDMapping{ + {ID: testUUID, StringRepresentation: "foo"}, + {ID: testUUID, StringRepresentation: "bar"}, + }, + assertErr: assertCheckErr, }, { - desc: "two with same rep should fail on constraint", - mappings: sql.UUIDMappings{ - &sql.UUIDMapping{StringRepresentation: "bar"}, - &sql.UUIDMapping{StringRepresentation: "bar"}, + desc: "two with same rep should succeed", + mappings: &[]sql.UUIDMapping{ + {ID: uuid.Must(uuid.NewV4()), StringRepresentation: "bar"}, + {ID: uuid.Must(uuid.NewV4()), StringRepresentation: "bar"}, }, - shouldErr: true, + assertErr: assert.NoError, }} { t.Run("case="+tc.desc, func(t *testing.T) { - err = c.Create(tc.mappings) - if tc.shouldErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } + err := c.Create(tc.mappings) + tc.assertErr(t, err) }) } }) diff --git a/internal/relationtuple/definitions.go b/internal/relationtuple/definitions.go index 9fc1c63c9..a3db8f77a 100644 --- a/internal/relationtuple/definitions.go +++ b/internal/relationtuple/definitions.go @@ -24,6 +24,7 @@ import ( type ( ManagerProvider interface { RelationTupleManager() Manager + UUIDMappingManager() UUIDMappingManager } Manager interface { GetRelationTuples(ctx context.Context, query *RelationQuery, options ...x.PaginationOptionSetter) ([]*InternalRelationTuple, string, error) @@ -88,6 +89,9 @@ type Subject interface { // swagger:ignore ToProto() *rts.Subject + + // swagger:ignore + UUIDMappable } // swagger:ignore @@ -97,6 +101,18 @@ type InternalRelationTuple struct { Relation string `json:"relation"` Subject Subject `json:"subject"` } +type InternalRelationTuples []*InternalRelationTuple + +func (rt *InternalRelationTuple) UUIDMappableFields() []*string { + return append([]*string{&rt.Object}, rt.Subject.UUIDMappableFields()...) +} + +func (rtt InternalRelationTuples) UUIDMappableFields() (fields []*string) { + for _, rt := range rtt { + fields = append(fields, rt.UUIDMappableFields()...) + } + return fields +} // swagger:parameters getExpand type SubjectSet struct { @@ -141,6 +157,16 @@ func SubjectFromString(s string) (Subject, error) { return (&SubjectID{}).FromString(s) } +// swagger:ignore +func (s *SubjectID) UUIDMappableFields() []*string { + return []*string{&s.ID} +} + +// swagger:ignore +func (s *SubjectSet) UUIDMappableFields() []*string { + return []*string{&s.Object} +} + // swagger:ignore func SubjectFromProto(gs *rts.Subject) (Subject, error) { switch s := gs.GetRef().(type) { @@ -447,6 +473,17 @@ func (q *RelationQuery) FromProto(query TupleData) (*RelationQuery, error) { return q, nil } +func (q *RelationQuery) UUIDMappableFields() []*string { + res := []*string{&q.Object} + if q.SubjectID != nil { + res = append(res, q.SubjectID) + } + if q.SubjectSet != nil { + res = append(res, q.SubjectSet.UUIDMappableFields()...) + } + return res +} + const ( subjectIDKey = "subject_id" subjectSetNamespaceKey = "subject_set.namespace" @@ -684,3 +721,7 @@ func (t *ManagerWrapper) TransactRelationTuples(ctx context.Context, insert []*I func (t *ManagerWrapper) RelationTupleManager() Manager { return t } + +func (t *ManagerWrapper) UUIDMappingManager() UUIDMappingManager { + return t.Reg.UUIDMappingManager() +} diff --git a/internal/relationtuple/read_server.go b/internal/relationtuple/read_server.go index cc36f682e..9004756de 100644 --- a/internal/relationtuple/read_server.go +++ b/internal/relationtuple/read_server.go @@ -16,7 +16,10 @@ import ( "github.com/ory/keto/internal/x" ) -var _ rts.ReadServiceServer = (*handler)(nil) +var ( + _ rts.ReadServiceServer = (*handler)(nil) + _ = (*getRelationsParams)(nil) +) func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationTuplesRequest) (*rts.ListRelationTuplesResponse, error) { if req.Query == nil { @@ -28,6 +31,9 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT return nil, err } + if err := h.d.UUIDMappingManager().MapFieldsToUUID(ctx, q); err != nil { + return nil, err + } rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(ctx, q, x.WithSize(int(req.PageSize)), x.WithToken(req.PageToken), @@ -35,6 +41,9 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT if err != nil { return nil, err } + if err := h.d.UUIDMappingManager().MapFieldsFromUUID(ctx, InternalRelationTuples(rels)); err != nil { + return nil, err + } resp := &rts.ListRelationTuplesResponse{ RelationTuples: make([]*rts.RelationTuple, len(rels)), @@ -48,7 +57,6 @@ func (h *handler) ListRelationTuples(ctx context.Context, req *rts.ListRelationT } // swagger:parameters getRelationTuples -// nolint:deadcode,unused type getRelationsParams struct { // Namespace of the Relation Tuple // @@ -139,11 +147,19 @@ func (h *handler) getRelations(w http.ResponseWriter, r *http.Request, _ httprou paginationOpts = append(paginationOpts, x.WithSize(int(s))) } + if err := h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), query); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } rels, nextPage, err := h.d.RelationTupleManager().GetRelationTuples(r.Context(), query, paginationOpts...) if err != nil { h.d.Writer().WriteError(w, r, err) return } + if err := h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), InternalRelationTuples(rels)); err != nil { + h.d.Writer().WriteError(w, r, err) + return + } resp := &GetResponse{ RelationTuples: rels, diff --git a/internal/relationtuple/read_server_test.go b/internal/relationtuple/read_server_test.go index e772ea019..a432e08d0 100644 --- a/internal/relationtuple/read_server_test.go +++ b/internal/relationtuple/read_server_test.go @@ -75,7 +75,7 @@ func TestReadHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{ "namespace": {nspace.Name}, @@ -103,7 +103,7 @@ func TestReadHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) resp, err := ts.Client().Get(ts.URL + relationtuple.ReadRouteBase + "?" + url.Values{ "object": {obj}, @@ -114,7 +114,7 @@ func TestReadHandlers(t *testing.T) { var respMsg relationtuple.GetResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&respMsg)) assert.Equal(t, 1, len(respMsg.RelationTuples)) - assert.Contains(t, rts, respMsg.RelationTuples[0]) + assert.Containsf(t, rts, respMsg.RelationTuples[0], "expected to find %q in %q", respMsg.RelationTuples[0].String(), rts) assert.Equal(t, "", respMsg.NextPageToken) }) @@ -145,7 +145,7 @@ func TestReadHandlers(t *testing.T) { Subject: &relationtuple.SubjectID{ID: "s2"}, }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) var firstResp relationtuple.GetResponse t.Run("case=first page", func(t *testing.T) { diff --git a/internal/relationtuple/swagger_definitions.go b/internal/relationtuple/swagger_definitions.go index 7c32e2a38..5e215a002 100644 --- a/internal/relationtuple/swagger_definitions.go +++ b/internal/relationtuple/swagger_definitions.go @@ -1,7 +1,11 @@ package relationtuple +var ( + _ = (*relationTupleWithRequired)(nil) + _ = (*patchPayload)(nil) +) + // swagger:model InternalRelationTuple -// nolint:deadcode,unused type relationTupleWithRequired struct { // Namespace of the Relation Tuple // @@ -31,7 +35,6 @@ type relationTupleWithRequired struct { // The patch request payload // // swagger:parameters patchRelationTuples -// nolint:deadcode,unused type patchPayload struct { // in:body Payload []*PatchDelta @@ -41,3 +44,15 @@ type PatchDelta struct { Action patchAction `json:"action"` RelationTuple *InternalRelationTuple `json:"relation_tuple"` } + +// swagger:ignore +type PatchDeltas []*PatchDelta + +func (p PatchDeltas) UUIDMappableFields() (res []*string) { + for _, pd := range p { + if pd.RelationTuple != nil { + res = append(res, pd.RelationTuple.UUIDMappableFields()...) + } + } + return +} diff --git a/internal/relationtuple/test_helper.go b/internal/relationtuple/test_helper.go new file mode 100644 index 000000000..e7332de45 --- /dev/null +++ b/internal/relationtuple/test_helper.go @@ -0,0 +1,18 @@ +package relationtuple + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// MapAndWriteTuples is a test helper to write relation tuples to the database +// while mapping all strings to UUIDs. +func MapAndWriteTuples(t *testing.T, m ManagerProvider, tuples ...*InternalRelationTuple) { + t.Helper() + + require.NoError(t, m.UUIDMappingManager().MapFieldsToUUID(context.Background(), InternalRelationTuples(tuples))) + require.NoError(t, m.RelationTupleManager().WriteRelationTuples(context.Background(), tuples...)) + require.NoError(t, m.UUIDMappingManager().MapFieldsFromUUID(context.Background(), InternalRelationTuples(tuples))) +} diff --git a/internal/relationtuple/transact_server.go b/internal/relationtuple/transact_server.go index 5bdd02e72..4afd3b1e4 100644 --- a/internal/relationtuple/transact_server.go +++ b/internal/relationtuple/transact_server.go @@ -12,7 +12,11 @@ import ( "github.com/pkg/errors" ) -var _ rts.WriteServiceServer = (*handler)(nil) +var ( + _ rts.WriteServiceServer = (*handler)(nil) + _ = (*bodyRelationTuple)(nil) + _ = (*queryRelationTuple)(nil) +) func protoTuplesWithAction(deltas []*rts.RelationTupleDelta, action rts.RelationTupleDelta_Action) (filtered []*InternalRelationTuple, err error) { for _, d := range deltas { @@ -38,6 +42,9 @@ func (h *handler) TransactRelationTuples(ctx context.Context, req *rts.TransactR return nil, err } + if err = h.d.UUIDMappingManager().MapFieldsToUUID(ctx, InternalRelationTuples(append(insertTuples, deleteTuples...))); err != nil { + return nil, err + } err = h.d.RelationTupleManager().TransactRelationTuples(ctx, insertTuples, deleteTuples) if err != nil { return nil, err @@ -62,6 +69,9 @@ func (h *handler) DeleteRelationTuples(ctx context.Context, req *rts.DeleteRelat return nil, errors.WithStack(herodot.ErrBadRequest.WithError(err.Error())) } + if err := h.d.UUIDMappingManager().MapFieldsToUUID(ctx, q); err != nil { + return nil, err + } if err := h.d.RelationTupleManager().DeleteAllRelationTuples(ctx, q); err != nil { return nil, errors.WithStack(herodot.ErrInternalServerError.WithError(err.Error())) } @@ -72,7 +82,6 @@ func (h *handler) DeleteRelationTuples(ctx context.Context, req *rts.DeleteRelat // The basic ACL relation tuple // // swagger:parameters postCheck createRelationTuple -// nolint:deadcode,unused type bodyRelationTuple struct { // in: body Payload RelationQuery @@ -81,7 +90,6 @@ type bodyRelationTuple struct { // The basic ACL relation tuple // // swagger:parameters getCheck deleteRelationTuples -// nolint:deadcode,unused type queryRelationTuple struct { // Namespace of the Relation Tuple // @@ -151,11 +159,21 @@ func (h *handler) createRelation(w http.ResponseWriter, r *http.Request, _ httpr h.d.Logger().WithFields(rel.ToLoggerFields()).Debug("creating relation tuple") + if err := h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), &rel); err != nil { + h.d.Logger().WithError(err).WithFields(rel.ToLoggerFields()).Errorf("got an error while mapping fields to UUID") + h.d.Writer().WriteError(w, r, err) + return + } if err := h.d.RelationTupleManager().WriteRelationTuples(r.Context(), &rel); err != nil { h.d.Logger().WithError(err).WithFields(rel.ToLoggerFields()).Errorf("got an error while creating the relation tuple") h.d.Writer().WriteError(w, r, err) return } + if err := h.d.UUIDMappingManager().MapFieldsFromUUID(r.Context(), &rel); err != nil { + h.d.Logger().WithError(err).WithFields(rel.ToLoggerFields()).Errorf("got an error while mapping fields from UUID") + h.d.Writer().WriteError(w, r, err) + return + } q, err := rel.ToURLQuery() if err != nil { @@ -198,6 +216,11 @@ func (h *handler) deleteRelations(w http.ResponseWriter, r *http.Request, _ http } l.Debug("deleting relation tuples") + if err := h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), query); err != nil { + h.d.Logger().WithError(err).Errorf("got an error while mapping fields to UUID") + h.d.Writer().WriteError(w, r, err) + return + } if err := h.d.RelationTupleManager().DeleteAllRelationTuples(r.Context(), query); err != nil { l.WithError(err).Errorf("got an error while deleting relation tuples") h.d.Writer().WriteError(w, r, herodot.ErrInternalServerError.WithError(err.Error())) @@ -254,7 +277,17 @@ func (h *handler) patchRelations(w http.ResponseWriter, r *http.Request, _ httpr } } - if err := h.d.RelationTupleManager().TransactRelationTuples(r.Context(), internalTuplesWithAction(deltas, ActionInsert), internalTuplesWithAction(deltas, ActionDelete)); err != nil { + if err := h.d.UUIDMappingManager().MapFieldsToUUID(r.Context(), PatchDeltas(deltas)); err != nil { + h.d.Logger().WithError(err).Errorf("got an error while mapping fields to UUID") + h.d.Writer().WriteError(w, r, err) + return + } + if err := h.d.RelationTupleManager(). + TransactRelationTuples( + r.Context(), + internalTuplesWithAction(deltas, ActionInsert), + internalTuplesWithAction(deltas, ActionDelete)); err != nil { + h.d.Writer().WriteError(w, r, err) return } diff --git a/internal/relationtuple/transact_server_test.go b/internal/relationtuple/transact_server_test.go index 921e509c6..ae386065e 100644 --- a/internal/relationtuple/transact_server_test.go +++ b/internal/relationtuple/transact_server_test.go @@ -80,10 +80,14 @@ func TestWriteHandlers(t *testing.T) { assert.JSONEq(t, string(payload), string(body)) t.Run("check=is contained in the manager", func(t *testing.T) { + query := rt.ToQuery() + require.NoError(t, reg.UUIDMappingManager().MapFieldsToUUID(context.Background(), query)) // set a size > 1 just to make sure it gets all - actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), rt.ToQuery(), x.WithSize(10)) + actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), query, x.WithSize(10)) require.NoError(t, err) - assert.Equal(t, []*relationtuple.InternalRelationTuple{rt}, actualRTs) + require.NoError(t, reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs))) + require.NoError(t, reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), query)) + assert.Equalf(t, []*relationtuple.InternalRelationTuple{rt}, actualRTs, "want: %s\ngot: %s", rt.String(), actualRTs[0].String()) }) t.Run("check=is gettable with the returned URL", func(t *testing.T) { @@ -136,6 +140,7 @@ func TestWriteHandlers(t *testing.T) { Namespace: nspace.Name, }) require.NoError(t, err) + require.NoError(t, reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actual))) assert.Equal(t, "", next) assert.Len(t, actual, 2) for _, rt := range rts { @@ -154,7 +159,7 @@ func TestWriteHandlers(t *testing.T) { Relation: "deleted rel", Subject: &relationtuple.SubjectID{ID: "deleted subj"}, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rt)) + relationtuple.MapAndWriteTuples(t, reg, rt) q, err := rt.ToURLQuery() require.NoError(t, err) @@ -167,6 +172,7 @@ func TestWriteHandlers(t *testing.T) { // set a size > 1 just to make sure it gets all actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), rt.ToQuery(), x.WithSize(10)) require.NoError(t, err) + require.NoError(t, reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs))) assert.Equal(t, []*relationtuple.InternalRelationTuple{}, actualRTs) }) @@ -188,7 +194,7 @@ func TestWriteHandlers(t *testing.T) { }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), rts...)) + relationtuple.MapAndWriteTuples(t, reg, rts...) q := url.Values{ "namespace": {nspace.Name}, @@ -206,6 +212,7 @@ func TestWriteHandlers(t *testing.T) { actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), query, x.WithSize(10)) require.NoError(t, err) + require.NoError(t, reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs))) assert.Equal(t, []*relationtuple.InternalRelationTuple{}, actualRTs) }) }) @@ -234,7 +241,7 @@ func TestWriteHandlers(t *testing.T) { }, }, } - require.NoError(t, reg.RelationTupleManager().WriteRelationTuples(context.Background(), deltas[1].RelationTuple)) + relationtuple.MapAndWriteTuples(t, reg, deltas[1].RelationTuple) body, err := json.Marshal(deltas) require.NoError(t, err) @@ -249,6 +256,8 @@ func TestWriteHandlers(t *testing.T) { Relation: t.Name(), }) require.NoError(t, err) + err = reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs)) + require.NoError(t, err) assert.Equal(t, []*relationtuple.InternalRelationTuple{deltas[0].RelationTuple}, actualRTs) }) @@ -316,6 +325,7 @@ func TestWriteHandlers(t *testing.T) { actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(context.Background(), &relationtuple.RelationQuery{ Namespace: nspace.Name, }) + require.NoError(t, reg.UUIDMappingManager().MapFieldsFromUUID(context.Background(), relationtuple.InternalRelationTuples(actualRTs))) require.NoError(t, err) assert.Equal(t, []*relationtuple.InternalRelationTuple{deltas[0].RelationTuple}, actualRTs) }) diff --git a/internal/relationtuple/uuid_mapping.go b/internal/relationtuple/uuid_mapping.go new file mode 100644 index 000000000..54aea61c2 --- /dev/null +++ b/internal/relationtuple/uuid_mapping.go @@ -0,0 +1,102 @@ +package relationtuple + +import ( + "context" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/keto/internal/x" +) + +type ( + // UUIDMappable is an interface for objects that have fields that can be + // mapped to and from UUIDs. + UUIDMappable interface{ UUIDMappableFields() []*string } + + UUIDMappingManager interface { + // ToUUID returns the mapped UUID for the given string representation. + // If the string representation is not mapped, a new UUID will be + // created automatically. + ToUUID(ctx context.Context, representation string) (uuid.UUID, error) + + // MapFields maps all fields of the given object to UUIDs. + MapFieldsToUUID(ctx context.Context, m UUIDMappable) error + + // MapFieldsFromUUID maps all fields of the given object from UUIDs to + // their string value. + MapFieldsFromUUID(ctx context.Context, m UUIDMappable) error + + // FromUUID returns the text representations for the given UUIDs, such + // that ids[i] is mapped to reps[i]. + // + // Of the pagination options, only the page size is considered and used + // as a batch size. + FromUUID(ctx context.Context, ids []uuid.UUID, opts ...x.PaginationOptionSetter) (reps []string, err error) + } +) + +func UUIDMappingManagerTest(t *testing.T, m UUIDMappingManager) { + ctx := context.Background() + + t.Run("case=ToUUID_FromUUID", func(t *testing.T) { + rep1 := "foo" + id, err := m.ToUUID(ctx, rep1) + require.NoError(t, err) + + rep2, err := m.FromUUID(ctx, []uuid.UUID{id}) + assert.NoError(t, err) + assert.Equal(t, rep1, rep2[0]) + }) + + t.Run("case=Idempotent_ToUUID", func(t *testing.T) { + id1, err := m.ToUUID(ctx, "string") + assert.NoError(t, err) + id2, err := m.ToUUID(ctx, "string") + assert.NoError(t, err) + assert.Equal(t, id1, id2) + }) + + // Test that the batch mapping preserves ordering, i.e. id[i] is mapped to + // rep[i]. + t.Run("case=Batch_ToUUID_Paginates", func(t *testing.T) { + expected := []string{"foo", "foo", "bar", "baz"} + ids := make([]uuid.UUID, len(expected)) + for i, s := range expected { + var err error + ids[i], err = m.ToUUID(ctx, s) + assert.NoError(t, err) + } + + actual, err := m.FromUUID(ctx, ids, x.WithSize(1)) + assert.NoError(t, err) + assert.Equal(t, expected, actual) + }) + + t.Run("case=IdempotentMapFieldsToAndFromUUIDs", func(t *testing.T) { + tc := []struct { + name string + obj UUIDMappable + copy UUIDMappable + }{ + { + name: "RelationTuple", + obj: &InternalRelationTuple{Namespace: "n", Relation: "r", Object: "Object", Subject: &SubjectID{ID: "Subject"}}, + copy: &InternalRelationTuple{Namespace: "n", Relation: "r", Object: "Object", Subject: &SubjectID{ID: "Subject"}}, + }, { + name: "SubjectID", + obj: &SubjectID{ID: "sub"}, + copy: &SubjectID{ID: "sub"}, + }, + } + for _, tt := range tc { + t.Run("type="+tt.name, func(t *testing.T) { + assert.NoError(t, m.MapFieldsToUUID(ctx, tt.obj)) + assert.NoError(t, m.MapFieldsFromUUID(ctx, tt.obj)) + assert.Equal(t, tt.copy, tt.obj) + }) + } + }) +} diff --git a/internal/uuidmapping/definitions.go b/internal/uuidmapping/definitions.go deleted file mode 100644 index 9ddeac751..000000000 --- a/internal/uuidmapping/definitions.go +++ /dev/null @@ -1,48 +0,0 @@ -package uuidmapping - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/gofrs/uuid" -) - -type ( - ManagerProvider interface { - UUIDMappingManager() Manager - } - Manager interface { - // ToUUID returns the mapped UUID for the given string representation. - // If the string representation is not mapped, a new UUID will be - // created automatically. - ToUUID(ctx context.Context, representation string) (uuid.UUID, error) - - // FromUUID returns the text representation for the given UUID. - FromUUID(ctx context.Context, id uuid.UUID) (text string, err error) - } -) - -func ManagerTest(t *testing.T, m Manager) { - ctx := context.Background() - - t.Run("case=ToUUID_FromUUID", func(t *testing.T) { - rep1 := "foo" - id, err := m.ToUUID(ctx, rep1) - require.NoError(t, err) - - rep2, err := m.FromUUID(ctx, id) - assert.NoError(t, err) - assert.Equal(t, rep1, rep2) - }) - - t.Run("case=FromUUID", func(t *testing.T) { - id1, err := m.ToUUID(ctx, "string") - assert.NoError(t, err) - id2, err := m.ToUUID(ctx, "string") - assert.NoError(t, err) - assert.Equal(t, id1, id2) - }) -} diff --git a/internal/x/dbx/dsn_testutils.go b/internal/x/dbx/dsn_testutils.go index 6c146a5dc..207615423 100644 --- a/internal/x/dbx/dsn_testutils.go +++ b/internal/x/dbx/dsn_testutils.go @@ -1,16 +1,20 @@ package dbx import ( + "crypto/rand" "fmt" "io/ioutil" + "net/url" "os" "path/filepath" + "strings" "testing" + "github.com/go-sql-driver/mysql" + "github.com/gobuffalo/pop/v6" + "github.com/ory/x/sqlcon/dockertest" "github.com/stretchr/testify/require" "github.com/tidwall/sjson" - - "github.com/ory/x/sqlcon/dockertest" ) type DsnT struct { @@ -19,6 +23,53 @@ type DsnT struct { MigrateUp, MigrateDown bool } +const mySQLSchema = "mysql://" + +func mySQLWithDbName(dsn string, db string) string { + cfg, err := mysql.ParseDSN(strings.TrimPrefix(dsn, mySQLSchema)) + if err != nil { + return "" + } + cfg.DBName = db + return mySQLSchema + cfg.FormatDSN() +} + +func withDbName(dsn string, db string) string { + // Special case for mysql, because their URLs are not parsable. + if strings.HasPrefix(dsn, mySQLSchema) { + return mySQLWithDbName(dsn, db) + } + + u, err := url.Parse(dsn) + if err != nil { + return "" + } + u.Path = db + + return u.String() +} + +// dbName returns a name for the database based on the test name. +func dbName(_ string) string { + var buf [20]byte + rand.Read(buf[:]) + return fmt.Sprintf("testdb_%x", buf) +} + +func createDB(t testing.TB, dsn string) (err error) { + var conn *pop.Connection + + if conn, err = pop.NewConnection(&pop.ConnectionDetails{URL: dsn}); err != nil { + return fmt.Errorf("failed to connect to %q: %w", dsn, err) + } + if err = pop.CreateDB(conn); err != nil { + return fmt.Errorf("failed to create db in %q: %w", dsn, err) + } + t.Cleanup(func() { pop.DropDB(conn) }) + + return +} + func GetDSNs(t testing.TB, debugSqliteOnDisk bool) []*DsnT { sqliteMode := SQLiteFile if debugSqliteOnDisk { @@ -33,16 +84,26 @@ func GetDSNs(t testing.TB, debugSqliteOnDisk bool) []*DsnT { if !testing.Short() { var mysql, postgres, cockroach string + db := dbName(t.Name()) dockertest.Parallel([]func(){ func() { - mysql = dockertest.RunTestMySQL(t) + mysql = withDbName(dockertest.RunTestMySQL(t), db) + if err := createDB(t, mysql); err != nil { + t.Fatal(err) + } }, func() { - postgres = dockertest.RunTestPostgreSQL(t) + postgres = withDbName(dockertest.RunTestPostgreSQL(t), db) + if err := createDB(t, postgres); err != nil { + t.Fatal(err) + } }, func() { - cockroach = dockertest.RunTestCockroachDB(t) + cockroach = withDbName(dockertest.RunTestCockroachDB(t), db) + if err := createDB(t, cockroach); err != nil { + t.Fatal(err) + } }, }) diff --git a/internal/x/dbx/dsn_testutils_test.go b/internal/x/dbx/dsn_testutils_test.go new file mode 100644 index 000000000..5591dfa5f --- /dev/null +++ b/internal/x/dbx/dsn_testutils_test.go @@ -0,0 +1,61 @@ +package dbx + +import ( + "testing" + + "github.com/gobuffalo/pop/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_withDbName(t *testing.T) { + type args struct { + dsn string + db string + } + tests := []struct { + name string + args args + want string + }{{ + name: "postgres", + args: args{ + dsn: "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable", + db: "mydb", + }, + want: "postgres://postgres:postgres@localhost:5432/mydb?sslmode=disable", + }, { + name: "cockroach", + args: args{ + dsn: "cockroach://root@localhost:49364/defaultdb?sslmode=disable", + db: "foo", + }, + want: "cockroach://root@localhost:49364/foo?sslmode=disable", + }, { + name: "mysql", + args: args{ + dsn: "mysql://root:secret@(localhost:49394)/mysql?parseTime=true&multiStatements=true", + db: "testdb", + }, + want: "mysql://root:secret@tcp(localhost:49394)/testdb?multiStatements=true&parseTime=true", + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := withDbName(tt.args.dsn, tt.args.db); got != tt.want { + t.Errorf("\nwant %q\ngot %q", tt.want, got) + } + }) + } +} + +func Test_GetDSNs_can_connect_to_each_db(t *testing.T) { + for _, db := range GetDSNs(t, false) { + t.Run("dsn="+db.Name, func(t *testing.T) { + conn, err := pop.NewConnection(&pop.ConnectionDetails{URL: db.Conn}) + require.NoError(t, err) + assert.NoError(t, conn.Open()) + assert.NoError(t, Ping(conn)) + assert.NoError(t, conn.Close()) + }) + } +}