diff --git a/go.mod b/go.mod index 5fc96c4..62bca5f 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,12 @@ require ( gopkg.in/yaml.v2 v2.4.0 ) +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect @@ -42,6 +48,7 @@ require ( github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/shopspring/decimal v1.3.1 // indirect github.com/spf13/cast v1.5.0 // indirect + github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/go.sum b/go.sum index e338ef3..1bbc785 100644 --- a/go.sum +++ b/go.sum @@ -116,8 +116,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/migrate.go b/migrate.go index 7fb56f1..38d7b59 100644 --- a/migrate.go +++ b/migrate.go @@ -42,6 +42,8 @@ type MigrationSet struct { IgnoreUnknown bool // DisableCreateTable disable the creation of the migration table DisableCreateTable bool + // Limits precision of time values inserted into the database to 6 milliseconds + LimitTimePrecision bool } var migSet = MigrationSet{} @@ -111,6 +113,11 @@ func SetSchema(name string) { } } +// Resets schema to empty +func ResetSchema() { + migSet.SchemaName = "" +} + // SetDisableCreateTable sets the boolean to disable the creation of the migration table func SetDisableCreateTable(disable bool) { migSet.DisableCreateTable = disable @@ -124,6 +131,12 @@ func SetIgnoreUnknown(v bool) { migSet.IgnoreUnknown = v } +// LimitTimePrecision limits the precision of time values inserted into the database +// to be no greater than 6 milliseconds +func LimitTimePrecision(v bool) { + migSet.LimitTimePrecision = v +} + type Migration struct { Id string Up []string @@ -563,9 +576,13 @@ func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, switch dir { case Up: + appliedAt := time.Now() + if migSet.LimitTimePrecision { + appliedAt = appliedAt.Round(6 * time.Millisecond) + } err = executor.Insert(&MigrationRecord{ Id: migration.Id, - AppliedAt: time.Now(), + AppliedAt: appliedAt, }) if err != nil { if trans, ok := executor.(*gorp.Transaction); ok { @@ -743,9 +760,13 @@ func SkipMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirecti } } + appliedAt := time.Now() + if migSet.LimitTimePrecision { + appliedAt = appliedAt.Round(6 * time.Millisecond) + } err = executor.Insert(&MigrationRecord{ Id: migration.Id, - AppliedAt: time.Now(), + AppliedAt: appliedAt, }) if err != nil { if trans, ok := executor.(*gorp.Transaction); ok { diff --git a/migrate_test.go b/migrate_test.go index f1d66d6..d3af2af 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -96,6 +96,10 @@ func (s *SqliteMigrateSuite) TestMigrateIncremental(c *C) { Migrations: sqliteMigrations[:1], } + LimitTimePrecision(true) + defer func() { + LimitTimePrecision(false) + }() // Executes one migration n, err := Exec(s.Db, "sqlite3", migrations, Up) c.Assert(err, IsNil) @@ -384,6 +388,10 @@ func (s *SqliteMigrateSuite) TestSkipMigration(c *C) { }, }, } + LimitTimePrecision(true) + defer func() { + LimitTimePrecision(false) + }() n, err := SkipMax(s.Db, "sqlite3", migrations, Up, 0) // there should be no errors c.Assert(err, IsNil) diff --git a/sql-migrate/config.go b/sql-migrate/config.go index 6947a53..5ffd1c6 100644 --- a/sql-migrate/config.go +++ b/sql-migrate/config.go @@ -35,12 +35,13 @@ func ConfigFlags(f *flag.FlagSet) { } type Environment struct { - Dialect string `yaml:"dialect"` - DataSource string `yaml:"datasource"` - Dir string `yaml:"dir"` - TableName string `yaml:"table"` - SchemaName string `yaml:"schema"` - IgnoreUnknown bool `yaml:"ignoreunknown"` + Dialect string `yaml:"dialect"` + DataSource string `yaml:"datasource"` + Dir string `yaml:"dir"` + TableName string `yaml:"table"` + SchemaName string `yaml:"schema"` + IgnoreUnknown bool `yaml:"ignoreunknown"` + LimitTimePrecision bool `yaml:"limitprecision"` } func ReadConfig() (map[string]*Environment, error) { @@ -66,15 +67,15 @@ func GetEnvironment() (*Environment, error) { env := config[ConfigEnvironment] if env == nil { - return nil, errors.New("No environment: " + ConfigEnvironment) + return nil, errors.New("no environment: " + ConfigEnvironment) } if env.Dialect == "" { - return nil, errors.New("No dialect specified") + return nil, errors.New("no dialect specified") } if env.DataSource == "" { - return nil, errors.New("No data source specified") + return nil, errors.New("no data source specified") } env.DataSource = os.ExpandEnv(env.DataSource) @@ -92,19 +93,21 @@ func GetEnvironment() (*Environment, error) { migrate.SetIgnoreUnknown(env.IgnoreUnknown) + migrate.LimitTimePrecision(env.LimitTimePrecision) + return env, nil } func GetConnection(env *Environment) (*sql.DB, string, error) { db, err := sql.Open(env.Dialect, env.DataSource) if err != nil { - return nil, "", fmt.Errorf("Cannot connect to database: %w", err) + return nil, "", fmt.Errorf("cannot connect to database: %w", err) } // Make sure we only accept dialects that were compiled in. _, exists := dialects[env.Dialect] if !exists { - return nil, "", fmt.Errorf("Unsupported dialect: %s", env.Dialect) + return nil, "", fmt.Errorf("unsupported dialect: %s", env.Dialect) } return db, env.Dialect, nil diff --git a/sql-migrate/config_test.go b/sql-migrate/config_test.go new file mode 100644 index 0000000..dbe0150 --- /dev/null +++ b/sql-migrate/config_test.go @@ -0,0 +1,334 @@ +package main + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func generateFile(lines []string) (string, error) { + file, err := os.CreateTemp("", "") + if err != nil { + return "", err + } + defer func() { + file.Close() + }() + + for _, l := range lines { + file.WriteString(fmt.Sprintf("%s\n", l)) + } + + return file.Name(), nil +} + +func cleanup(filePath string) error { + err := os.Remove(filePath) + if err != nil { + if !strings.Contains(err.Error(), "no such file or directory") { + return err + } + } + + return nil +} + +func TestReadConfig(t *testing.T) { + // Bad lines + lines := []string{ + "development:", + " dialect: postgres", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + "", + "docker:", + " dialect: postgres", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", // Bad line + " table: migrations", + " limitprecision: false", + } + configFile, err := generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + _, err = ReadConfig() + require.Error(t, err) + + // Good config + lines = []string{ + "development:", + " dialect: postgres", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + "", + "docker:", + " dialect: postgres", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + envMap, err := ReadConfig() + if err != nil { + t.Error(err) + } + + require.Equal(t, 2, len(envMap)) +} + +func TestGetEnvironment(t *testing.T) { + // Bad lines - can't find key + lines := []string{ + "development:", + " dialect: postgres", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + "", + "docker:", + " dialect: postgres", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", // Bad line + " table: migrations", + " limitprecision: false", + } + configFile, err := generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + _, err = GetEnvironment() + require.ErrorContains(t, err, "yaml: line 12: did not find expected key") + + // Parseable - unmatched environment + lines = []string{ + "development:", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + "", + "docker:", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + ConfigEnvironment = "foobar" + _, err = GetEnvironment() + require.ErrorContains(t, err, "no environment: foobar") + + // Error - missing dialect + lines = []string{ + "development:", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + "", + "docker:", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + ConfigEnvironment = "development" + _, err = GetEnvironment() + require.ErrorContains(t, err, "no dialect") + + // Error - missing datasource + lines = []string{ + "development:", + " dialect: postgres", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + "", + "docker:", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: false", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + ConfigEnvironment = "development" + _, err = GetEnvironment() + require.ErrorContains(t, err, "no data source") + + // No migration dir + lines = []string{ + "development:", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dialect: postgres", + " schema: public", + " table: migrations", + "", + "docker:", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + ConfigEnvironment = "development" + env, err := GetEnvironment() + require.Equal(t, "migrations", env.Dir) + + // Test setting table and absent limitprecision + lines = []string{ + "development:", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dialect: postgres", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + "", + "docker:", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + ConfigEnvironment = "development" + env, err = GetEnvironment() + require.Equal(t, "migrations", env.TableName) + require.False(t, env.LimitTimePrecision) + + // Test setting table and absent limitprecision + lines = []string{ + "development:", + " datasource: host=127.0.0.1 dbname=reporting user=root password=root sslmode=disable", + " dialect: postgres", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + " limitprecision: true", + "", + "docker:", + " datasource: host=portal_db dbname=reporting user=root password=root sslmode=disable", + " dir: my_db/migrations", + " schema: public", + " table: migrations", + } + configFile, err = generateFile(lines) + if err != nil { + t.Error(err) + } + defer func() { + err = cleanup(configFile) + if err != nil { + t.Error(err) + } + }() + + ConfigFile = configFile + ConfigEnvironment = "development" + env, err = GetEnvironment() + require.True(t, env.LimitTimePrecision) +}