Skip to content

Commit

Permalink
Small refactor for addreference (yoyo-project#8)
Browse files Browse the repository at this point in the history
The Migrator.AddReference signature was set up to expect a db, but that
puts too much burden on the concretions to do  business logic outside of
normal dbms-specific implementation work.

This change updates the interface, migrator logic,  and mysql implementation
to clean up the implementation.
  • Loading branch information
dotvezz authored Dec 6, 2020
1 parent dd766c2 commit c866023
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 28 deletions.
2 changes: 1 addition & 1 deletion internal/dbms/base/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (d *Base) TypeString(dt datatype.Datatype) (string, error) {
return s, nil
}

func (d *Base) AddReference(table, reference string, db schema.Database, i schema.Reference) (string, error) {
func (d *Base) AddReference(table, reference string, rt schema.Table, i schema.Reference) (string, error) {
panic("implement me")
}

Expand Down
2 changes: 1 addition & 1 deletion internal/dbms/base/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestBase_Panics(t *testing.T) {

t.Run("AddReference", func(t *testing.T) {
panicked := didPanic(func() {
_, _ = b.AddReference("", "", schema.Database{}, schema.Reference{})
_, _ = b.AddReference("", "", schema.Table{}, schema.Reference{})
})
if !panicked {
t.Errorf("Expected a panic but didn't see one")
Expand Down
34 changes: 12 additions & 22 deletions internal/dbms/mysql/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,17 @@ func (m *migrator) AddIndex(tName, iName string, i schema.Index) string {
return fmt.Sprintf("ALTER TABLE `%s` ADD %s `%s` (%s);", tName, indexType, iName, cols.String())
}

func (m *migrator) AddReference(table, referencedTable string, db schema.Database, r schema.Reference) (string, error) {
sw := strings.Builder{}

if r.HasMany { // swap the tables if it's a HasMany
table, referencedTable = referencedTable, table
}

rt, ok := db.Tables[referencedTable]
if !ok { // This should technically be caught by validation, but still
return "", fmt.Errorf("referenced table `%s` does not exist in dbms definition", referencedTable)
}
func (m *migrator) AddReference(table, referencedTable string, refTable schema.Table, ref schema.Reference) (string, error) {
var (
fcols []string
fknames []string
fkname string
refColNames = r.ColumnNames
hasPK bool
refColNames = ref.ColumnNames
sw = strings.Builder{}
)

var hasPK bool

for cname, col := range rt.Columns {
for cname, col := range refTable.Columns {
if !col.PrimaryKey {
continue
}
Expand All @@ -127,8 +117,8 @@ func (m *migrator) AddReference(table, referencedTable string, db schema.Databas
switch {
case len(refColNames) > 0:
fkname, refColNames = refColNames[0], refColNames[1:len(refColNames)]
case len(r.ColumnName) > 0:
fkname = r.ColumnName
case len(ref.ColumnName) > 0:
fkname = ref.ColumnName
default:
fkname = fmt.Sprintf("fk_%s_%s", referencedTable, cname)
}
Expand All @@ -140,7 +130,7 @@ func (m *migrator) AddReference(table, referencedTable string, db schema.Databas
sw.WriteRune('\n')
}

if len(r.ColumnNames) > 0 && len(r.ColumnNames) != len(fcols) {
if len(ref.ColumnNames) > 0 && len(ref.ColumnNames) != len(fcols) {
return "", fmt.Errorf("cannot add reference from `%s` to `%s`: length of column_names does not match length of primary keys", table, referencedTable)
}

Expand All @@ -152,12 +142,12 @@ func (m *migrator) AddReference(table, referencedTable string, db schema.Databas
table, referencedTable, strings.Join(fknames, ", "), referencedTable, strings.Join(fcols, ", "),
))

if r.OnDelete != "" {
sw.WriteString(fmt.Sprintf(" ON DELETE %s", r.OnDelete))
if ref.OnDelete != "" {
sw.WriteString(fmt.Sprintf(" ON DELETE %s", ref.OnDelete))
}

if r.OnUpdate != "" {
sw.WriteString(fmt.Sprintf(" ON UPDATE %s", r.OnUpdate))
if ref.OnUpdate != "" {
sw.WriteString(fmt.Sprintf(" ON UPDATE %s", ref.OnUpdate))
}

return sw.String(), nil
Expand Down
2 changes: 1 addition & 1 deletion internal/dbms/mysql/reverser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ func Test_reverser_GetReference(t *testing.T) {
t.Errorf("ListTables() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr == "" && !reflect.DeepEqual(got, tt.want) {
if tt.wantErr == "" && !reflect.DeepEqual(got, tt.want) {
t.Errorf("ListTables() got = %v, want %v", got, tt.want)
}
})
Expand Down
2 changes: 1 addition & 1 deletion internal/migration/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Dialect interface {
AddIndex(table, index string, i schema.Index) string

// AddReference returns a string query which adds the specified index to a table
AddReference(table, referencedTable string, db schema.Database, i schema.Reference) (string, error)
AddReference(table, referencedTable string, dt schema.Table, i schema.Reference) (string, error)
}

func LoadDialect(name string) (d Dialect, err error) {
Expand Down
11 changes: 10 additions & 1 deletion internal/migration/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ func NewRefAdder(
) RefGenerator {
return func(localTable string, refs map[string]schema.Reference, sw io.StringWriter) error {
for foreignTable, ref := range refs {
if ref.HasMany { // swap the tables if it's a HasMany
localTable, foreignTable = foreignTable, localTable
}
ft, ok := db.Tables[foreignTable]

if options&AddMissing > 0 {
exists, err := hasReference(localTable, foreignTable)
if err != nil {
Expand All @@ -157,7 +162,11 @@ func NewRefAdder(
continue
}
}
s, err := d.AddReference(localTable, foreignTable, db, ref)

if !ok { // This should technically be caught by validation, but still
return fmt.Errorf("referenced table `%s` does not exist in dbms definition", foreignTable)
}
s, err := d.AddReference(localTable, foreignTable, ft, ref)
if err != nil {
return fmt.Errorf("unable to generate migration: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/migration/generator_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (d *mockDialect) DataSourceName(host, username, schema, password, port stri
return ""
}

func (d *mockDialect) AddReference(table, referencedTable string, db schema.Database, i schema.Reference) (string, error) {
func (d *mockDialect) AddReference(table, referencedTable string, rt schema.Table, i schema.Reference) (string, error) {
return "", nil
}

Expand Down

0 comments on commit c866023

Please sign in to comment.