diff --git a/ddlmod.go b/ddlmod.go index 7cb1642..c8db01f 100644 --- a/ddlmod.go +++ b/ddlmod.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "regexp" + "strconv" "strings" "gorm.io/gorm/migrator" @@ -12,12 +13,13 @@ import ( var ( sqliteSeparator = "`|\"|'|\t" - indexRegexp = regexp.MustCompile(fmt.Sprintf("CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator)) - tableRegexp = regexp.MustCompile(fmt.Sprintf("(?i)(CREATE TABLE [%v]?[\\w\\d]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator)) + indexRegexp = regexp.MustCompile(fmt.Sprintf("CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d-]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator)) + tableRegexp = regexp.MustCompile(fmt.Sprintf("(?i)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator)) separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) columnsRegexp = regexp.MustCompile(fmt.Sprintf("\\([%v]?([\\w\\d]+)[%v]?(?:,[%v]?([\\w\\d]+)[%v]){0,}\\)", sqliteSeparator, sqliteSeparator, sqliteSeparator, sqliteSeparator)) columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator)) defaultValueRegexp = regexp.MustCompile("(?i) DEFAULT \\(?(.+)?\\)?( |COLLATE|GENERATED|$)") + regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) ) type ddl struct { @@ -116,7 +118,7 @@ func parseDDL(strs ...string) (*ddl, error) { PrimaryKeyValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, NullableValue: sql.NullBool{Valid: true}, - DefaultValueValue: sql.NullString{Valid: true}, + DefaultValueValue: sql.NullString{Valid: false}, } matchUpper := strings.ToUpper(matches[3]) @@ -135,6 +137,14 @@ func parseDDL(strs ...string) (*ddl, error) { columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true} } + // data type length + matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1) + if len(matches) == 1 && len(matches[0]) == 2 { + size, _ := strconv.Atoi(matches[0][1]) + columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)} + columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0]) + } + result.columns = append(result.columns, columnType) } } diff --git a/ddlmod_test.go b/ddlmod_test.go index 2439640..5e7a840 100644 --- a/ddlmod_test.go +++ b/ddlmod_test.go @@ -4,8 +4,8 @@ import ( "database/sql" "testing" - "github.com/stretchr/testify/assert" "gorm.io/gorm/migrator" + "gorm.io/gorm/utils/tests" ) func TestParseDDL(t *testing.T) { @@ -19,29 +19,48 @@ func TestParseDDL(t *testing.T) { "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))", "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", }, 6, []migrator.ColumnType{ - {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: true}}, - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, {"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{ - {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }}, {"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{ - {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, {"no brackets", []string{"create table test"}, 0, nil}, {"with_special_characters", []string{ "CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试,\")", }, 1, []migrator.ColumnType{ - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试,", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试,", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, + { + "table_name_with_dash", + []string{ + "CREATE TABLE `test-a` (`id` int NOT NULL)", + "CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)", + }, + 1, + []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "id", Valid: true}, + DataTypeValue: sql.NullString{String: "int", Valid: true}, + ColumnTypeValue: sql.NullString{String: "int", Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + DefaultValueValue: sql.NullString{Valid: false}, + UniqueValue: sql.NullBool{Valid: true}, + PrimaryKeyValue: sql.NullBool{Valid: true}, + }, + }, + }, } for _, p := range params { @@ -52,9 +71,11 @@ func TestParseDDL(t *testing.T) { panic(err.Error()) } - assert.Equal(t, p.sql[0], ddl.compile()) - assert.Len(t, ddl.fields, p.nFields) - assert.Equal(t, ddl.columns, p.columns) + tests.AssertEqual(t, p.sql[0], ddl.compile()) + if len(ddl.fields) != p.nFields { + t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields)) + } + tests.AssertEqual(t, ddl.columns, p.columns) }) } } @@ -122,7 +143,7 @@ func TestAddConstraint(t *testing.T) { testDDL := ddl{fields: p.fields} testDDL.addConstraint(p.cName, p.sql) - assert.Equal(t, p.expect, testDDL.fields) + tests.AssertEqual(t, p.expect, testDDL.fields) }) } } @@ -164,8 +185,8 @@ func TestRemoveConstraint(t *testing.T) { success := testDDL.removeConstraint(p.cName) - assert.Equal(t, p.success, success) - assert.Equal(t, p.expect, testDDL.fields) + tests.AssertEqual(t, p.success, success) + tests.AssertEqual(t, p.expect, testDDL.fields) }) } } @@ -202,7 +223,7 @@ func TestGetColumns(t *testing.T) { cols := testDDL.getColumns() - assert.Equal(t, p.columns, cols) + tests.AssertEqual(t, p.columns, cols) }) } } diff --git a/go.mod b/go.mod index daeef5e..9cf9eb3 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,7 @@ module gorm.io/driver/sqlite go 1.14 require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-sqlite3 v1.14.12 - github.com/stretchr/testify v1.7.0 gorm.io/gorm v1.23.4 ) diff --git a/go.sum b/go.sum index c9f46fb..98e3d78 100644 --- a/go.sum +++ b/go.sum @@ -1,24 +1,9 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= -github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ1Z0= github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/gorm v1.23.4 h1:1BKWM67O6CflSLcwGQR7ccfmC4ebOxQrTfOQGRE9wjg= gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= diff --git a/migrator.go b/migrator.go index 641b1f4..9d11cfc 100644 --- a/migrator.go +++ b/migrator.go @@ -80,12 +80,17 @@ func (m Migrator) AlterColumn(value interface{}, name string) error { return m.RunWithoutForeignKey(func() error { return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { if field := stmt.Schema.LookUpField(name); field != nil { - reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?,") + // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)` + reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)") if err != nil { return "", nil, err } - createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?,", field.DBName)) + createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName)) + + if createSQL == rawDDL { + return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL) + } return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil }