diff --git a/commands/generate.go b/commands/generate.go index 1daa99f..7c2699c 100644 --- a/commands/generate.go +++ b/commands/generate.go @@ -35,7 +35,7 @@ func GenerateCmd() *cobra.Command { func generateCmdF(cmd *cobra.Command, args []string) { dir, _ := cmd.Flags().GetString("dir") driver, _ := cmd.Flags().GetString("driver") - extention := getExtension(driver) + extension := getExtension(driver) fileName := args[0] if ts, _ := cmd.Flags().GetBool("timestamp"); ts { @@ -57,7 +57,7 @@ func generateCmdF(cmd *cobra.Command, args []string) { fileName = strings.Join([]string{date.Format(tf), fileName}, "_") } } else if seq, _ := cmd.Flags().GetBool("sequence"); seq { - next, err := sequelNumber(dir, extention) + next, err := sequelNumber(dir, extension, driver) if err != nil { cmd.PrintErrln(err) return @@ -80,7 +80,7 @@ func generateCmdF(cmd *cobra.Command, args []string) { migrations := []string{"down", "up"} for _, migration := range migrations { - filePath := strings.Join([]string{filepath.Join(dir, fileName), migration, extention}, ".") + filePath := strings.Join([]string{filepath.Join(dir, fileName), migration, extension}, ".") f, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666) if err != nil { cmd.PrintErrln(err) @@ -90,8 +90,8 @@ func generateCmdF(cmd *cobra.Command, args []string) { } } -func sequelNumber(dir, extention string) (int, error) { - matches, err := filepath.Glob(filepath.Join(dir, "*"+extention)) +func sequelNumber(dir, extension, driver string) (int, error) { + matches, err := filepath.Glob(filepath.Join(dir, driver, "*"+extension)) if err != nil { return 0, err } @@ -119,7 +119,7 @@ func sequelNumber(dir, extention string) (int, error) { func getExtension(driver string) string { switch driver { - case "postgresql", "mysql": + case "postgres", "mysql": return "sql" default: return "txt" diff --git a/commands/generate_test.go b/commands/generate_test.go new file mode 100644 index 0000000..3373370 --- /dev/null +++ b/commands/generate_test.go @@ -0,0 +1,97 @@ +package commands + +import ( + "errors" + "log" + "os" + "path/filepath" + "testing" + + "github.com/go-morph/morph/commands/testlib" + "github.com/stretchr/testify/require" +) + +func TestGenerateCMD(t *testing.T) { + dir := "./tmp" + cmd := GenerateCmd() + cmd.PersistentFlags().String("dir", ".", "the migrations directory") + + t.Run("should generate migration files correctly when using --sequence", func(t *testing.T) { + name := "create_saiyans" + + defer func() { + err := os.RemoveAll(dir) + if err != nil { + log.Fatal(err) + } + }() + + // ensure that directory doesn't exist + _, err := os.Stat("./tmp") + require.Equal(t, errors.Is(err, os.ErrNotExist), true) + + cases := []struct { + driver string + args []string + sequence string + }{ + { + driver: "postgres", + args: []string{name, "--dir", dir, "--sequence"}, + sequence: "000001", + }, + { + driver: "mysql", + args: []string{name, "--dir", dir, "--sequence"}, + sequence: "000001", + }, + { + driver: "postgres", + args: []string{name, "--dir", dir, "--sequence"}, + sequence: "000002", + }, + { + driver: "mysql", + args: []string{name, "--dir", dir, "--sequence"}, + sequence: "000002", + }, + { + driver: "postgres", + args: []string{name, "--dir", dir, "-s"}, + sequence: "000003", + }, + { + driver: "mysql", + args: []string{name, "--dir", dir, "-s"}, + sequence: "000003", + }, + } + + for _, tc := range cases { + args := append(tc.args, "--driver", tc.driver) + + _, err := testlib.ExecuteCommand(t, cmd, args...) + require.NoError(t, err) + + _, fErr := os.Stat(filepath.Join("./tmp", tc.driver, tc.sequence+"_"+name+".down.sql")) + require.NoError(t, fErr) + + _, fErr = os.Stat(filepath.Join("./tmp/", tc.driver, tc.sequence+"_"+name+".up.sql")) + require.NoError(t, fErr) + } + }) + + t.Run("should correctly return extension", func(t *testing.T) { + ext := getExtension("postgres") + require.Equal(t, ext, "sql") + + ext = getExtension("mysql") + require.Equal(t, ext, "sql") + + ext = getExtension("postgresql") + require.Equal(t, ext, "txt") + + ext = getExtension("mysqlite") + require.Equal(t, ext, "txt") + }) +} diff --git a/commands/testlib/testing.go b/commands/testlib/testing.go new file mode 100644 index 0000000..00b870b --- /dev/null +++ b/commands/testlib/testing.go @@ -0,0 +1,21 @@ +package testlib + +import ( + "bytes" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +func ExecuteCommand(t *testing.T, c *cobra.Command, args ...string) (string, error) { + t.Helper() + + buf := new(bytes.Buffer) + c.SetOut(buf) + c.SetErr(buf) + c.SetArgs(args) + + err := c.Execute() + return strings.TrimSpace(buf.String()), err +}