diff --git a/tapdb/migrations.go b/tapdb/migrations.go index a19ecbea2..3f6c96642 100644 --- a/tapdb/migrations.go +++ b/tapdb/migrations.go @@ -7,11 +7,13 @@ import ( "io" "io/fs" "net/http" + "os" "strings" "github.com/btcsuite/btclog" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/source" "github.com/golang-migrate/migrate/v4/source/httpfs" "github.com/lightninglabs/taproot-assets/fn" ) @@ -26,20 +28,29 @@ const ( ) // MigrationTarget is a functional option that can be passed to applyMigrations -// to specify a target version to migrate to. -type MigrationTarget func(mig *migrate.Migrate) error +// to specify a target version to migrate to. `currentDbVersion` is the current +// (migration) version of the database, or None if unknown. +// `maxMigrationVersion` is the maximum migration version known to the driver, +// or None if unknown. +type MigrationTarget func(mig *migrate.Migrate, + currentDbVersion fn.Option[uint], + maxMigrationVersion fn.Option[uint]) error var ( // TargetLatest is a MigrationTarget that migrates to the latest // version available. - TargetLatest = func(mig *migrate.Migrate) error { + TargetLatest = func(mig *migrate.Migrate, _ fn.Option[uint], + _ fn.Option[uint]) error { + return mig.Up() } // TargetVersion is a MigrationTarget that migrates to the given // version. TargetVersion = func(version uint) MigrationTarget { - return func(mig *migrate.Migrate) error { + return func(mig *migrate.Migrate, _ fn.Option[uint], + _ fn.Option[uint]) error { + return mig.Migrate(version) } } @@ -109,6 +120,31 @@ func (m *migrationLogger) Verbose() bool { return m.log.Level() <= btclog.LevelDebug } +// maxKnownMigrationVersion returns the maximum migration version known to the +// given source driver. +func maxKnownMigrationVersion(sourceDriver source.Driver) (fn.Option[uint], + error) { + + var ( + // We start at version 1, if we start with 0 the next version + // is not returned correctly. This project is already beyond + // version 0, so this should be fine. + currentVersion uint = 1 + ) + for { + nextVersion, err := sourceDriver.Next(currentVersion) + switch { + case errors.Is(err, os.ErrNotExist): + return fn.Some(currentVersion), nil + case err != nil: + return fn.None[uint](), err + } + + // Set the current version to the next version. + currentVersion = nextVersion + } +} + // applyMigrations executes database migration files found in the given file // system under the given path, using the passed database driver and database // name, up to or down to the given target version. @@ -146,17 +182,43 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string, ErrMigrationDowngrade, migrationVersion, latestVersion) } - log.Infof("Applying migrations from version=%v", migrationVersion) + // Compute the max migration version known to the driver. + maxMigrationVersion, err := maxKnownMigrationVersion(migrateFileServer) + if err != nil { + return fmt.Errorf("failed to determine max migration "+ + "version: %v", err) + } + + // Report the current version of the database before the migration. + currentDbVersion, _, err := driver.Version() + if err != nil { + return fmt.Errorf("unable to get current db version: %w", err) + } + log.Infof("Attempting to apply migration(s) "+ + "(current_db_version=%v, max_migration_version=%v)", + currentDbVersion, maxMigrationVersion) + + var currentDbV fn.Option[uint] + if currentDbVersion >= 0 { + currentDbV = fn.Some(uint(currentDbVersion)) + } // Apply our local logger to the migration instance. sqlMigrate.Log = &migrationLogger{log} // Execute the migration based on the target given. - err = targetVersion(sqlMigrate) + err = targetVersion(sqlMigrate, currentDbV, maxMigrationVersion) if err != nil && !errors.Is(err, migrate.ErrNoChange) { return err } + // Report the current version of the database after the migration. + currentDbVersion, _, err = driver.Version() + if err != nil { + return fmt.Errorf("unable to get current db version: %w", err) + } + log.Infof("Database version after migration: %v", currentDbVersion) + return nil }