Skip to content

Commit

Permalink
Merge pull request #604 from kinvolk/migrate_down
Browse files Browse the repository at this point in the history
backend: Add migration down support
  • Loading branch information
joaquimrocha authored May 30, 2022
2 parents eaa4ce0 + 66252ad commit e4f2e06
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 21 deletions.
2 changes: 1 addition & 1 deletion backend/cmd/initdb/initdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func main() {
if _, err := api.New(api.OptionInitDB); err != nil {
if _, err := api.NewWithMigrations(api.OptionInitDB); err != nil {
log.Fatal(err)
}
}
16 changes: 15 additions & 1 deletion backend/cmd/nebraska/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,22 @@ func main() {
log.Fatalf("Config is invaliad, err: %w\n", err)
}

if conf.RollbackDBTo != "" {
db, err := db.New()
if err != nil {
log.Fatal("DB connection err:", err)
}

count, err := db.MigrateDown(conf.RollbackDBTo)
if err != nil {
log.Fatal("DB migration down err:", err)
}
log.Infof("DB migration down successful, migrated %d levels down", count)
return
}

// create new DB
db, err := db.New()
db, err := db.NewWithMigrations()
if err != nil {
log.Fatal("DB connection err:", err)
}
Expand Down
85 changes: 76 additions & 9 deletions backend/pkg/api/api.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package api

import (
"database/sql"
"errors"
"fmt"
"os"
"strconv"

//register "pgx" sql driver
"github.com/doug-martin/goqu/v9"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/jmoiron/sqlx"
migrate "github.com/rubenv/sql-migrate"
Expand All @@ -14,7 +18,6 @@ import (
// Postgresql driver
_ "github.com/lib/pq"

"strconv"
"time"
)

Expand Down Expand Up @@ -55,6 +58,8 @@ var (
ErrArchMismatch = errors.New("nebraska: mismatched arches")
)

const migrationsTable = "database_migrations"

// API represents an api instance used to interact with Nebraska entities.
type API struct {
db *sqlx.DB
Expand All @@ -66,8 +71,7 @@ type API struct {
disableUpdatesOnFailedRollout bool
}

// New creates a new API instance, creating the underlying db connection and
// applying db migrations available.
// New creates a new API instance, creates the underlying db connection.
func New(options ...func(*API) error) (*API, error) {
api := &API{
dbDriver: "pgx",
Expand Down Expand Up @@ -117,13 +121,20 @@ func New(options ...func(*API) error) (*API, error) {
return nil, err
}
}
return api, nil
}

migrate.SetTable("database_migrations")
migrations := &migrate.AssetMigrationSource{
Asset: Asset,
AssetDir: AssetDir,
Dir: "db/migrations",
// NewWithMigrations creates a new API instance, creates the underlying db connection and
// applies all available db migrations.
func NewWithMigrations(options ...func(*API) error) (*API, error) {
api, err := New(options...)
if err != nil {
return nil, err
}

migrate.SetTable(migrationsTable)
migrations := migrationAssets()

if _, err := migrate.Exec(api.db.DB, "postgres", migrations, migrate.Up); err != nil {
return nil, err
}
Expand All @@ -133,6 +144,62 @@ func New(options ...func(*API) error) (*API, error) {
return api, nil
}

type migration struct {
ID string `db:"id"`
AppliedAt time.Time `db:"applied_at"`
}

func (api *API) MigrateDown(version string) (int, error) {

migrate.SetTable(migrationsTable)
migrations := migrationAssets()

// find version based on input string
query, _, err := goqu.Select("*").From(migrationsTable).Where(goqu.C("id").Like(fmt.Sprintf("%s%%", version))).ToSQL()
if err != nil {
return 0, err
}

var mig migration
err = api.db.QueryRowx(query).StructScan(&mig)
if err != nil {
if err == sql.ErrNoRows {
return 0, fmt.Errorf("no migrations found for: %s, err: %v", version, err)
}
return 0, err
}

// find count of migrations that have been applied after the version
query, _, err = goqu.Select(goqu.COUNT("*")).From(migrationsTable).Where(goqu.C("applied_at").Gt(mig.AppliedAt)).ToSQL()
if err != nil {
return 0, err
}

countMap := make(map[string]interface{})

err = api.db.QueryRowx(query).MapScan(countMap)
if err != nil {
return 0, err
}

levels := countMap["count"].(int64)
logger.Info().Msgf("migrating down %d levels", levels)
count, err := migrate.ExecMax(api.db.DB, "postgres", migrations, migrate.Down, int(levels))
if err != nil {
return 0, err
}
logger.Info().Msg("successfully migrated down")
return count, nil
}

func migrationAssets() *migrate.AssetMigrationSource {
return &migrate.AssetMigrationSource{
Asset: Asset,
AssetDir: AssetDir,
Dir: "db/migrations",
}
}

// OptionInitDB will initialize the database during the API instance creation,
// dropping all existing tables, which will force all migration scripts to be
// re-executed. Use with caution, this will DESTROY ALL YOUR DATA.
Expand Down Expand Up @@ -166,7 +233,7 @@ func (api *API) Close() {
// NewForTest creates a new API instance with given options and fills
// the database with sample data for testing purposes.
func NewForTest(options ...func(*API) error) (*API, error) {
a, err := New(options...)
a, err := NewWithMigrations(options...)
if err != nil {
return nil, err
}
Expand Down
32 changes: 31 additions & 1 deletion backend/pkg/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"os"
"testing"

"github.com/doug-martin/goqu/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -31,7 +33,7 @@ func TestMain(m *testing.M) {
_ = os.Setenv("NEBRASKA_DB_URL", defaultTestDbURL)
}

a, err := New(OptionInitDB)
a, err := NewWithMigrations(OptionInitDB)
if err != nil {
log.Printf("Failed to init DB: %v\n", err)
log.Println("These tests require PostgreSQL running and a tests database created, please adjust NEBRASKA_DB_URL as needed.")
Expand All @@ -41,3 +43,31 @@ func TestMain(m *testing.M) {

os.Exit(m.Run())
}

func TestMigrateDown(t *testing.T) {

// Create New DB
db, err := NewWithMigrations(OptionInitDB)
require.NoError(t, err)
defer db.Close()

_, err = db.MigrateDown("0004")
require.NoError(t, err)

query, _, err := goqu.Select("*").From(migrationsTable).ToSQL()
require.NoError(t, err)

var migrations []migration
rows, err := db.db.Queryx(query)
require.NoError(t, err)

defer rows.Close()
for rows.Next() {
var mig migration
err := rows.StructScan(&mig)
require.NoError(t, err)
migrations = append(migrations, mig)
}

assert.Equal(t, 4, len(migrations))
}
16 changes: 8 additions & 8 deletions backend/pkg/api/bindata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e4f2e06

Please sign in to comment.