Skip to content

Commit

Permalink
Merge pull request #156 from marema31/MigrationSet
Browse files Browse the repository at this point in the history
Add type MigrationSet to allow concurrent usage of sql-migrate
  • Loading branch information
rubenv authored Nov 20, 2019
2 parents ce2300b + cac4cff commit 7c4b0a9
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 11 deletions.
56 changes: 45 additions & 11 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,24 @@ const (
Down
)

var tableName = "gorp_migrations"
var schemaName = ""
// MigrationSet provides database parameters for a migration execution
type MigrationSet struct {
// TableName name of the table used to store migration info.
TableName string
// SchemaName schema that the migration table be referenced.
SchemaName string
}

var migSet = MigrationSet{}

// NewMigrationSet returns a parametrized Migration object
func (ms MigrationSet) getTableName() string {
if ms.TableName == "" {
return "gorp_migrations"
}
return ms.TableName
}

var numberPrefixRegex = regexp.MustCompile(`^(\d+).*$`)

// PlanError happens where no migration plan could be created between the sets
Expand Down Expand Up @@ -73,14 +89,14 @@ func (e *TxError) Error() string {
// Should be called before any other call such as (Exec, ExecMax, ...).
func SetTable(name string) {
if name != "" {
tableName = name
migSet.TableName = name
}
}

// SetSchema sets the name of a schema that the migration table be referenced.
func SetSchema(name string) {
if name != "" {
schemaName = name
migSet.SchemaName = name
}
}

Expand Down Expand Up @@ -366,13 +382,23 @@ func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection)
return ExecMax(db, dialect, m, dir, 0)
}

// Returns the number of applied migrations.
func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMax(db, dialect, m, dir, 0)
}

// Execute a set of migrations
//
// Will apply at most `max` migrations. Pass 0 for no limit (or use Exec).
//
// Returns the number of applied migrations.
func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
migrations, dbMap, err := PlanMigration(db, dialect, m, dir, max)
return migSet.ExecMax(db, dialect, m, dir, max)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -443,7 +469,11 @@ func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirecti

// Plan a migration.
func PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) {
dbMap, err := getMigrationDbMap(db, dialect)
return migSet.PlanMigration(db, dialect, m, dir, max)
}

func (ms MigrationSet) PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) {
dbMap, err := ms.getMigrationDbMap(db, dialect)
if err != nil {
return nil, nil, err
}
Expand All @@ -454,7 +484,7 @@ func PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationD
}

var migrationRecords []MigrationRecord
_, err = dbMap.Select(&migrationRecords, fmt.Sprintf("SELECT * FROM %s", dbMap.Dialect.QuotedTableForQuery(schemaName, tableName)))
_, err = dbMap.Select(&migrationRecords, fmt.Sprintf("SELECT * FROM %s", dbMap.Dialect.QuotedTableForQuery(ms.SchemaName, ms.getTableName())))
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -621,13 +651,17 @@ func ToCatchup(migrations, existingMigrations []*Migration, lastRun *Migration)
}

func GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) {
dbMap, err := getMigrationDbMap(db, dialect)
return migSet.GetMigrationRecords(db, dialect)
}

func (ms MigrationSet) GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) {
dbMap, err := ms.getMigrationDbMap(db, dialect)
if err != nil {
return nil, err
}

var records []*MigrationRecord
query := fmt.Sprintf("SELECT * FROM %s ORDER BY id ASC", dbMap.Dialect.QuotedTableForQuery(schemaName, tableName))
query := fmt.Sprintf("SELECT * FROM %s ORDER BY id ASC", dbMap.Dialect.QuotedTableForQuery(ms.SchemaName, ms.getTableName()))
_, err = dbMap.Select(&records, query)
if err != nil {
return nil, err
Expand All @@ -636,7 +670,7 @@ func GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error)
return records, nil
}

func getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) {
func (ms MigrationSet) getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) {
d, ok := MigrationDialects[dialect]
if !ok {
return nil, fmt.Errorf("Unknown dialect: %s", dialect)
Expand Down Expand Up @@ -664,7 +698,7 @@ Check https://github.com/go-sql-driver/mysql#parsetime for more info.`)

// Create migration database map
dbMap := &gorp.DbMap{Db: db, Dialect: d}
dbMap.AddTableWithNameAndSchema(MigrationRecord{}, schemaName, tableName).SetKeys(false, "Id")
dbMap.AddTableWithNameAndSchema(MigrationRecord{}, ms.SchemaName, ms.getTableName()).SetKeys(false, "Id")
//dbMap.TraceOn("", log.New(os.Stdout, "migrate: ", log.Lmicroseconds))

err := dbMap.CreateTablesIfNotExists()
Expand Down
50 changes: 50 additions & 0 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,53 @@ func (s *SqliteMigrateSuite) TestExecWithUnknownMigrationInDatabase(c *C) {
_, err = s.DbMap.Exec("SELECT age FROM people")
c.Assert(err, NotNil)
}

func (s *SqliteMigrateSuite) TestRunMigrationObjDefaultTable(c *C) {
migrations := &MemoryMigrationSource{
Migrations: sqliteMigrations[:1],
}

ms := MigrationSet{}
// Executes one migration
n, err := ms.Exec(s.Db, "sqlite3", migrations, Up)
c.Assert(err, IsNil)
c.Assert(n, Equals, 1)

// Can use table now
_, err = s.DbMap.Exec("SELECT * FROM people")
c.Assert(err, IsNil)

// Uses default tableName
_, err = s.DbMap.Exec("SELECT * FROM gorp_migrations")
c.Assert(err, IsNil)

// Shouldn't apply migration again
n, err = ms.Exec(s.Db, "sqlite3", migrations, Up)
c.Assert(err, IsNil)
c.Assert(n, Equals, 0)
}

func (s *SqliteMigrateSuite) TestRunMigrationObjOtherTable(c *C) {
migrations := &MemoryMigrationSource{
Migrations: sqliteMigrations[:1],
}

ms := MigrationSet{TableName: "other_migrations"}
// Executes one migration
n, err := ms.Exec(s.Db, "sqlite3", migrations, Up)
c.Assert(err, IsNil)
c.Assert(n, Equals, 1)

// Can use table now
_, err = s.DbMap.Exec("SELECT * FROM people")
c.Assert(err, IsNil)

// Uses default tableName
_, err = s.DbMap.Exec("SELECT * FROM other_migrations")
c.Assert(err, IsNil)

// Shouldn't apply migration again
n, err = ms.Exec(s.Db, "sqlite3", migrations, Up)
c.Assert(err, IsNil)
c.Assert(n, Equals, 0)
}

0 comments on commit 7c4b0a9

Please sign in to comment.