Skip to content

Commit

Permalink
feat: improve FK handling
Browse files Browse the repository at this point in the history
Additionally:
- allow custom FK names (limited applicability rn)
- implement GetReverse for AddFK and DropFK
- detect renamed foreign keys
- EqualSignature handles empty models (no columns)
  • Loading branch information
bevzzz committed Oct 27, 2024
1 parent 4c1dfdb commit a822fc5
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 121 deletions.
32 changes: 18 additions & 14 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pgdialect

import (
"context"
"fmt"

"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate/sqlschema"
Expand All @@ -20,15 +19,18 @@ type Migrator struct {

var _ sqlschema.Migrator = (*Migrator)(nil)

func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) error {
query := fmt.Sprintf("ALTER TABLE %s RENAME TO %s", oldName, newName)
_, err := m.db.ExecContext(ctx, query)
if err != nil {
func (m *Migrator) exec(ctx context.Context, q *bun.RawQuery) error {
if _, err := q.Exec(ctx); err != nil {
return err
}
return nil
}

func (m *Migrator) RenameTable(ctx context.Context, oldName, newName string) error {
q := m.db.NewRaw("ALTER TABLE ? RENAME TO ?", bun.Ident(oldName), bun.Ident(newName))
return m.exec(ctx, q)
}

func (m *Migrator) AddContraint(ctx context.Context, fk sqlschema.FK, name string) error {
q := m.db.NewRaw(
"ALTER TABLE ?.? ADD CONSTRAINT ? FOREIGN KEY (?) REFERENCES ?.? (?)",
Expand All @@ -37,19 +39,21 @@ func (m *Migrator) AddContraint(ctx context.Context, fk sqlschema.FK, name strin
bun.Safe(fk.To.Schema), bun.Safe(fk.To.Table),
bun.Safe(fk.To.Column.String()),
)
if _, err := q.Exec(ctx); err != nil {
return err
}
return nil
return m.exec(ctx, q)
}

func (m *Migrator) DropContraint(ctx context.Context, schema, table, name string) error {
q := m.db.NewRaw(
"ALTER TABLE ?.? DROP CONSTRAINT ?",
bun.Safe(schema), bun.Safe(table), bun.Safe(name),
bun.Ident(schema), bun.Ident(table), bun.Ident(name),
)
if _, err := q.Exec(ctx); err != nil {
return err
}
return nil
return m.exec(ctx, q)
}

func (m *Migrator) RenameConstraint(ctx context.Context, schema, table, oldName, newName string) error {
q := m.db.NewRaw(
"ALTER TABLE ?.? RENAME CONSTRAINT ? TO ?",
bun.Ident(schema), bun.Ident(table), bun.Ident(oldName), bun.Ident(newName),
)
return m.exec(ctx, q)
}
197 changes: 168 additions & 29 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"sort"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -171,6 +172,8 @@ func TestAutoMigrator_Run(t *testing.T) {
{testRenameTable},
{testCreateDropTable},
{testAlterForeignKeys},
{testCustomFKNameFunc},
{testForceRenameFK},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -328,6 +331,8 @@ func testAlterForeignKeys(t *testing.T, db *bun.DB) {
require.NoError(t, err)

defaultSchema := db.Dialect().DefaultSchema()

// Crated 2 new constraints
require.Contains(t, state.FKs, sqlschema.FK{
From: sqlschema.C(defaultSchema, "things_to_owners", "owner_id"),
To: sqlschema.C(defaultSchema, "owners", "id"),
Expand All @@ -336,13 +341,146 @@ func testAlterForeignKeys(t *testing.T, db *bun.DB) {
From: sqlschema.C(defaultSchema, "things_to_owners", "thing_id"),
To: sqlschema.C(defaultSchema, "things", "id"),
})

// Dropped the initial one
require.NotContains(t, state.FKs, sqlschema.FK{
From: sqlschema.C(defaultSchema, "things", "owner_id"),
To: sqlschema.C(defaultSchema, "owners", "id"),
})
}

func TestDetector_Diff(t *testing.T) {
func testForceRenameFK(t *testing.T, db *bun.DB) {
// Database state
type Owner struct {
ID int64 `bun:",pk"`
}

type OwnedThing struct {
bun.BaseModel `bun:"table:things"`
ID int64 `bun:",pk"`
OwnerID int64 `bun:"owner_id,notnull"`

Owner *Owner `bun:"rel:belongs-to,join:owner_id=id"`
}

// Model state
type Person struct {
ID int64 `bun:",pk"`
}

type PersonalThing struct {
bun.BaseModel `bun:"table:things"`
ID int64 `bun:",pk"`
PersonID int64 `bun:"owner_id,notnull"`

Owner *Person `bun:"rel:belongs-to,join:owner_id=id"`
}

ctx := context.Background()
dbInspector, err := sqlschema.NewInspector(db)
if err != nil {
t.Skip(err)
}

mustCreateTableWithFKs(t, ctx, db,
(*Owner)(nil),
(*OwnedThing)(nil),
)
mustDropTableOnCleanup(t, ctx, db, (*Person)(nil))

m, err := migrate.NewAutoMigrator(db,
migrate.WithTableNameAuto(migrationsTable),
migrate.WithLocksTableNameAuto(migrationLocksTable),
migrate.WithModel(
(*Person)(nil),
(*PersonalThing)(nil),
),
migrate.WithFKNameFunc(func(fk sqlschema.FK) string {
return strings.Join([]string{
fk.From.Table, fk.To.Table, "fkey",
}, "_")
}),
migrate.WithRenameFK(true),
)
require.NoError(t, err)

// Act
err = m.Run(ctx)
require.NoError(t, err)

// Assert
state, err := dbInspector.Inspect(ctx)
require.NoError(t, err)

schema := db.Dialect().DefaultSchema()
wantName, ok := state.FKs[sqlschema.FK{
From: sqlschema.C(schema, "things", "owner_id"),
To: sqlschema.C(schema, "people", "id"),
}]
require.True(t, ok, "expect state.FKs to contain things_people_fkey")
require.Equal(t, wantName, "things_people_fkey")
}

func testCustomFKNameFunc(t *testing.T, db *bun.DB) {
// Database state
type Column struct {
OID int64 `bun:",pk"`
RelID int64 `bun:"attrelid,notnull"`
}
type Table struct {
OID int64 `bun:",pk"`
}

// Model state
type ColumnM struct {
bun.BaseModel `bun:"table:columns"`
OID int64 `bun:",pk"`
RelID int64 `bun:"attrelid,notnull"`

Table *Table `bun:"rel:belongs-to,join:attrelid=oid"`
}
type TableM struct {
bun.BaseModel `bun:"table:tables"`
OID int64 `bun:",pk"`
}

ctx := context.Background()
dbInspector, err := sqlschema.NewInspector(db)
if err != nil {
t.Skip(err)
}

mustCreateTableWithFKs(t, ctx, db,
(*Table)(nil),
(*Column)(nil),
)

m, err := migrate.NewAutoMigrator(db,
migrate.WithTableNameAuto(migrationsTable),
migrate.WithLocksTableNameAuto(migrationLocksTable),
migrate.WithFKNameFunc(func(sqlschema.FK) string { return "test_fkey" }),
migrate.WithModel((*TableM)(nil)),
migrate.WithModel((*ColumnM)(nil)),
)
require.NoError(t, err)

// Act
err = m.Run(ctx)
require.NoError(t, err)

// Assert
state, err := dbInspector.Inspect(ctx)
require.NoError(t, err)

fkName := state.FKs[sqlschema.FK{
From: sqlschema.C(db.Dialect().DefaultSchema(), "columns", "attrelid"),
To: sqlschema.C(db.Dialect().DefaultSchema(), "tables", "oid"),
}]
require.Equal(t, fkName, "test_fkey")
}

// TODO: rewrite these tests into AutoMigrator tests, Diff should be moved to migrate/internal package
func TestDiff(t *testing.T) {
type Journal struct {
ISBN string `bun:"isbn,pk"`
Title string `bun:"title,notnull"`
Expand Down Expand Up @@ -378,6 +516,8 @@ func TestDetector_Diff(t *testing.T) {
}

testEachDialect(t, func(t *testing.T, dialectName string, dialect schema.Dialect) {
defaultSchema := dialect.DefaultSchema()

for _, tt := range []struct {
name string
states func(testing.TB, context.Context, schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State)
Expand Down Expand Up @@ -418,7 +558,7 @@ func TestDetector_Diff(t *testing.T) {
},
want: []migrate.Operation{
&migrate.RenameTable{
Schema: dialect.DefaultSchema(),
Schema: defaultSchema,
From: "journals",
To: "journals_renamed",
},
Expand All @@ -430,7 +570,7 @@ func TestDetector_Diff(t *testing.T) {
Name: "billing.subscriptions", // TODO: fix once schema is used correctly
},
&migrate.DropTable{
Schema: dialect.DefaultSchema(),
Schema: defaultSchema,
Name: "reviews",
},
},
Expand All @@ -454,7 +594,7 @@ func TestDetector_Diff(t *testing.T) {
},
want: []migrate.Operation{
&migrate.DropTable{
Schema: dialect.DefaultSchema(),
Schema: defaultSchema,
Name: "external_users",
},
&migrate.CreateTable{
Expand Down Expand Up @@ -500,21 +640,19 @@ func TestDetector_Diff(t *testing.T) {
)
},
want: []migrate.Operation{
&migrate.AddForeignKey{
SourceSchema: dialect.DefaultSchema(),
SourceTable: "users",
SourceColumns: []string{"pet_kind", "pet_name"},
TargetSchema: dialect.DefaultSchema(),
TargetTable: "pets",
TargetColumns: []string{"kind", "nickname"},
&migrate.AddFK{
FK: sqlschema.FK{
From: sqlschema.C(defaultSchema, "users", "pet_kind", "pet_name"),
To: sqlschema.C(defaultSchema, "pets", "kind", "nickname"),
},
ConstraintName: "users_pet_kind_pet_name_fkey",
},
&migrate.AddForeignKey{
SourceSchema: dialect.DefaultSchema(),
SourceTable: "users",
SourceColumns: []string{"friend"},
TargetSchema: dialect.DefaultSchema(),
TargetTable: "users",
TargetColumns: []string{"username"},
&migrate.AddFK{
FK: sqlschema.FK{
From: sqlschema.C(defaultSchema, "users", "friend"),
To: sqlschema.C(defaultSchema, "users", "username"),
},
ConstraintName: "users_friend_fkey",
},
},
},
Expand All @@ -532,13 +670,12 @@ func TestDetector_Diff(t *testing.T) {
&migrate.CreateTable{
Model: &Owner{},
},
&migrate.AddForeignKey{
SourceSchema: dialect.DefaultSchema(),
SourceTable: "things",
SourceColumns: []string{"owner_id"},
TargetSchema: dialect.DefaultSchema(),
TargetTable: "owners",
TargetColumns: []string{"id"},
&migrate.AddFK{
FK: sqlschema.FK{
From: sqlschema.C(defaultSchema, "things", "owner_id"),
To: sqlschema.C(defaultSchema, "owners", "id"),
},
ConstraintName: "things_owner_id_fkey",
},
},
},
Expand All @@ -557,12 +694,14 @@ func TestDetector_Diff(t *testing.T) {
},
want: []migrate.Operation{
&migrate.DropTable{
Schema: dialect.DefaultSchema(),
Schema: defaultSchema,
Name: "owners",
},
&migrate.DropForeignKey{
Schema: dialect.DefaultSchema(),
Table: "things",
&migrate.DropFK{
FK: sqlschema.FK{
From: sqlschema.C(defaultSchema, "things", "owner_id"),
To: sqlschema.C(defaultSchema, "owners", "id"),
},
ConstraintName: "test_fkey",
},
},
Expand Down
Loading

0 comments on commit a822fc5

Please sign in to comment.