diff --git a/Makefile b/Makefile index 65ccfbfc43..ec2bc419fd 100644 --- a/Makefile +++ b/Makefile @@ -78,7 +78,7 @@ perms-table: gen: cleangen proto api migrations fmt migrations: - $(MAKE) --environment-overrides -C internal/db/migrations/genmigrations migrations + $(MAKE) --environment-overrides -C internal/db/schema/migrations/generate migrations ### oplog requires protoc-gen-go v1.20.0 or later # GO111MODULE=on go get -u github.com/golang/protobuf/protoc-gen-go@v1.40 diff --git a/go.mod b/go.mod index c436e456ef..4ab6fbc445 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ replace github.com/hashicorp/boundary/sdk => ./sdk require ( github.com/armon/go-metrics v0.3.5 github.com/bufbuild/buf v0.33.0 + github.com/dhui/dktest v0.3.3 github.com/fatih/color v1.10.0 github.com/favadi/protoc-go-inject-tag v1.1.0 github.com/go-bindata/go-bindata/v3 v3.1.3 diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 464c2dff10..d2f62d6616 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/cmd/config" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/types/scope" @@ -442,7 +443,7 @@ func (b *Server) ConnectToDatabase(dialect string) error { return nil } -func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error { +func (b *Server) CreateDevDatabase(ctx context.Context, dialect string, opt ...Option) error { opts := getOpts(opt...) var container, url string @@ -470,17 +471,25 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error { return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err) } - _, err := db.InitStore(dialect, c, url) + _, err := schema.InitStore(ctx, dialect, url) if err != nil { - return fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err) + err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err) + if c != nil { + err = multierror.Append(err, c()) + } + return err } b.DevDatabaseCleanupFunc = c b.DatabaseUrl = url default: - if _, err := db.InitStore(dialect, c, b.DatabaseUrl); err != nil { - return fmt.Errorf("error initializing store: %w", err) + if _, err := schema.InitStore(ctx, dialect, b.DatabaseUrl); err != nil { + err = fmt.Errorf("error initializing store: %w", err) + if c != nil { + err = multierror.Append(err, c()) + } + return err } } @@ -492,16 +501,25 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error { } if err := b.ConnectToDatabase(dialect); err != nil { + if c != nil { + err = multierror.Append(err, c()) + } return err } b.Database.LogMode(true) - if err := b.CreateGlobalKmsKeys(context.Background()); err != nil { + if err := b.CreateGlobalKmsKeys(ctx); err != nil { + if c != nil { + err = multierror.Append(err, c()) + } return err } - if _, err := b.CreateInitialLoginRole(context.Background()); err != nil { + if _, err := b.CreateInitialLoginRole(ctx); err != nil { + if c != nil { + err = multierror.Append(err, c()) + } return err } @@ -512,7 +530,7 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error { return nil } - if _, _, err := b.CreateInitialAuthMethod(context.Background()); err != nil { + if _, _, err := b.CreateInitialAuthMethod(ctx); err != nil { return err } @@ -523,7 +541,7 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error { return nil } - if _, _, err := b.CreateInitialScopes(context.Background()); err != nil { + if _, _, err := b.CreateInitialScopes(ctx); err != nil { return err } @@ -545,7 +563,7 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error { return nil } - if _, err := b.CreateInitialTarget(context.Background()); err != nil { + if _, err := b.CreateInitialTarget(ctx); err != nil { return err } diff --git a/internal/cmd/commands.go b/internal/cmd/commands.go index 6f9f959586..320be4b89c 100644 --- a/internal/cmd/commands.go +++ b/internal/cmd/commands.go @@ -36,20 +36,14 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) { Commands = map[string]cli.CommandFactory{ "server": func() (cli.Command, error) { return &server.Command{ - Server: base.NewServer(&base.Command{ - UI: serverCmdUi, - ShutdownCh: base.MakeShutdownCh(), - }), + Server: base.NewServer(base.NewCommand(serverCmdUi)), SighupCh: MakeSighupCh(), SigUSR2Ch: MakeSigUSR2Ch(), }, nil }, "dev": func() (cli.Command, error) { return &dev.Command{ - Server: base.NewServer(&base.Command{ - UI: serverCmdUi, - ShutdownCh: base.MakeShutdownCh(), - }), + Server: base.NewServer(base.NewCommand(serverCmdUi)), SighupCh: MakeSighupCh(), SigUSR2Ch: MakeSigUSR2Ch(), }, nil diff --git a/internal/cmd/commands/database/init.go b/internal/cmd/commands/database/init.go index c275d319a9..92ce10a3c7 100644 --- a/internal/cmd/commands/database/init.go +++ b/internal/cmd/commands/database/init.go @@ -7,9 +7,7 @@ import ( "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" - "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/db/migrations" - "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/types/scope" "github.com/hashicorp/boundary/sdk/wrapper" wrapping "github.com/hashicorp/go-kms-wrapping" @@ -186,8 +184,10 @@ func (c *InitCommand) Run(args []string) (retCode int) { }() } - if migrations.DevMigration != c.flagAllowDevMigrations { - if migrations.DevMigration { + dialect := "postgres" + + if schema.DevMigration(dialect) != c.flagAllowDevMigrations { + if schema.DevMigration(dialect) { c.UI.Error(base.WrapAtLength("This version of the binary has " + "dev database schema updates which may not be supported in the " + "next official release. To proceed anyways please use the " + @@ -263,34 +263,63 @@ func (c *InitCommand) Run(args []string) (retCode int) { return 1 } - migrationUrl, err := config.ParseAddress(migrationUrlToParse) - if err != nil && err != config.ErrNotAUrl { - c.UI.Error(fmt.Errorf("Error parsing migration url: %w", err).Error()) + // This database is used to keep an exclusive lock on the database for the + // remainder of the command + dBase, err := sql.Open(dialect, dbaseUrl) + if err != nil { + c.UI.Error(fmt.Errorf("Error establishing db connection for locking: %w", err).Error()) + return 1 + } + man, err := schema.NewManager(c.Context, dialect, dBase) + if err != nil { + c.UI.Error(fmt.Errorf("Error setting up schema manager for locking: %w", err).Error()) return 1 } - - // Core migrations using the migration URL { - c.srv.DatabaseUrl = strings.TrimSpace(migrationUrl) - ldb, err := sql.Open("postgres", c.srv.DatabaseUrl) + st, err := man.CurrentState(c.Context) if err != nil { - c.UI.Error(fmt.Errorf("Error opening database to check init status: %w", err).Error()) + c.UI.Error(fmt.Errorf("Error getting database state: %w", err).Error()) return 1 } - _, err = ldb.QueryContext(c.Context, "select version from schema_migrations") - switch { - case err == nil: - if base.Format(c.UI) == "table" { - c.UI.Info("Database already initialized.") - return 0 - } - case errors.IsMissingTableError(err): - // Doesn't exist so we continue on - default: - c.UI.Error(fmt.Errorf("Error querying database for init status: %w", err).Error()) + if st.Dirty { + c.UI.Error(base.WrapAtLength("Database is in a bad initialization " + + "state. Please revert back to the last known good state.")) + return 1 + } + if st.InitializationStarted { + // TODO: Separate from the "dirty" bit maintained by the schema + // manager maintain a bit which indicates that this full command + // was completed successfully (with all default resources being created). + // Use that bit to determine if a previous init was completed + // successfully or not. + c.UI.Error(base.WrapAtLength("Database has already been " + + "initialized. If the initialization did not complete successfully " + + "please revert the database to its fresh state.")) return 1 } - ran, err := db.InitStore("postgres", nil, c.srv.DatabaseUrl) + } + + // This is an advisory locks on the DB which is released when the db session ends. + if err := man.ExclusiveLock(c.Context); err != nil { + c.UI.Error(fmt.Errorf("Error capturing an exclusive lock: %w", err).Error()) + return 1 + } + defer func() { + if err := man.ExclusiveUnlock(c.Context); err != nil { + c.UI.Error(fmt.Errorf("Unable to release exclusive lock to the database: %w", err).Error()) + } + }() + + migrationUrl, err := config.ParseAddress(migrationUrlToParse) + if err != nil && err != config.ErrNotAUrl { + c.UI.Error(fmt.Errorf("Error parsing migration url: %w", err).Error()) + return 1 + } + + // Core migrations using the migration URL + { + migrationUrl = strings.TrimSpace(migrationUrl) + ran, err := schema.InitStore(c.Context, dialect, migrationUrl) if err != nil { c.UI.Error(fmt.Errorf("Error running database migrations: %w", err).Error()) return 1 @@ -308,7 +337,7 @@ func (c *InitCommand) Run(args []string) (retCode int) { // Everything after is done with normal database URL and is affecting actual data c.srv.DatabaseUrl = strings.TrimSpace(dbaseUrl) - if err := c.srv.ConnectToDatabase("postgres"); err != nil { + if err := c.srv.ConnectToDatabase(dialect); err != nil { c.UI.Error(fmt.Errorf("Error connecting to database after migrations: %w", err).Error()) return 1 } diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index d91f170ccb..f504133169 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -399,7 +399,7 @@ func (c *Command) Run(args []string) int { if c.flagDisableDatabaseDestruction { opts = append(opts, base.WithSkipDatabaseDestruction()) } - if err := c.CreateDevDatabase("postgres", opts...); err != nil { + if err := c.CreateDevDatabase(c.Context, "postgres", opts...); err != nil { if err == docker.ErrDockerUnsupported { c.UI.Error("Automatically starting a Docker container running Postgres is not currently supported on this platform. Please use -database-url to pass in a URL (or an env var or file reference to a URL) for connecting to an existing empty database.") return 1 @@ -417,7 +417,7 @@ func (c *Command) Run(args []string) int { return 1 } c.DatabaseUrl = strings.TrimSpace(dbaseUrl) - if err := c.CreateDevDatabase("postgres"); err != nil { + if err := c.CreateDevDatabase(c.Context, "postgres"); err != nil { c.UI.Error(fmt.Errorf("Error connecting to database: %w", err).Error()) return 1 } diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index 2f53b4cf01..5d064fb701 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" + "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/boundary/internal/servers/worker" "github.com/hashicorp/boundary/sdk/wrapper" @@ -341,6 +342,46 @@ func (c *Command) Run(args []string) int { c.UI.Error(fmt.Errorf("Error connecting to database: %w", err).Error()) return 1 } + + sMan, err := schema.NewManager(c.Context, "postgres", c.Database.DB()) + if err != nil { + c.UI.Error(fmt.Errorf("Can't get schema manager: %w.", err).Error()) + return 1 + } + // This is an advisory locks on the DB which is released when the db session ends. + if err := sMan.SharedLock(c.Context); err != nil { + c.UI.Error(fmt.Errorf("Unable to gain shared access to the database: %w", err).Error()) + return 1 + } + defer func() { + if err := sMan.SharedUnlock(c.Context); err != nil { + c.UI.Error(fmt.Errorf("Unable to release shared lock to the database: %w", err).Error()) + } + }() + ckState, err := sMan.CurrentState(c.Context) + if err != nil { + c.UI.Error(fmt.Errorf("Error checking schema state: %w", err).Error()) + return 1 + } + if !ckState.InitializationStarted { + c.UI.Error("Database has not been initialized. Please run `boundary database init`.") + return 1 + } + if ckState.Dirty { + c.UI.Error(base.WrapAtLength("Database is in a bad state. Please revert the database into the last known good state.")) + return 1 + } + if ckState.BinarySchemaVersion > ckState.DatabaseSchemaVersion { + // TODO: Add the command to migrate up the schema version once that command exists. + c.UI.Error("Older schema version is than is expected from this binary.") + return 1 + } + if ckState.BinarySchemaVersion < ckState.DatabaseSchemaVersion { + c.UI.Error(base.WrapAtLength(fmt.Sprintf("Newer schema version (%d) "+ + "than this binary expects. Please use a newer version of the boundary "+ + "binary.", ckState.DatabaseSchemaVersion))) + return 1 + } } defer func() { diff --git a/internal/db/db.go b/internal/db/db.go index 5537aa303e..64c4fff452 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,15 +1,10 @@ package db import ( - "errors" "fmt" - "os" - "github.com/golang-migrate/migrate/v4" - "github.com/hashicorp/boundary/internal/db/migrations" "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/go-hclog" - "github.com/hashicorp/go-multierror" "github.com/jinzhu/gorm" "github.com/lib/pq" ) @@ -40,66 +35,6 @@ func Open(dbType DbType, connectionUrl string) (*gorm.DB, error) { return db, nil } -// Migrate a database schema -func Migrate(connectionUrl string, migrationsDirectory string) error { - if connectionUrl == "" { - return errors.New("connection url is unset") - } - if _, err := os.Stat(migrationsDirectory); os.IsNotExist(err) { - return errors.New("error migrations directory does not exist") - } - // run migrations - m, err := migrate.New(fmt.Sprintf("file://%s", migrationsDirectory), connectionUrl) - if err != nil { - return fmt.Errorf("unable to create migrations: %w", err) - } - if err := m.Up(); err != nil && err != migrate.ErrNoChange { - return fmt.Errorf("unable to run migrations: %w", err) - } - return nil -} - -// InitStore will execute the migrations needed to initialize the store. It -// returns true if migrations actually ran; false if we were already current. -func InitStore(dialect string, cleanup func() error, url string) (bool, error) { - var mErr *multierror.Error - // run migrations - source, err := migrations.NewMigrationSource(dialect) - if err != nil { - mErr = multierror.Append(mErr, fmt.Errorf("error creating migration driver: %w", err)) - if cleanup != nil { - if err := cleanup(); err != nil { - mErr = multierror.Append(mErr, fmt.Errorf("error cleaning up from creating driver: %w", err)) - } - } - return false, mErr.ErrorOrNil() - } - m, err := migrate.NewWithSourceInstance("httpfs", source, url) - if err != nil { - mErr = multierror.Append(mErr, fmt.Errorf("error creating migrations: %w", err)) - if cleanup != nil { - if err := cleanup(); err != nil { - mErr = multierror.Append(mErr, fmt.Errorf("error cleaning up from creating migrations: %w", err)) - } - } - return false, mErr.ErrorOrNil() - - } - if err := m.Up(); err != nil { - if err == migrate.ErrNoChange { - return false, nil - } - mErr = multierror.Append(mErr, fmt.Errorf("error running migrations: %w", err)) - if cleanup != nil { - if err := cleanup(); err != nil { - mErr = multierror.Append(mErr, fmt.Errorf("error cleaning up from running migrations: %w", err)) - } - } - return false, mErr.ErrorOrNil() - } - return true, mErr.ErrorOrNil() -} - func GetGormLogFormatter(log hclog.Logger) func(values ...interface{}) (messages []interface{}) { return func(values ...interface{}) (messages []interface{}) { if len(values) > 2 && values[0].(string) == "log" { diff --git a/internal/db/db_test.go b/internal/db/db_test.go index fd3bfaae87..f3f762d45a 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -58,56 +58,3 @@ func TestOpen(t *testing.T) { }) } } - -func TestMigrate(t *testing.T) { - cleanup, url, _, err := StartDbInDocker("postgres") - if err != nil { - t.Fatal(err) - } - defer func() { - if err := cleanup(); err != nil { - t.Error(err) - } - }() - type args struct { - connectionUrl string - migrationsDirectory string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "valid", - args: args{ - connectionUrl: url, - migrationsDirectory: "migrations/postgres/0", - }, - wantErr: false, - }, - { - name: "bad-url", - args: args{ - connectionUrl: "", - migrationsDirectory: "migrations/postgres/0", - }, - wantErr: true, - }, - { - name: "bad-dir", - args: args{ - connectionUrl: url, - migrationsDirectory: "", - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := Migrate(tt.args.connectionUrl, tt.args.migrationsDirectory); (err != nil) != tt.wantErr { - t.Errorf("Migrate() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/internal/db/migrations/driver.go b/internal/db/migrations/driver.go deleted file mode 100644 index 4d842cf13e..0000000000 --- a/internal/db/migrations/driver.go +++ /dev/null @@ -1,128 +0,0 @@ -package migrations - -import ( - "bytes" - "fmt" - "net/http" - "os" - "sort" - "strings" - "time" - - "github.com/golang-migrate/migrate/v4/source" - "github.com/golang-migrate/migrate/v4/source/httpfs" -) - -// migrationDriver satisfies the remaining need of the Driver interface, since -// the package uses PartialDriver under the hood -type migrationDriver struct { - dialect string -} - -// Open returns the given "file" -func (m *migrationDriver) Open(name string) (http.File, error) { - return newFakeFile(m.dialect, name) -} - -// NewMigrationSource creates a source.Driver using httpfs with the given dialect -func NewMigrationSource(dialect string) (source.Driver, error) { - switch dialect { - case "postgres": - default: - return nil, fmt.Errorf("unknown migrations dialect %s", dialect) - } - return httpfs.New(&migrationDriver{dialect}, "migrations") -} - -// fakeFile is used to satisfy the http.File interface -type fakeFile struct { - name string - bytes []byte - reader *bytes.Reader - dialect string -} - -func newFakeFile(dialect string, name string) (*fakeFile, error) { - var ff *fakeFile - switch dialect { - case "postgres": - ff = postgresMigrations[name] - } - if ff == nil { - return nil, os.ErrNotExist - } - ff.name = strings.TrimPrefix(name, "migrations/") - ff.reader = bytes.NewReader(ff.bytes) - ff.dialect = dialect - return ff, nil -} - -func (f *fakeFile) Read(p []byte) (n int, err error) { - return f.reader.Read(p) -} - -func (f *fakeFile) Seek(offset int64, whence int) (int64, error) { - return f.reader.Seek(offset, whence) -} - -func (f *fakeFile) Close() error { return nil } - -// Readdir returns os.FileInfo values, in sorted order, and eliding the -// migrations "dir" -func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { - // Get the right map - var migrationsMap map[string]*fakeFile - switch f.dialect { - case "postgres": - migrationsMap = postgresMigrations - default: - return nil, fmt.Errorf("unknown database dialect %s", f.dialect) - } - - // Sort the keys. May not be necessary but feels nice. - keys := make([]string, 0, len(migrationsMap)) - for k := range migrationsMap { - keys = append(keys, k) - } - sort.Strings(keys) - - // Create the slice of fileinfo objects to return - ret := make([]os.FileInfo, 0, len(migrationsMap)) - for i, v := range keys { - // We need "migrations" in the map for the initial Open call but we - // should not return it as part of the "directory"'s "files". - if v == "migrations" { - continue - } - stat, err := migrationsMap[v].Stat() - if err != nil { - return nil, err - } - ret = append(ret, stat) - if count > 0 && count == i { - break - } - } - return ret, nil -} - -// Stat returns a new fakeFileInfo object with the necessary bits -func (f *fakeFile) Stat() (os.FileInfo, error) { - return &fakeFileInfo{ - name: f.name, - size: int64(len(f.bytes)), - }, nil -} - -// fakeFileInfo satisfies os.FileInfo but represents our fake "files" -type fakeFileInfo struct { - name string - size int64 -} - -func (f *fakeFileInfo) Name() string { return f.name } -func (f *fakeFileInfo) Size() int64 { return f.size } -func (f *fakeFileInfo) Mode() os.FileMode { return os.ModePerm } -func (f *fakeFileInfo) ModTime() time.Time { return time.Now() } -func (f *fakeFileInfo) IsDir() bool { return false } -func (f *fakeFileInfo) Sys() interface{} { return nil } diff --git a/internal/db/migrations/driver_test.go b/internal/db/migrations/driver_test.go deleted file mode 100644 index 75d993fc54..0000000000 --- a/internal/db/migrations/driver_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package migrations - -import ( - "os" - "reflect" - "testing" - - "github.com/golang-migrate/migrate/v4/source" - "github.com/golang-migrate/migrate/v4/source/httpfs" - "github.com/stretchr/testify/assert" -) - -func TestNewMigrationSource(t *testing.T) { - type args struct { - dialect string - } - tests := []struct { - name string - args args - want source.Driver - wantErr bool - }{ - { - name: "postgres", - args: args{dialect: "postgres"}, - want: func() source.Driver { - d, err := httpfs.New(&migrationDriver{"postgres"}, "migrations") - if err != nil { - t.Errorf("NewMigrationSource() error creating httpfs = %w", err) - } - return d - }(), - wantErr: false, - }, - { - name: "no-dialect", - args: args{dialect: ""}, - want: nil, - wantErr: true, - }, - { - name: "bad-dialect", - args: args{dialect: "rainbows-and-unicorns-db"}, - want: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NewMigrationSource(tt.args.dialect) - if (err != nil) != tt.wantErr { - t.Errorf("NewMigrationSource() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewMigrationSource() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_migrationDriver_Open(t *testing.T) { - type args struct { - name string - } - tests := []struct { - name string - dialect string - args args - wantErr bool - }{ - { - name: "valid-file", - dialect: "postgres", - args: args{name: "migrations/01_domain_types.up.sql"}, - wantErr: false, - }, - { - name: "bad-file", - dialect: "postgres", - args: args{name: "migrations/unicorns-and-rainbows.up.sql"}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := &migrationDriver{ - dialect: tt.dialect, - } - _, err := m.Open(tt.args.name) - if (err != nil) != tt.wantErr { - t.Errorf("migrationDriver.Open() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} - -func Test_fakeFile_Read(t *testing.T) { - assert := assert.New(t) - t.Run("valid", func(t *testing.T) { - ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql") - assert.NoError(err) - buf := make([]byte, len(ff.bytes)) - n, err := ff.Read(buf) - assert.NoError(err) - assert.Equal(len(buf), n) - }) -} - -func Test_fakeFile_Seek(t *testing.T) { - assert := assert.New(t) - t.Run("valid", func(t *testing.T) { - ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql") - assert.NoError(err) - buf := make([]byte, len(ff.bytes)) - n, err := ff.Seek(10, 0) - assert.NoError(err) - assert.Equal(int64(10), n) - - n2, err := ff.Read(buf) - assert.NoError(err) - assert.Equal(len(ff.bytes)-10, n2) - }) -} - -func Test_fakeFile_Close(t *testing.T) { - assert := assert.New(t) - t.Run("valid", func(t *testing.T) { - m := &migrationDriver{ - dialect: "postgres", - } - f, err := m.Open("migrations/01_domain_types.up.sql") - assert.NoError(err) - err = f.Close() - assert.NoError(err) - }) -} - -func Test_fakeFile_Stat(t *testing.T) { - assert := assert.New(t) - t.Run("valid", func(t *testing.T) { - name := "migrations/01_domain_types.up.sql" - ff, err := newFakeFile("postgres", name) - assert.NoError(err) - info, err := ff.Stat() - assert.NoError(err) - assert.Equal(ff.name, info.Name()) - assert.Equal(int64(len(ff.bytes)), info.Size()) - assert.Equal(os.ModePerm, info.Mode()) - assert.Equal(false, info.IsDir()) - assert.Equal(nil, info.Sys()) - }) -} - -func Test_fakeFile_Readdir(t *testing.T) { - assert := assert.New(t) - t.Run("valid", func(t *testing.T) { - name := "migrations/01_domain_types.up.sql" - ff, err := newFakeFile("postgres", name) - assert.NoError(err) - info, err := ff.Readdir(0) - assert.NoError(err) - assert.NotNil(info) - - info, err = ff.Readdir(1) - assert.NoError(err) - assert.NotNil(info) - assert.Equal(1, len(info)) - - info, err = ff.Readdir(0) - assert.NoError(err) - assert.NotNil(info) - // we don't want to count "migrations", so we're len - 1 - assert.Equal(len(postgresMigrations)-1, len(info)) - }) -} diff --git a/internal/db/migrations/genmigrations/generate.go b/internal/db/migrations/genmigrations/generate.go deleted file mode 100644 index 6ddfe5912f..0000000000 --- a/internal/db/migrations/genmigrations/generate.go +++ /dev/null @@ -1,149 +0,0 @@ -package main - -import ( - "bytes" - "fmt" - "io/ioutil" - "os" - "sort" - "strconv" - "strings" - "text/template" -) - -// generate looks for migration sql in a directory for the given dialect and -// applies the templates below to the contents of the files, building up a -// migrations map for the dialect -func generate(dialect string) { - baseDir := os.Getenv("GEN_BASEPATH") + "/internal/db/migrations" - dir, err := os.Open(fmt.Sprintf("%s/%s", baseDir, dialect)) - if err != nil { - fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err) - os.Exit(1) - } - versions, err := dir.Readdirnames(0) - if err != nil { - fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err) - os.Exit(1) - } - outBuf := bytes.NewBuffer(nil) - valuesBuf := bytes.NewBuffer(nil) - - sort.Strings(versions) - - isDev := false - largestVer := 0 - for _, ver := range versions { - var verVal int - switch ver { - case "dev": - verVal = largestVer + 1 - default: - if verVal, err = strconv.Atoi(ver); err != nil { - fmt.Printf("error reading major schema version directory %q. Must be a number or 'dev'\n", ver) - os.Exit(1) - } - if verVal > largestVer { - largestVer = verVal - } - } - - dir, err := os.Open(fmt.Sprintf("%s/%s/%s", baseDir, dialect, ver)) - if err != nil { - fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err) - os.Exit(1) - } - names, err := dir.Readdirnames(0) - if err != nil { - fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err) - os.Exit(1) - } - - if ver == "dev" && len(names) > 0 { - isDev = true - } - - sort.Strings(names) - for _, name := range names { - if !strings.HasSuffix(name, ".sql") { - continue - } - - contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s/%s", baseDir, dialect, ver, name)) - if err != nil { - fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err) - os.Exit(1) - } - - vName := name - nameParts := strings.SplitN(name, "_", 2) - if len(nameParts) != 2 { - continue - } - - nameVer, err := strconv.Atoi(nameParts[0]) - if err != nil { - fmt.Printf("Unable to get file version from %q\n", name) - continue - } - vName = fmt.Sprintf("%02d_%s", (verVal*1000)+nameVer, nameParts[1]) - - if err := migrationsValueTemplate.Execute(valuesBuf, struct { - Name string - Contents string - }{ - Name: vName, - Contents: string(contents), - }); err != nil { - fmt.Printf("error executing migrations value template for file %s/%s: %s", ver, name, err) - os.Exit(1) - } - } - } - if err := migrationsTemplate.Execute(outBuf, struct { - Type string - Values string - DevMigration bool - }{ - Type: dialect, - Values: valuesBuf.String(), - DevMigration: isDev, - }); err != nil { - fmt.Printf("error executing migrations value template for dialect %s: %s", dialect, err) - os.Exit(1) - } - - outFile := fmt.Sprintf("%s/%s.gen.go", baseDir, dialect) - if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0o644); err != nil { - fmt.Printf("error writing file %q: %v\n", outFile, err) - os.Exit(1) - } -} - -var migrationsTemplate = template.Must(template.New("").Parse( - `// Code generated by "make migrations"; DO NOT EDIT. -package migrations - -import ( - "bytes" -) - -// DevMigration is true if the database schema that would be applied by -// InitStore would be from files in the /dev directory which indicates it would -// not be safe to run in a non dev environment. -var DevMigration = {{ .DevMigration }} - -var {{ .Type }}Migrations = map[string]*fakeFile{ - "migrations": { - name: "migrations", - }, - {{ .Values }} -} -`)) - -var migrationsValueTemplate = template.Must(template.New("").Parse( - `"migrations/{{ .Name }}": { - name: "{{ .Name }}", - bytes: []byte(` + "`\n{{ .Contents }}\n`" + `), - }, -`)) diff --git a/internal/db/schema/manager.go b/internal/db/schema/manager.go new file mode 100644 index 0000000000..a00977dba9 --- /dev/null +++ b/internal/db/schema/manager.go @@ -0,0 +1,187 @@ +package schema + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + + "github.com/hashicorp/boundary/internal/db/schema/postgres" + "github.com/hashicorp/boundary/internal/errors" +) + +// driver provides functionality to a database. +type driver interface { + TrySharedLock(context.Context) error + TryLock(context.Context) error + Lock(context.Context) error + Unlock(context.Context) error + UnlockShared(context.Context) error + Run(context.Context, io.Reader) error + // A value of -1 indicates no version is set. + SetVersion(context.Context, int, bool) error + // A value of -1 indicates no version is set. + Version(context.Context) (int, bool, error) +} + +// Manager provides a way to run operations and retrieve information regarding +// the underlying boundary database schema. +// Manager is not thread safe. +type Manager struct { + db *sql.DB + driver driver + dialect string +} + +// NewManager creates a new schema manager. An error is returned +// if the provided dialect is unrecognized or if the passed in db is unreachable. +func NewManager(ctx context.Context, dialect string, db *sql.DB) (*Manager, error) { + const op = "schema.NewManager" + dbM := Manager{db: db, dialect: dialect} + switch dialect { + case "postgres": + var err error + dbM.driver, err = postgres.New(ctx, db) + if err != nil { + return nil, errors.Wrap(err, op) + } + default: + return nil, errors.New(errors.InvalidParameter, op, fmt.Sprintf("unknown dialect %q", dialect)) + } + return &dbM, nil +} + +// State contains information regarding the current state of a boundary database's schema. +type State struct { + // InitializationStarted indicates if the current database has already been initialized + // (successfully or not) at least once. + InitializationStarted bool + // Dirty is set to true if the database failed in a previous migration/initialization. + Dirty bool + // DatabaseSchemaVersion is the schema version that is currently running in the database. + DatabaseSchemaVersion int + // BinarySchemaVersion is the schema version which this boundary binary supports. + BinarySchemaVersion int +} + +// CurrentState provides the state of the boundary schema contained in the backing database. +func (b *Manager) CurrentState(ctx context.Context) (*State, error) { + dbS := State{ + BinarySchemaVersion: BinarySchemaVersion(b.dialect), + } + v, dirty, err := b.driver.Version(ctx) + if err != nil { + return nil, err + } + if v == nilVersion { + return &dbS, nil + } + dbS.InitializationStarted = true + dbS.DatabaseSchemaVersion = v + dbS.Dirty = dirty + return &dbS, nil +} + +// SharedLock attempts to obtain a shared lock on the database. This can fail if +// an exclusive lock is already held with the same key. An error is returned if +// a lock was unable to be obtained. +func (b *Manager) SharedLock(ctx context.Context) error { + const op = "schema.(Manager).SharedLock" + if err := b.driver.TrySharedLock(ctx); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +// SharedUnlock releases a shared lock on the database. If this +// fails for whatever reason an error is returned. Unlocking a lock +// that is not held is not an error. +func (b *Manager) SharedUnlock(ctx context.Context) error { + const op = "schema.(Manager).SharedUnlock" + if err := b.driver.UnlockShared(ctx); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +// ExclusiveLock attempts to obtain an exclusive lock on the database. +// An error is returned if a lock was unable to be obtained. +func (b *Manager) ExclusiveLock(ctx context.Context) error { + const op = "schema.(Manager).ExclusiveLock" + if err := b.driver.TryLock(ctx); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +// ExclusiveUnlock releases a shared lock on the database. If this +// fails for whatever reason an error is returned. Unlocking a lock +// that is not held is not an error. +func (b *Manager) ExclusiveUnlock(ctx context.Context) error { + const op = "schema.(Manager).ExclusiveUnlock" + if err := b.driver.Unlock(ctx); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +// RollForward updates the database schema to match the latest version known by +// the boundary binary. An error is not returned if the database is already at +// the most recent version. +func (b *Manager) RollForward(ctx context.Context) error { + const op = "schema.(Manager).RollForward" + + // Capturing a lock that this session to the db already possesses is okay. + if err := b.driver.Lock(ctx); err != nil { + return errors.Wrap(err, op) + } + defer func() { + b.driver.Unlock(ctx) + }() + + curVersion, dirty, err := b.driver.Version(ctx) + if err != nil { + return errors.Wrap(err, op) + } + + if dirty { + return errors.New(errors.NotSpecificIntegrity, op, fmt.Sprintf("schema is dirty with version %d", curVersion)) + } + + sp, err := newStatementProvider(b.dialect, curVersion) + if err != nil { + return errors.Wrap(err, op) + } + return b.runMigrations(ctx, sp) +} + +// runMigrations passes migration queries to a database driver and manages +// the version and dirty bit. Cancelation or deadline/timeout is managed +// through the passed in context. +func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) error { + const op = "schema.(Manager).runMigrations" + for qp.Next() { + select { + case <-ctx.Done(): + return errors.Wrap(ctx.Err(), op) + default: + // context is not done yet. Continue on to the next query to execute. + } + + // set version with dirty state + if err := b.driver.SetVersion(ctx, qp.Version(), true); err != nil { + return errors.Wrap(err, op) + } + + if err := b.driver.Run(ctx, bytes.NewReader(qp.ReadUp())); err != nil { + return errors.Wrap(err, op) + } + + // set clean state + if err := b.driver.SetVersion(ctx, qp.Version(), false); err != nil { + return errors.Wrap(err, op) + } + } + return nil +} diff --git a/internal/db/schema/manager_test.go b/internal/db/schema/manager_test.go new file mode 100644 index 0000000000..ceaa96ce61 --- /dev/null +++ b/internal/db/schema/manager_test.go @@ -0,0 +1,201 @@ +package schema + +import ( + "context" + "database/sql" + "testing" + + "github.com/hashicorp/boundary/internal/db/schema/postgres" + "github.com/hashicorp/boundary/internal/docker" + "github.com/hashicorp/boundary/internal/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewManager(t *testing.T) { + c, u, _, err := docker.StartDbInDocker("postgres") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := sql.Open("postgres", u) + require.NoError(t, err) + + ctx := context.Background() + _, err = NewManager(ctx, "postgres", d) + require.NoError(t, err) + _, err = NewManager(ctx, "unknown", d) + assert.True(t, errors.Match(errors.T(errors.InvalidParameter), err)) + + d.Close() + _, err = NewManager(ctx, "postgres", d) + assert.True(t, errors.Match(errors.T(errors.Op("schema.NewManager")), err)) +} + +func TestCurrentState(t *testing.T) { + c, u, _, err := docker.StartDbInDocker("postgres") + t.Cleanup(func() { + if err := c(); err != nil { + t.Fatalf("Got error at cleanup: %v", err) + } + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + ctx := context.Background() + d, err := sql.Open("postgres", u) + require.NoError(t, err) + + m, err := NewManager(ctx, "postgres", d) + require.NoError(t, err) + want := &State{ + BinarySchemaVersion: BinarySchemaVersion("postgres"), + } + s, err := m.CurrentState(ctx) + require.NoError(t, err) + assert.Equal(t, want, s) + + testDriver, err := postgres.New(ctx, d) + require.NoError(t, err) + require.NoError(t, testDriver.SetVersion(ctx, 2, true)) + + want = &State{ + InitializationStarted: true, + BinarySchemaVersion: BinarySchemaVersion("postgres"), + Dirty: true, + DatabaseSchemaVersion: 2, + } + s, err = m.CurrentState(ctx) + require.NoError(t, err) + assert.Equal(t, want, s) +} + +func TestRollForward(t *testing.T) { + c, u, _, err := docker.StartDbInDocker("postgres") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := sql.Open("postgres", u) + require.NoError(t, err) + + ctx := context.Background() + m, err := NewManager(ctx, "postgres", d) + require.NoError(t, err) + assert.NoError(t, m.RollForward(ctx)) + + // Now set to dirty at an early version + testDriver, err := postgres.New(ctx, d) + require.NoError(t, err) + testDriver.SetVersion(ctx, 0, true) + assert.Error(t, m.RollForward(ctx)) +} + +func TestRollForward_NotFromFresh(t *testing.T) { + dialect := "postgres" + oState := migrationStates[dialect] + + nState := createPartialMigrationState(oState, 8) + migrationStates[dialect] = nState + + c, u, _, err := docker.StartDbInDocker("postgres") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := sql.Open(dialect, u) + require.NoError(t, err) + + // Initialize the DB with only a portion of the current sql scripts. + ctx := context.Background() + m, err := NewManager(ctx, dialect, d) + require.NoError(t, err) + assert.NoError(t, m.RollForward(ctx)) + + ver, dirty, err := m.driver.Version(ctx) + assert.NoError(t, err) + assert.Equal(t, nState.binarySchemaVersion, ver) + assert.False(t, dirty) + + // Restore the full set of sql scripts and roll the rest of the way forward. + migrationStates[dialect] = oState + + newM, err := NewManager(ctx, dialect, d) + require.NoError(t, err) + assert.NoError(t, newM.RollForward(ctx)) + ver, dirty, err = newM.driver.Version(ctx) + assert.NoError(t, err) + assert.Equal(t, oState.binarySchemaVersion, ver) + assert.False(t, dirty) +} + +func TestManager_ExclusiveLock(t *testing.T) { + ctx := context.Background() + c, u, _, err := docker.StartDbInDocker("postgres") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d1, err := sql.Open("postgres", u) + require.NoError(t, err) + m1, err := NewManager(ctx, "postgres", d1) + require.NoError(t, err) + + d2, err := sql.Open("postgres", u) + require.NoError(t, err) + m2, err := NewManager(ctx, "postgres", d2) + require.NoError(t, err) + + assert.NoError(t, m1.ExclusiveLock(ctx)) + assert.NoError(t, m1.ExclusiveLock(ctx)) + assert.Error(t, m2.ExclusiveLock(ctx)) + assert.Error(t, m2.SharedLock(ctx)) +} + +func TestManager_SharedLock(t *testing.T) { + ctx := context.Background() + c, u, _, err := docker.StartDbInDocker("postgres") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d1, err := sql.Open("postgres", u) + require.NoError(t, err) + m1, err := NewManager(ctx, "postgres", d1) + require.NoError(t, err) + + d2, err := sql.Open("postgres", u) + require.NoError(t, err) + m2, err := NewManager(ctx, "postgres", d2) + require.NoError(t, err) + + assert.NoError(t, m1.SharedLock(ctx)) + assert.NoError(t, m2.SharedLock(ctx)) + assert.NoError(t, m1.SharedLock(ctx)) + assert.NoError(t, m2.SharedLock(ctx)) + + assert.Error(t, m1.ExclusiveLock(ctx)) + assert.Error(t, m2.ExclusiveLock(ctx)) +} + +// Creates a new migrationState only with the versions <= the provided maxVer +func createPartialMigrationState(om migrationState, maxVer int) migrationState { + nState := migrationState{ + devMigration: om.devMigration, + upMigrations: make(map[int][]byte), + downMigrations: make(map[int][]byte), + } + for k := range om.upMigrations { + if k > maxVer { + // Don't store any versions past our test version. + continue + } + nState.upMigrations[k] = om.upMigrations[k] + nState.downMigrations[k] = om.downMigrations[k] + if nState.binarySchemaVersion < k { + nState.binarySchemaVersion = k + } + } + return nState +} diff --git a/internal/db/schema/migrations/README.md b/internal/db/schema/migrations/README.md new file mode 100644 index 0000000000..4167cdc461 --- /dev/null +++ b/internal/db/schema/migrations/README.md @@ -0,0 +1,27 @@ +# migrations package +This package handles the generation of the database schema in a format that can +be compiled into the boundary binary. + +## Organization + +* `./generate`: contains the makefile, code, and templates needed to generate the schema info. +* `./postgres`: contains the versioned schema folders. The contents of these folders, except + for `dev` should not be modified. + +## Usage +To regenerate the schema information into the format the boundary binary uses run +`make migrations` or `make gen` to recreate all generated code. + +The content of the folders under `./postgres` are compiled into the +boundary binary and when the `boundary database init` or `boundary database migrate` +commands are executed they are applied in order of their version. + +The `./postgres/dev` directory contains schema files that are under development and +are not included in a release yet and so it is the only directory where additions and +modifications are allowed. When a boundary binary is built when this directory is not +empty a special flag is required to run the `boundary database init` command to indicate +the user is aware that this is a development release and running this command can +result in a completely broken schema and dataloss. + +When a new release is made the contents of the `dev` directory are moved into a new +versioned directory. \ No newline at end of file diff --git a/internal/db/migrations/genmigrations/Makefile b/internal/db/schema/migrations/generate/Makefile similarity index 73% rename from internal/db/migrations/genmigrations/Makefile rename to internal/db/schema/migrations/generate/Makefile index dcc0e8a9e3..4e51dea151 100644 --- a/internal/db/migrations/genmigrations/Makefile +++ b/internal/db/schema/migrations/generate/Makefile @@ -3,6 +3,6 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST)) migrations: go run . - goimports -w ${GEN_BASEPATH}/internal/db/migrations + goimports -w ${GEN_BASEPATH}/internal/db/schema .PHONY: migrations diff --git a/internal/db/schema/migrations/generate/generate.go b/internal/db/schema/migrations/generate/generate.go new file mode 100644 index 0000000000..48e352e63f --- /dev/null +++ b/internal/db/schema/migrations/generate/generate.go @@ -0,0 +1,158 @@ +package main + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "sort" + "strconv" + "strings" + "text/template" +) + +// generate looks for migration sql in a directory for the given dialect and +// applies the templates below to the contents of the files, building up a +// migrations map for the dialect +func generate(dialect string) { + baseDir := os.Getenv("GEN_BASEPATH") + "/internal/db/schema" + srcDir := baseDir + "/migrations" + dir, err := os.Open(fmt.Sprintf("%s/%s", srcDir, dialect)) + if err != nil { + fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err) + os.Exit(1) + } + versions, err := dir.Readdirnames(0) + if err != nil { + fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err) + os.Exit(1) + } + sort.Strings(versions) + + type ContentValues struct { + Name string + Content string + } + var upContents []ContentValues + var downContents []ContentValues + + isDev := false + var lRelVer, largestSchemaVersion int + for _, ver := range versions { + var verVal int + switch ver { + case "dev": + verVal = lRelVer + 1 + default: + v, err := strconv.Atoi(ver) + if err != nil { + fmt.Printf("error reading major schema version directory %q. Must be a number or 'dev'\n", ver) + os.Exit(1) + } + verVal = v + if verVal > lRelVer { + lRelVer = verVal + } + } + + dir, err := os.Open(fmt.Sprintf("%s/%s/%s", srcDir, dialect, ver)) + if err != nil { + fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err) + os.Exit(1) + } + names, err := dir.Readdirnames(0) + if err != nil { + fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err) + os.Exit(1) + } + + if ver == "dev" && len(names) > 0 { + isDev = true + } + + sort.Strings(names) + for _, name := range names { + if !strings.HasSuffix(name, ".sql") { + continue + } + + contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s/%s", srcDir, dialect, ver, name)) + if err != nil { + fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err) + os.Exit(1) + } + + nameParts := strings.SplitN(name, "_", 2) + if len(nameParts) != 2 { + continue + } + + v, err := strconv.Atoi(nameParts[0]) + if err != nil { + fmt.Printf("Unable to get file version from %q\n", name) + continue + } + + fullV := (verVal * 1000) + v + if fullV > largestSchemaVersion { + largestSchemaVersion = fullV + } + cv := ContentValues{ + Name: fmt.Sprint(fullV), + Content: string(contents), + } + switch { + case strings.Contains(nameParts[1], ".down."): + downContents = append(downContents, cv) + case strings.Contains(nameParts[1], ".up."): + upContents = append(upContents, cv) + } + } + } + + // fmt.Printf("Got upcontent: %#v\n\n downcontents: %#v", upContents[:1], downContents[:1]) + + outBuf := bytes.NewBuffer(nil) + if err := migrationsTemplate.Execute(outBuf, struct { + Type string + UpValues []ContentValues + DownValues []ContentValues + DevMigration bool + BinarySchemaVersion int + }{ + Type: dialect, + UpValues: upContents, + DownValues: downContents, + DevMigration: isDev, + BinarySchemaVersion: largestSchemaVersion, + }); err != nil { + fmt.Printf("error executing migrations value template for dialect %s: %s", dialect, err) + os.Exit(1) + } + + outFile := fmt.Sprintf("%s/%s_migration.gen.go", baseDir, dialect) + if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0o644); err != nil { + fmt.Printf("error writing file %q: %v\n", outFile, err) + os.Exit(1) + } +} + +var migrationsTemplate = template.Must(template.Must(template.New("Content").Parse( + `{{ .Name }}: []byte(` + "`\n{{ .Content }}\n`" + `), +`)).New("MainPage").Parse(`package schema + +// Code generated by "make migrations"; DO NOT EDIT. + +func init() { + migrationStates["{{ .Type }}"] = migrationState{ + devMigration: {{ .DevMigration }}, + binarySchemaVersion: {{ .BinarySchemaVersion }}, + upMigrations: map[int][]byte{ + {{range .UpValues }}{{ template "Content" . }}{{end}} + }, + downMigrations: map[int][]byte{ + {{range .DownValues }}{{ template "Content" . }}{{end}} + }, + } +} +`)) diff --git a/internal/db/migrations/genmigrations/main.go b/internal/db/schema/migrations/generate/main.go similarity index 100% rename from internal/db/migrations/genmigrations/main.go rename to internal/db/schema/migrations/generate/main.go diff --git a/internal/db/migrations/postgres/0/01_domain_types.down.sql b/internal/db/schema/migrations/postgres/0/01_domain_types.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/01_domain_types.down.sql rename to internal/db/schema/migrations/postgres/0/01_domain_types.down.sql diff --git a/internal/db/migrations/postgres/0/01_domain_types.up.sql b/internal/db/schema/migrations/postgres/0/01_domain_types.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/01_domain_types.up.sql rename to internal/db/schema/migrations/postgres/0/01_domain_types.up.sql diff --git a/internal/db/migrations/postgres/0/02_oplog.down.sql b/internal/db/schema/migrations/postgres/0/02_oplog.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/02_oplog.down.sql rename to internal/db/schema/migrations/postgres/0/02_oplog.down.sql diff --git a/internal/db/migrations/postgres/0/02_oplog.up.sql b/internal/db/schema/migrations/postgres/0/02_oplog.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/02_oplog.up.sql rename to internal/db/schema/migrations/postgres/0/02_oplog.up.sql diff --git a/internal/db/migrations/postgres/0/03_db.down.sql b/internal/db/schema/migrations/postgres/0/03_db.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/03_db.down.sql rename to internal/db/schema/migrations/postgres/0/03_db.down.sql diff --git a/internal/db/migrations/postgres/0/03_db.up.sql b/internal/db/schema/migrations/postgres/0/03_db.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/03_db.up.sql rename to internal/db/schema/migrations/postgres/0/03_db.up.sql diff --git a/internal/db/migrations/postgres/0/06_iam.down.sql b/internal/db/schema/migrations/postgres/0/06_iam.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/06_iam.down.sql rename to internal/db/schema/migrations/postgres/0/06_iam.down.sql diff --git a/internal/db/migrations/postgres/0/06_iam.up.sql b/internal/db/schema/migrations/postgres/0/06_iam.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/06_iam.up.sql rename to internal/db/schema/migrations/postgres/0/06_iam.up.sql diff --git a/internal/db/migrations/postgres/0/07_auth.down.sql b/internal/db/schema/migrations/postgres/0/07_auth.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/07_auth.down.sql rename to internal/db/schema/migrations/postgres/0/07_auth.down.sql diff --git a/internal/db/migrations/postgres/0/07_auth.up.sql b/internal/db/schema/migrations/postgres/0/07_auth.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/07_auth.up.sql rename to internal/db/schema/migrations/postgres/0/07_auth.up.sql diff --git a/internal/db/migrations/postgres/0/08_servers.down.sql b/internal/db/schema/migrations/postgres/0/08_servers.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/08_servers.down.sql rename to internal/db/schema/migrations/postgres/0/08_servers.down.sql diff --git a/internal/db/migrations/postgres/0/08_servers.up.sql b/internal/db/schema/migrations/postgres/0/08_servers.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/08_servers.up.sql rename to internal/db/schema/migrations/postgres/0/08_servers.up.sql diff --git a/internal/db/migrations/postgres/0/11_auth_token.down.sql b/internal/db/schema/migrations/postgres/0/11_auth_token.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/11_auth_token.down.sql rename to internal/db/schema/migrations/postgres/0/11_auth_token.down.sql diff --git a/internal/db/migrations/postgres/0/11_auth_token.up.sql b/internal/db/schema/migrations/postgres/0/11_auth_token.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/11_auth_token.up.sql rename to internal/db/schema/migrations/postgres/0/11_auth_token.up.sql diff --git a/internal/db/migrations/postgres/0/12_auth_password.down.sql b/internal/db/schema/migrations/postgres/0/12_auth_password.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/12_auth_password.down.sql rename to internal/db/schema/migrations/postgres/0/12_auth_password.down.sql diff --git a/internal/db/migrations/postgres/0/12_auth_password.up.sql b/internal/db/schema/migrations/postgres/0/12_auth_password.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/12_auth_password.up.sql rename to internal/db/schema/migrations/postgres/0/12_auth_password.up.sql diff --git a/internal/db/migrations/postgres/0/13_auth_password_argon.down.sql b/internal/db/schema/migrations/postgres/0/13_auth_password_argon.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/13_auth_password_argon.down.sql rename to internal/db/schema/migrations/postgres/0/13_auth_password_argon.down.sql diff --git a/internal/db/migrations/postgres/0/13_auth_password_argon.up.sql b/internal/db/schema/migrations/postgres/0/13_auth_password_argon.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/13_auth_password_argon.up.sql rename to internal/db/schema/migrations/postgres/0/13_auth_password_argon.up.sql diff --git a/internal/db/migrations/postgres/0/14_auth_password_views.down.sql b/internal/db/schema/migrations/postgres/0/14_auth_password_views.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/14_auth_password_views.down.sql rename to internal/db/schema/migrations/postgres/0/14_auth_password_views.down.sql diff --git a/internal/db/migrations/postgres/0/14_auth_password_views.up.sql b/internal/db/schema/migrations/postgres/0/14_auth_password_views.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/14_auth_password_views.up.sql rename to internal/db/schema/migrations/postgres/0/14_auth_password_views.up.sql diff --git a/internal/db/migrations/postgres/0/20_host.down.sql b/internal/db/schema/migrations/postgres/0/20_host.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/20_host.down.sql rename to internal/db/schema/migrations/postgres/0/20_host.down.sql diff --git a/internal/db/migrations/postgres/0/20_host.up.sql b/internal/db/schema/migrations/postgres/0/20_host.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/20_host.up.sql rename to internal/db/schema/migrations/postgres/0/20_host.up.sql diff --git a/internal/db/migrations/postgres/0/22_static_host.down.sql b/internal/db/schema/migrations/postgres/0/22_static_host.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/22_static_host.down.sql rename to internal/db/schema/migrations/postgres/0/22_static_host.down.sql diff --git a/internal/db/migrations/postgres/0/22_static_host.up.sql b/internal/db/schema/migrations/postgres/0/22_static_host.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/22_static_host.up.sql rename to internal/db/schema/migrations/postgres/0/22_static_host.up.sql diff --git a/internal/db/migrations/postgres/0/30_keys.down.sql b/internal/db/schema/migrations/postgres/0/30_keys.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/30_keys.down.sql rename to internal/db/schema/migrations/postgres/0/30_keys.down.sql diff --git a/internal/db/migrations/postgres/0/30_keys.up.sql b/internal/db/schema/migrations/postgres/0/30_keys.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/30_keys.up.sql rename to internal/db/schema/migrations/postgres/0/30_keys.up.sql diff --git a/internal/db/migrations/postgres/0/31_keys.down.sql b/internal/db/schema/migrations/postgres/0/31_keys.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/31_keys.down.sql rename to internal/db/schema/migrations/postgres/0/31_keys.down.sql diff --git a/internal/db/migrations/postgres/0/31_keys.up.sql b/internal/db/schema/migrations/postgres/0/31_keys.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/31_keys.up.sql rename to internal/db/schema/migrations/postgres/0/31_keys.up.sql diff --git a/internal/db/migrations/postgres/0/40_targets.down.sql b/internal/db/schema/migrations/postgres/0/40_targets.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/40_targets.down.sql rename to internal/db/schema/migrations/postgres/0/40_targets.down.sql diff --git a/internal/db/migrations/postgres/0/40_targets.up.sql b/internal/db/schema/migrations/postgres/0/40_targets.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/40_targets.up.sql rename to internal/db/schema/migrations/postgres/0/40_targets.up.sql diff --git a/internal/db/migrations/postgres/0/41_targets.down.sql b/internal/db/schema/migrations/postgres/0/41_targets.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/41_targets.down.sql rename to internal/db/schema/migrations/postgres/0/41_targets.down.sql diff --git a/internal/db/migrations/postgres/0/41_targets.up.sql b/internal/db/schema/migrations/postgres/0/41_targets.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/41_targets.up.sql rename to internal/db/schema/migrations/postgres/0/41_targets.up.sql diff --git a/internal/db/migrations/postgres/0/50_session.down.sql b/internal/db/schema/migrations/postgres/0/50_session.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/50_session.down.sql rename to internal/db/schema/migrations/postgres/0/50_session.down.sql diff --git a/internal/db/migrations/postgres/0/50_session.up.sql b/internal/db/schema/migrations/postgres/0/50_session.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/50_session.up.sql rename to internal/db/schema/migrations/postgres/0/50_session.up.sql diff --git a/internal/db/migrations/postgres/0/51_connection.down.sql b/internal/db/schema/migrations/postgres/0/51_connection.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/51_connection.down.sql rename to internal/db/schema/migrations/postgres/0/51_connection.down.sql diff --git a/internal/db/migrations/postgres/0/51_connection.up.sql b/internal/db/schema/migrations/postgres/0/51_connection.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/51_connection.up.sql rename to internal/db/schema/migrations/postgres/0/51_connection.up.sql diff --git a/internal/db/migrations/postgres/0/60_wh_domain_types.down.sql b/internal/db/schema/migrations/postgres/0/60_wh_domain_types.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/60_wh_domain_types.down.sql rename to internal/db/schema/migrations/postgres/0/60_wh_domain_types.down.sql diff --git a/internal/db/migrations/postgres/0/60_wh_domain_types.up.sql b/internal/db/schema/migrations/postgres/0/60_wh_domain_types.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/60_wh_domain_types.up.sql rename to internal/db/schema/migrations/postgres/0/60_wh_domain_types.up.sql diff --git a/internal/db/migrations/postgres/0/62_wh_datetime.down.sql b/internal/db/schema/migrations/postgres/0/62_wh_datetime.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/62_wh_datetime.down.sql rename to internal/db/schema/migrations/postgres/0/62_wh_datetime.down.sql diff --git a/internal/db/migrations/postgres/0/62_wh_datetime.up.sql b/internal/db/schema/migrations/postgres/0/62_wh_datetime.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/62_wh_datetime.up.sql rename to internal/db/schema/migrations/postgres/0/62_wh_datetime.up.sql diff --git a/internal/db/migrations/postgres/0/65_wh_session_dimensions.down.sql b/internal/db/schema/migrations/postgres/0/65_wh_session_dimensions.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/65_wh_session_dimensions.down.sql rename to internal/db/schema/migrations/postgres/0/65_wh_session_dimensions.down.sql diff --git a/internal/db/migrations/postgres/0/65_wh_session_dimensions.up.sql b/internal/db/schema/migrations/postgres/0/65_wh_session_dimensions.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/65_wh_session_dimensions.up.sql rename to internal/db/schema/migrations/postgres/0/65_wh_session_dimensions.up.sql diff --git a/internal/db/migrations/postgres/0/66_wh_session_dimensions.down.sql b/internal/db/schema/migrations/postgres/0/66_wh_session_dimensions.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/66_wh_session_dimensions.down.sql rename to internal/db/schema/migrations/postgres/0/66_wh_session_dimensions.down.sql diff --git a/internal/db/migrations/postgres/0/66_wh_session_dimensions.up.sql b/internal/db/schema/migrations/postgres/0/66_wh_session_dimensions.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/66_wh_session_dimensions.up.sql rename to internal/db/schema/migrations/postgres/0/66_wh_session_dimensions.up.sql diff --git a/internal/db/migrations/postgres/0/68_wh_session_facts.down.sql b/internal/db/schema/migrations/postgres/0/68_wh_session_facts.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/68_wh_session_facts.down.sql rename to internal/db/schema/migrations/postgres/0/68_wh_session_facts.down.sql diff --git a/internal/db/migrations/postgres/0/68_wh_session_facts.up.sql b/internal/db/schema/migrations/postgres/0/68_wh_session_facts.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/68_wh_session_facts.up.sql rename to internal/db/schema/migrations/postgres/0/68_wh_session_facts.up.sql diff --git a/internal/db/migrations/postgres/0/69_wh_session_facts.down.sql b/internal/db/schema/migrations/postgres/0/69_wh_session_facts.down.sql similarity index 100% rename from internal/db/migrations/postgres/0/69_wh_session_facts.down.sql rename to internal/db/schema/migrations/postgres/0/69_wh_session_facts.down.sql diff --git a/internal/db/migrations/postgres/0/69_wh_session_facts.up.sql b/internal/db/schema/migrations/postgres/0/69_wh_session_facts.up.sql similarity index 100% rename from internal/db/migrations/postgres/0/69_wh_session_facts.up.sql rename to internal/db/schema/migrations/postgres/0/69_wh_session_facts.up.sql diff --git a/internal/db/schema/postgres/postgres.go b/internal/db/schema/postgres/postgres.go new file mode 100644 index 0000000000..932313d630 --- /dev/null +++ b/internal/db/schema/postgres/postgres.go @@ -0,0 +1,356 @@ +// The MIT License (MIT) +// +// Original Work +// Copyright (c) 2016 Matthias Kadenbach +// https://github.com/mattes/migrate +// +// Modified Work +// Copyright (c) 2018 Dale Hui +// https://github.com/golang-migrate/migrate +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "io" + "io/ioutil" + "strconv" + "strings" + + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/go-multierror" + "github.com/lib/pq" +) + +// schemaAccessLockId is a Lock key used to ensure a single boundary binary is operating +// on a postgres server at a time. The value has no meaning and was picked randomly. +const ( + schemaAccessLockId int64 = 3865661975 + nilVersion = -1 +) + +var defaultMigrationsTable = "boundary_schema_version" + +// Postgres is a driver usable by a boundary schema manager. +type Postgres struct { + // Locking and unlocking need to use the same connection + conn *sql.Conn + db *sql.DB +} + +// New returns a postgres pointer with the provided db verified as +// connectable and a version table being initialized. +func New(ctx context.Context, instance *sql.DB) (*Postgres, error) { + const op = "postgres.New" + if err := instance.PingContext(ctx); err != nil { + return nil, errors.Wrap(err, op) + } + conn, err := instance.Conn(ctx) + if err != nil { + return nil, errors.Wrap(err, op) + } + + px := &Postgres{ + conn: conn, + db: instance, + } + + if err := px.ensureVersionTable(ctx); err != nil { + return nil, errors.Wrap(err, op) + } + + return px, nil +} + +// TrySharedLock attempts to capture a shared lock. If it is not successful it returns an error. +// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS +func (p *Postgres) TrySharedLock(ctx context.Context) error { + const op = "postgres.(Postgres).TrySharedLock" + r := p.conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock_shared($1)", schemaAccessLockId) + if r.Err() != nil { + return errors.Wrap(r.Err(), op) + } + var gotLock bool + if err := r.Scan(&gotLock); err != nil { + return errors.Wrap(err, op) + } + if !gotLock { + return errors.New(errors.MigrationLock, op, "Lock failed") + } + return nil +} + +// TryLock attempts to capture an exclusive lock. If it is not successful it returns an error. +// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS +func (p *Postgres) TryLock(ctx context.Context) error { + const op = "postgres.(Postgres).TryLock" + + r := p.conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", schemaAccessLockId) + if r.Err() != nil { + return errors.Wrap(r.Err(), op) + } + var gotLock bool + if err := r.Scan(&gotLock); err != nil { + return errors.Wrap(err, op) + } + if !gotLock { + return errors.New(errors.MigrationLock, op, "Lock failed") + } + return nil +} + +// Lock calls pg_advisory_lock with the provided context and returns an error +// if we were unable to get the lock before the context cancels. +func (p *Postgres) Lock(ctx context.Context) error { + const op = "postgres.(Postgres).Lock" + + // This will wait indefinitely until the Lock can be acquired. + query := `SELECT pg_advisory_lock($1)` + if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil { + return errors.Wrap(err, op) + } + + return nil +} + +// Unlock calls pg_advisory_unlock and returns an error if we were unable to +// release the lock before the context cancels. +func (p *Postgres) Unlock(ctx context.Context) error { + const op = "postgres.(Postgres).Unlock" + + query := `SELECT pg_advisory_unlock($1)` + if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +// UnlockShared calls pg_advisory_unlock_shared and returns an error if we were unable to +// release the lock before the context cancels. +func (p *Postgres) UnlockShared(ctx context.Context) error { + const op = "postgres.(Postgres).UnlockShared" + query := `SELECT pg_advisory_unlock_shared($1)` + if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +// Executes the sql provided in the passed in io.Reader. The contents of the reader must +// fit in memory as the full content is read into a string before being passed to the +// backing database. +func (p *Postgres) Run(ctx context.Context, migration io.Reader) error { + const op = "postgres.(Postgres).Run" + migr, err := ioutil.ReadAll(migration) + if err != nil { + return errors.Wrap(err, op) + } + // Run migration + query := string(migr) + if _, err := p.conn.ExecContext(ctx, query); err != nil { + if pgErr, ok := err.(*pq.Error); ok { + var line uint + var col uint + var lineColOK bool + if pgErr.Position != "" { + if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { + line, col, lineColOK = computeLineFromPos(query, int(pos)) + } + } + message := fmt.Sprintf("migration failed") + if lineColOK { + message = fmt.Sprintf("%s (column %d)", message, col) + } + if pgErr.Detail != "" { + message = fmt.Sprintf("%s, %s", message, pgErr.Detail) + } + message = fmt.Sprintf("%s, on line %v: %s", message, line, migr) + return errors.Wrap(err, op, errors.WithMsg(message)) + } + return errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("migration failed: %s", migr))) + } + + return nil +} + +func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { + // replace crlf with lf + s = strings.Replace(s, "\r\n", "\n", -1) + // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes + runes := []rune(s) + if pos > len(runes) { + return 0, 0, false + } + sel := runes[:pos] + line = uint(runesCount(sel, newLine) + 1) + col = uint(pos - 1 - runesLastIndex(sel, newLine)) + return line, col, true +} + +const newLine = '\n' + +func runesCount(input []rune, target rune) int { + var count int + for _, r := range input { + if r == target { + count++ + } + } + return count +} + +func runesLastIndex(input []rune, target rune) int { + for i := len(input) - 1; i >= 0; i-- { + if input[i] == target { + return i + } + } + return -1 +} + +// SetVersion sets the version number, and whether the database is in a dirty state. +// A version value of -1 indicates no version is set. +func (p *Postgres) SetVersion(ctx context.Context, version int, dirty bool) error { + const op = "postgres.(Postgres).SetVersion" + tx, err := p.conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + + query := `TRUNCATE ` + pq.QuoteIdentifier(defaultMigrationsTable) + if _, err := tx.ExecContext(ctx, query); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return errors.Wrap(err, op) + } + + // Also re-write the schema Version for nil dirty versions to prevent + // empty schema Version for failed down migration on the first migration + // See: https://github.com/golang-migrate/migrate/issues/330 + if version >= 0 || (version == nilVersion && dirty) { + query = `INSERT INTO ` + pq.QuoteIdentifier(defaultMigrationsTable) + + ` (Version, dirty) VALUES ($1, $2)` + if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return errors.Wrap(err, op) + } + } + + if err := tx.Commit(); err != nil { + return errors.Wrap(err, op) + } + + return nil +} + +// Version returns the version, if the database is currently in a dirty state, and any error. +// A version value of -1 indicates no version is set. +func (p *Postgres) Version(ctx context.Context) (version int, dirty bool, err error) { + const op = "postgres.(Postgres).Version" + query := `SELECT Version, dirty FROM ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` LIMIT 1` + err = p.conn.QueryRowContext(ctx, query).Scan(&version, &dirty) + switch { + case err == sql.ErrNoRows: + return nilVersion, false, nil + + case err != nil: + if e, ok := err.(*pq.Error); ok { + if e.Code.Name() == "undefined_table" { + return nilVersion, false, nil + } + } + return 0, false, errors.Wrap(err, op) + + default: + return version, dirty, nil + } +} + +func (p *Postgres) drop(ctx context.Context) (err error) { + const op = "postgres.(Postgres).drop" + // select all tables in current schema + query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` + tables, err := p.conn.QueryContext(ctx, query) + if err != nil { + return errors.Wrap(err, op) + } + defer func() { + if errClose := tables.Close(); errClose != nil { + err = multierror.Append(err, errClose) + err = errors.Wrap(err, op) + } + }() + + // delete one table after another + tableNames := make([]string, 0) + for tables.Next() { + var tableName string + if err := tables.Scan(&tableName); err != nil { + return errors.Wrap(err, op) + } + if len(tableName) > 0 { + tableNames = append(tableNames, tableName) + } + } + if err := tables.Err(); err != nil { + return errors.Wrap(err, op) + } + + if len(tableNames) > 0 { + // delete one by one ... + for _, t := range tableNames { + query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` + if _, err := p.conn.ExecContext(ctx, query); err != nil { + return errors.Wrap(err, op) + } + } + } + + return nil +} + +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the postgres type. +func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) { + const op = "postgres.(Postgres).ensureVersionTable" + if err = p.Lock(ctx); err != nil { + return errors.Wrap(err, op) + } + + defer func() { + if e := p.Unlock(ctx); e != nil { + err = multierror.Append(err, e) + } + }() + + query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (Version bigint primary key, dirty boolean not null)` + if _, err = p.conn.ExecContext(ctx, query); err != nil { + return errors.Wrap(err, op) + } + + return nil +} diff --git a/internal/db/schema/postgres/postgres_test.go b/internal/db/schema/postgres/postgres_test.go new file mode 100644 index 0000000000..3432972dae --- /dev/null +++ b/internal/db/schema/postgres/postgres_test.go @@ -0,0 +1,433 @@ +// The MIT License (MIT) +// +// Original Work +// Copyright (c) 2016 Matthias Kadenbach +// https://github.com/mattes/migrate +// +// Modified Work +// Copyright (c) 2018 Dale Hui +// https://github.com/golang-migrate/migrate +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package postgres + +// error codes https://github.com/lib/pq/blob/master/error.go + +import ( + "bytes" + "context" + "database/sql" + sqldriver "database/sql/driver" + "fmt" + "io" + "log" + "strconv" + "strings" + "sync" + "testing" + + "github.com/dhui/dktest" + "github.com/golang-migrate/migrate/v4/dktesting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + pgPassword = "postgres" +) + +var ( + opts = dktest.Options{ + Env: map[string]string{"POSTGRES_PASSWORD": pgPassword}, + PortRequired: true, ReadyFunc: isReady, + } + // Supported versions: https://www.postgresql.org/support/versioning/ + specs = []dktesting.ContainerSpec{ + {ImageName: "postgres:12", Options: opts}, + } +) + +func pgConnectionString(host, port string) string { + return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port) +} + +func isReady(ctx context.Context, c dktest.ContainerInfo) bool { + ip, port, err := c.FirstPort() + if err != nil { + return false + } + + db, err := sql.Open("postgres", pgConnectionString(ip, port)) + if err != nil { + return false + } + defer func() { + if err := db.Close(); err != nil { + log.Println("close error:", err) + } + }() + if err = db.PingContext(ctx); err != nil { + switch err { + case sqldriver.ErrBadConn, io.EOF: + return false + default: + log.Println(err) + } + return false + } + + return true +} + +func TestDbStuff(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + d, err := open(t, ctx, addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.close(t); err != nil { + t.Error(err) + } + }() + test(t, d, []byte("SELECT 1")) + }) +} + +func TestVersion_NoVersionTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + d, err := open(t, ctx, addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.close(t); err != nil { + t.Error(err) + } + }() + // Drop the version table so calls to Version don't rely on that + d.drop(ctx) + + v, dirt, err := d.Version(ctx) + assert.NoError(t, err) + assert.Equal(t, v, nilVersion) + assert.False(t, dirt) + }) +} + +func TestMultiStatement(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + d, err := open(t, ctx, addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.close(t); err != nil { + t.Error(err) + } + }() + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + + // make sure second table exists + var exists bool + if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected table bar to exist") + } + }) +} + +func TestWithSchema(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + d, err := open(t, ctx, addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.close(t); err != nil { + t.Fatal(err) + } + }() + + // create foobar schema + if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil { + t.Fatal(err) + } + if err := d.SetVersion(ctx, 1, false); err != nil { + t.Fatal(err) + } + + // re-connect using that schema + d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar", + pgPassword, ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d2.close(t); err != nil { + t.Fatal(err) + } + }() + + version, _, err := d2.Version(ctx) + if err != nil { + t.Fatal(err) + } + if version != nilVersion { + t.Fatal("expected NilVersion") + } + + // now update Version and compare + if err := d2.SetVersion(ctx, 2, false); err != nil { + t.Fatal(err) + } + version, _, err = d2.Version(ctx) + if err != nil { + t.Fatal(err) + } + if version != 2 { + t.Fatal("expected Version 2") + } + + // meanwhile, the public schema still has the other Version + version, _, err = d.Version(ctx) + if err != nil { + t.Fatal(err) + } + if version != 1 { + t.Fatal("expected Version 2") + } + }) +} + +func TestPostgres_Lock(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + ps, err := open(t, ctx, addr) + if err != nil { + t.Fatal(err) + } + + test(t, ps, []byte("SELECT 1")) + + err = ps.Lock(ctx) + if err != nil { + t.Fatal(err) + } + + err = ps.Unlock(ctx) + if err != nil { + t.Fatal(err) + } + + err = ps.Lock(ctx) + if err != nil { + t.Fatal(err) + } + + err = ps.Unlock(ctx) + if err != nil { + t.Fatal(err) + } + }) +} + +func TestWithInstance_Concurrent(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + // The number of concurrent processes running New + const concurrency = 30 + + // We can instantiate a single database handle because it is + // actually a connection pool, and so, each of the below go + // routines will have a high probability of using a separate + // connection, which is something we want to exercise. + db, err := sql.Open("postgres", pgConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + db.SetMaxIdleConns(concurrency) + db.SetMaxOpenConns(concurrency) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func(i int) { + defer wg.Done() + _, err := New(context.Background(), db) + if err != nil { + t.Errorf("process %d error: %s", i, err) + } + }(i) + } + }) +} + +func TestRun_Error(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p, err := open(t, ctx, addr) + if err != nil { + require.NoError(t, err) + } + t.Cleanup(func() { + require.NoError(t, p.close(t)) + }) + + err = p.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo"))) + assert.Error(t, err) + }) +} + +func Test_computeLineFromPos(t *testing.T) { + testcases := []struct { + pos int + wantLine uint + wantCol uint + input string + wantOk bool + }{ + { + 15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists + }, + { + 16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line + }, + { + 25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error + }, + { + 27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines + }, + { + 10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo + }, + { + 11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line + }, + { + 17, 2, 8, "SELECT *\nFROM foo", true, // last character + }, + { + 18, 0, 0, "SELECT *\nFROM foo", false, // invalid position + }, + } + for i, tc := range testcases { + t.Run("tc"+strconv.Itoa(i), func(t *testing.T) { + run := func(crlf bool, nonASCII bool) { + var name string + if crlf { + name = "crlf" + } else { + name = "lf" + } + if nonASCII { + name += "-nonascii" + } else { + name += "-ascii" + } + t.Run(name, func(t *testing.T) { + input := tc.input + if crlf { + input = strings.Replace(input, "\n", "\r\n", -1) + } + if nonASCII { + input = strings.Replace(input, "FROM", "FRÖM", -1) + } + gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos) + + if tc.wantOk { + t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input) + } + + if gotOK != tc.wantOk { + t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK) + } + if gotLine != tc.wantLine { + t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine) + } + if gotCol != tc.wantCol { + t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol) + } + }) + } + run(false, false) + run(true, false) + run(false, true) + run(true, true) + }) + } +} diff --git a/internal/db/schema/postgres/testing.go b/internal/db/schema/postgres/testing.go new file mode 100644 index 0000000000..c1f10f5f0d --- /dev/null +++ b/internal/db/schema/postgres/testing.go @@ -0,0 +1,169 @@ +// The MIT License (MIT) +// +// Original Work +// Copyright (c) 2016 Matthias Kadenbach +// https://github.com/mattes/migrate +// +// Modified Work +// Copyright (c) 2018 Dale Hui +// https://github.com/golang-migrate/migrate +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package postgres + +import ( + "bytes" + "context" + "database/sql" + "io" + "testing" + "time" + + "github.com/golang-migrate/migrate/v4/database" + "github.com/stretchr/testify/require" +) + +// test runs tests against database implementations. +func test(t *testing.T, d *Postgres, migration []byte) { + if migration == nil { + t.Fatal("test must provide migration reader") + } + + testNilVersion(t, d) // test first + testLockAndUnlock(t, d) + testRun(t, d, bytes.NewReader(migration)) + testSetVersion(t, d) // also tests Version() + // drop breaks the driver, so test it last. + testDrop(t, d) +} + +func testNilVersion(t *testing.T, d *Postgres) { + ctx := context.Background() + v, _, err := d.Version(ctx) + if err != nil { + t.Fatal(err) + } + if v != database.NilVersion { + t.Fatalf("Version: expected Version to be NilVersion (-1), got %v", v) + } +} + +func testLockAndUnlock(t *testing.T, d *Postgres) { + ctx := context.Background() + + ctx, _ = context.WithTimeout(ctx, 15*time.Second) + + // locking twice is ok, no error + if err := d.Lock(ctx); err != nil { + t.Fatalf("got error, expected none: %v", err) + } + if err := d.Lock(ctx); err != nil { + t.Fatalf("got error, expected none: %v", err) + } + + // Unlock + if err := d.Unlock(ctx); err != nil { + t.Fatalf("error unlocking: %v", err) + } + + // try to Lock + if err := d.Lock(ctx); err != nil { + t.Fatalf("got error, expected none: %v", err) + } + if err := d.Unlock(ctx); err != nil { + t.Fatalf("got error, expected none: %v", err) + } +} + +func testRun(t *testing.T, d *Postgres, migration io.Reader) { + ctx := context.Background() + if migration == nil { + t.Fatal("migration can't be nil") + } + + if err := d.Run(ctx, migration); err != nil { + t.Fatal(err) + } +} + +func testDrop(t *testing.T, d *Postgres) { + ctx := context.Background() + if err := d.drop(ctx); err != nil { + t.Fatal(err) + } +} + +func testSetVersion(t *testing.T, d *Postgres) { + ctx := context.Background() + // nolint:maligned + testCases := []struct { + name string + version int + dirty bool + expectedErr error + expectedReadErr error + expectedVersion int + expectedDirty bool + }{ + {name: "set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true}, + {name: "re-set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true}, + {name: "set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false}, + {name: "re-set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false}, + {name: "last migration dirty", version: database.NilVersion, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: true}, + {name: "last migration clean", version: database.NilVersion, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := d.SetVersion(ctx, tc.version, tc.dirty) + if err != tc.expectedErr { + t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr) + } + v, dirty, readErr := d.Version(ctx) + if readErr != tc.expectedReadErr { + t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr) + } + if v != tc.expectedVersion { + t.Error("Got unexpected Version:", v, "!=", tc.expectedVersion) + } + if dirty != tc.expectedDirty { + t.Error("Got unexpected dirty value:", dirty, "!=", tc.dirty) + } + }) + } +} + +func open(t *testing.T, ctx context.Context, u string) (*Postgres, error) { + t.Helper() + db, err := sql.Open("postgres", u) + require.NoError(t, err) + + px, err := New(ctx, db) + require.NoError(t, err) + + return px, nil +} + +func (p *Postgres) close(t *testing.T) error { + t.Helper() + require.NoError(t, p.conn.Close()) + require.NoError(t, p.db.Close()) + return nil +} diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/schema/postgres_migration.gen.go similarity index 97% rename from internal/db/migrations/postgres.gen.go rename to internal/db/schema/postgres_migration.gen.go index c6f02c710e..054957de89 100644 --- a/internal/db/migrations/postgres.gen.go +++ b/internal/db/schema/postgres_migration.gen.go @@ -1,39 +1,13 @@ -// Code generated by "make migrations"; DO NOT EDIT. -package migrations - -// DevMigration is true if the database schema that would be applied by -// InitStore would be from files in the /dev directory which indicates it would -// not be safe to run in a non dev environment. -var DevMigration = false - -var postgresMigrations = map[string]*fakeFile{ - "migrations": { - name: "migrations", - }, - "migrations/01_domain_types.down.sql": { - name: "01_domain_types.down.sql", - bytes: []byte(` -begin; - -drop domain wt_timestamp; -drop domain wt_public_id; -drop domain wt_private_id; -drop domain wt_scope_id; -drop domain wt_user_id; -drop domain wt_version; - -drop function default_create_time; -drop function update_time_column; -drop function update_version_column; -drop function immutable_columns; +package schema -commit; +// Code generated by "make migrations"; DO NOT EDIT. -`), - }, - "migrations/01_domain_types.up.sql": { - name: "01_domain_types.up.sql", - bytes: []byte(` +func init() { + migrationStates["postgres"] = migrationState{ + devMigration: false, + binarySchemaVersion: 69, + upMigrations: map[int][]byte{ + 1: []byte(` begin; create domain wt_public_id as text @@ -194,23 +168,7 @@ is commit; `), - }, - "migrations/02_oplog.down.sql": { - name: "02_oplog.down.sql", - bytes: []byte(` -begin; - -drop table oplog_metadata cascade; -drop table oplog_ticket cascade; -drop table oplog_entry cascade; - -commit; - -`), - }, - "migrations/02_oplog.up.sql": { - name: "02_oplog.up.sql", - bytes: []byte(` + 2: []byte(` begin; -- TODO (jimlambrt 7/2020) remove update_time @@ -326,24 +284,7 @@ commit; `), - }, - "migrations/03_db.down.sql": { - name: "03_db.down.sql", - bytes: []byte(` -begin; - -drop table db_test_rental cascade; -drop table db_test_car cascade; -drop table db_test_user cascade; -drop table db_test_scooter cascade; - -commit; - -`), - }, - "migrations/03_db.up.sql": { - name: "03_db.up.sql", - bytes: []byte(` + 3: []byte(` begin; -- create test tables used in the unit tests for the internal/db package @@ -477,50 +418,7 @@ insert on db_test_scooter commit; `), - }, - "migrations/06_iam.down.sql": { - name: "06_iam.down.sql", - bytes: []byte(` -BEGIN; - -drop table iam_group cascade; -drop table iam_user cascade; -drop table iam_scope_project cascade; -drop table iam_scope_org cascade; -drop table iam_scope_global cascade; -drop table iam_scope cascade; -drop table iam_scope_type_enm cascade; -drop table iam_role cascade; -drop view iam_principal_role cascade; -drop table iam_group_role cascade; -drop table iam_user_role cascade; -drop table iam_group_member_user cascade; -drop view iam_group_member cascade; -drop table iam_role_grant cascade; - -drop function iam_sub_names cascade; -drop function iam_immutable_scope_type_func cascade; -drop function iam_sub_scopes_func cascade; -drop function iam_immutable_role_principal cascade; -drop function iam_user_role_scope_check cascade; -drop function iam_group_role_scope_check cascade; -drop function iam_group_member_scope_check cascade; -drop function iam_immutable_group_member cascade; -drop function get_scoped_member_id cascade; -drop function grant_scope_id_valid cascade; -drop function disallow_global_scope_deletion cascade; -drop function user_scope_id_valid cascade; -drop function iam_immutable_role_grant cascade; -drop function disallow_iam_predefined_user_deletion cascade; -drop function recovery_user_not_allowed cascade; - -COMMIT; - -`), - }, - "migrations/06_iam.up.sql": { - name: "06_iam.up.sql", - bytes: []byte(` + 6: []byte(` begin; create table iam_scope_type_enm ( @@ -1230,26 +1128,7 @@ where commit; `), - }, - "migrations/07_auth.down.sql": { - name: "07_auth.down.sql", - bytes: []byte(` -begin; - - drop function update_iam_user_auth_account; - drop function insert_auth_account_subtype; - drop function insert_auth_method_subtype; - - drop table auth_account cascade; - drop table auth_method cascade; - -commit; - -`), - }, - "migrations/07_auth.up.sql": { - name: "07_auth.up.sql", - bytes: []byte(` + 7: []byte(` begin; /* @@ -1418,21 +1297,7 @@ begin; commit; `), - }, - "migrations/08_servers.down.sql": { - name: "08_servers.down.sql", - bytes: []byte(` -begin; - - drop table server; - -commit; - -`), - }, - "migrations/08_servers.up.sql": { - name: "08_servers.up.sql", - bytes: []byte(` + 8: []byte(` begin; -- For now at least the IDs will be the same as the name, because this allows us @@ -1485,26 +1350,7 @@ update on recovery_nonces commit; `), - }, - "migrations/11_auth_token.down.sql": { - name: "11_auth_token.down.sql", - bytes: []byte(` -begin; - - drop view auth_token_account cascade; - drop table auth_token cascade; - - drop function update_last_access_time cascade; - drop function immutable_auth_token_columns cascade; - drop function expire_time_not_older_than_token cascade; - -commit; - -`), - }, - "migrations/11_auth_token.up.sql": { - name: "11_auth_token.up.sql", - bytes: []byte(` + 11: []byte(` begin; -- an auth token belongs to 1 and only 1 auth account @@ -1643,27 +1489,7 @@ begin; commit; `), - }, - "migrations/12_auth_password.down.sql": { - name: "12_auth_password.down.sql", - bytes: []byte(` -begin; - - drop table auth_password_credential; - drop table auth_password_conf cascade; - drop table if exists auth_password_account; - drop table if exists auth_password_method; - - drop function insert_auth_password_credential_subtype; - drop function insert_auth_password_conf_subtype; - -commit; - -`), - }, - "migrations/12_auth_password.up.sql": { - name: "12_auth_password.up.sql", - bytes: []byte(` + 12: []byte(` begin; /* @@ -1962,22 +1788,7 @@ begin; commit; `), - }, - "migrations/13_auth_password_argon.down.sql": { - name: "13_auth_password_argon.down.sql", - bytes: []byte(` -begin; - - drop table auth_password_argon2_cred; - drop table auth_password_argon2_conf; - -commit; - -`), - }, - "migrations/13_auth_password_argon.up.sql": { - name: "13_auth_password_argon.up.sql", - bytes: []byte(` + 13: []byte(` begin; create table auth_password_argon2_conf ( @@ -2129,22 +1940,7 @@ begin; commit; `), - }, - "migrations/14_auth_password_views.down.sql": { - name: "14_auth_password_views.down.sql", - bytes: []byte(` -begin; - - drop view auth_password_current_conf; - drop view auth_password_conf_union; - -commit; - -`), - }, - "migrations/14_auth_password_views.up.sql": { - name: "14_auth_password_views.up.sql", - bytes: []byte(` + 14: []byte(` begin; -- auth_password_conf_union is a union of the configuration settings @@ -2172,35 +1968,7 @@ begin; commit; `), - }, - "migrations/20_host.down.sql": { - name: "20_host.down.sql", - bytes: []byte(` -begin; - - drop table host_set; - drop table host; - drop table host_catalog; - - drop function insert_host_set_subtype; - drop function insert_host_subtype; - drop function insert_host_catalog_subtype; - - delete - from oplog_ticket - where name in ( - 'host_catalog', - 'host', - 'host_set' - ); - -commit; - -`), - }, - "migrations/20_host.up.sql": { - name: "20_host.up.sql", - bytes: []byte(` + 20: []byte(` begin; /* @@ -2371,33 +2139,7 @@ begin; commit; `), - }, - "migrations/22_static_host.down.sql": { - name: "22_static_host.down.sql", - bytes: []byte(` -begin; - - drop table static_host_set_member cascade; - drop table static_host_set cascade; - drop table static_host cascade; - drop table static_host_catalog cascade; - - delete - from oplog_ticket - where name in ( - 'static_host_catalog', - 'static_host', - 'static_host_set', - 'static_host_set_member' - ); - -commit; - -`), - }, - "migrations/22_static_host.up.sql": { - name: "22_static_host.up.sql", - bytes: []byte(` + 22: []byte(` begin; /* @@ -2601,21 +2343,7 @@ begin; commit; `), - }, - "migrations/30_keys.down.sql": { - name: "30_keys.down.sql", - bytes: []byte(` -begin; - -drop function kms_version_column cascade; - -commit; - -`), - }, - "migrations/30_keys.up.sql": { - name: "30_keys.up.sql", - bytes: []byte(` + 30: []byte(` begin; -- kms_version_column() will increment the version column whenever row data @@ -2646,28 +2374,7 @@ is commit; `), - }, - "migrations/31_keys.down.sql": { - name: "31_keys.down.sql", - bytes: []byte(` -begin; - -drop table kms_root_key cascade; -drop table kms_root_key_version cascade; -drop table kms_database_key cascade; -drop table kms_database_key_version cascade; -drop table kms_oplog_key cascade; -drop table kms_oplog_key_version cascade; -drop table kms_session_key cascade; -drop table kms_session_key_version cascade; - -commit; - -`), - }, - "migrations/31_keys.up.sql": { - name: "31_keys.up.sql", - bytes: []byte(` + 31: []byte(` begin; /* @@ -2994,24 +2701,7 @@ before insert on kms_token_key_version commit; `), - }, - "migrations/40_targets.down.sql": { - name: "40_targets.down.sql", - bytes: []byte(` -begin; - -drop function insert_target_subtype; -drop function delete_target_subtype; -drop function target_scope_valid; -drop function target_host_set_scope_valid - -commit; - -`), - }, - "migrations/40_targets.up.sql": { - name: "40_targets.up.sql", - bytes: []byte(` + 40: []byte(` begin; @@ -3083,31 +2773,7 @@ $$ language plpgsql; commit; `), - }, - "migrations/41_targets.down.sql": { - name: "41_targets.down.sql", - bytes: []byte(` -begin; - -drop table target cascade; -drop table target_host_set cascade; -drop table target_tcp; -drop view target_all_subtypes; -drop view target_host_set_catalog; - - -delete -from oplog_ticket -where name in ( - 'target_tcp' - ); - -commit; -`), - }, - "migrations/41_targets.up.sql": { - name: "41_targets.up.sql", - bytes: []byte(` + 41: []byte(` /* ┌─────────────────┐ ┌─────────────────┐ │ target_tcp │ @@ -3300,36 +2966,7 @@ values commit; `), - }, - "migrations/50_session.down.sql": { - name: "50_session.down.sql", - bytes: []byte(` -begin; - - drop table session_state; - drop table session_state_enm; - drop table session; - drop table session_termination_reason_enm; - drop function insert_session_state; - drop function insert_new_session_state; - drop function insert_session; - drop function update_session_state_on_termination_reason; - drop function insert_session_state; - - - delete - from oplog_ticket - where name in ( - 'session' - ); - -commit; - -`), - }, - "migrations/50_session.up.sql": { - name: "50_session.up.sql", - bytes: []byte(` + 50: []byte(` begin; /* @@ -3819,26 +3456,7 @@ begin; commit; `), - }, - "migrations/51_connection.down.sql": { - name: "51_connection.down.sql", - bytes: []byte(` -begin; - - drop table session_connection_state; - drop table session_connection_state_enm; - drop table session_connection; - drop table session_connection_closed_reason_enm; - drop function insert_session_connection_state; - drop function insert_new_connection_state; - drop function update_connection_state_on_closed_reason; -commit; - -`), - }, - "migrations/51_connection.up.sql": { - name: "51_connection.up.sql", - bytes: []byte(` + 51: []byte(` begin; /* @@ -4219,31 +3837,7 @@ create or replace function commit; `), - }, - "migrations/60_wh_domain_types.down.sql": { - name: "60_wh_domain_types.down.sql", - bytes: []byte(` -begin; - - drop function wh_current_time_id; - drop function wh_current_date_id; - drop function wh_time_id; - drop function wh_date_id; - drop domain wh_dim_text; - drop domain wh_timestamp; - drop domain wh_public_id; - drop domain wh_dim_id; - drop function wh_dim_id; - drop domain wh_bytes_transmitted; - drop domain wh_inet_port; - -commit; - -`), - }, - "migrations/60_wh_domain_types.up.sql": { - name: "60_wh_domain_types.up.sql", - bytes: []byte(` + 60: []byte(` begin; create extension if not exists "pgcrypto"; @@ -4330,22 +3924,7 @@ begin; commit; `), - }, - "migrations/62_wh_datetime.down.sql": { - name: "62_wh_datetime.down.sql", - bytes: []byte(` -begin; - - drop table wh_time_of_day_dimension; - drop table wh_date_dimension; - -commit; - -`), - }, - "migrations/62_wh_datetime.up.sql": { - name: "62_wh_datetime.up.sql", - bytes: []byte(` + 62: []byte(` begin; create table wh_date_dimension ( @@ -4445,26 +4024,7 @@ begin; commit; `), - }, - "migrations/65_wh_session_dimensions.down.sql": { - name: "65_wh_session_dimensions.down.sql", - bytes: []byte(` -begin; - - drop view whx_user_dimension_target; - drop view whx_user_dimension_source; - drop view whx_host_dimension_target; - drop view whx_host_dimension_source; - drop table wh_user_dimension; - drop table wh_host_dimension; - -commit; - -`), - }, - "migrations/65_wh_session_dimensions.up.sql": { - name: "65_wh_session_dimensions.up.sql", - bytes: []byte(` + 65: []byte(` begin; create table wh_host_dimension ( @@ -4696,22 +4256,7 @@ begin; commit; `), - }, - "migrations/66_wh_session_dimensions.down.sql": { - name: "66_wh_session_dimensions.down.sql", - bytes: []byte(` -begin; - - drop function wh_upsert_user; - drop function wh_upsert_host; - -commit; - -`), - }, - "migrations/66_wh_session_dimensions.up.sql": { - name: "66_wh_session_dimensions.up.sql", - bytes: []byte(` + 66: []byte(` begin; -- wh_upsert_host returns the wh_host_dimension id for p_host_id, @@ -4852,22 +4397,7 @@ begin; commit; `), - }, - "migrations/68_wh_session_facts.down.sql": { - name: "68_wh_session_facts.down.sql", - bytes: []byte(` -begin; - - drop table wh_session_connection_accumulating_fact; - drop table wh_session_accumulating_fact; - -commit; - -`), - }, - "migrations/68_wh_session_facts.up.sql": { - name: "68_wh_session_facts.up.sql", - bytes: []byte(` + 68: []byte(` begin; -- Column names for numeric fields that are not a measurement end in id or @@ -5056,36 +4586,7 @@ begin; commit; `), - }, - "migrations/69_wh_session_facts.down.sql": { - name: "69_wh_session_facts.down.sql", - bytes: []byte(` -begin; - - drop trigger wh_insert_session_connection_state on session_connection_state; - drop function wh_insert_session_connection_state; - - drop trigger wh_insert_session_state on session_state; - drop function wh_insert_session_state; - - drop trigger wh_update_session_connection on session_connection; - drop function wh_update_session_connection; - - drop trigger wh_insert_session_connection on session_connection; - drop function wh_insert_session_connection; - - drop trigger wh_insert_session on session; - drop function wh_insert_session; - - drop function wh_rollup_connections; - -commit; - -`), - }, - "migrations/69_wh_session_facts.up.sql": { - name: "69_wh_session_facts.up.sql", - bytes: []byte(` + 69: []byte(` begin; -- wh_rollup_connections calculates the aggregate values from @@ -5361,5 +4862,361 @@ begin; commit; `), - }, + }, + downMigrations: map[int][]byte{ + 1: []byte(` +begin; + +drop domain wt_timestamp; +drop domain wt_public_id; +drop domain wt_private_id; +drop domain wt_scope_id; +drop domain wt_user_id; +drop domain wt_version; + +drop function default_create_time; +drop function update_time_column; +drop function update_version_column; +drop function immutable_columns; + +commit; + +`), + 2: []byte(` +begin; + +drop table oplog_metadata cascade; +drop table oplog_ticket cascade; +drop table oplog_entry cascade; + +commit; + +`), + 3: []byte(` +begin; + +drop table db_test_rental cascade; +drop table db_test_car cascade; +drop table db_test_user cascade; +drop table db_test_scooter cascade; + +commit; + +`), + 6: []byte(` +BEGIN; + +drop table iam_group cascade; +drop table iam_user cascade; +drop table iam_scope_project cascade; +drop table iam_scope_org cascade; +drop table iam_scope_global cascade; +drop table iam_scope cascade; +drop table iam_scope_type_enm cascade; +drop table iam_role cascade; +drop view iam_principal_role cascade; +drop table iam_group_role cascade; +drop table iam_user_role cascade; +drop table iam_group_member_user cascade; +drop view iam_group_member cascade; +drop table iam_role_grant cascade; + +drop function iam_sub_names cascade; +drop function iam_immutable_scope_type_func cascade; +drop function iam_sub_scopes_func cascade; +drop function iam_immutable_role_principal cascade; +drop function iam_user_role_scope_check cascade; +drop function iam_group_role_scope_check cascade; +drop function iam_group_member_scope_check cascade; +drop function iam_immutable_group_member cascade; +drop function get_scoped_member_id cascade; +drop function grant_scope_id_valid cascade; +drop function disallow_global_scope_deletion cascade; +drop function user_scope_id_valid cascade; +drop function iam_immutable_role_grant cascade; +drop function disallow_iam_predefined_user_deletion cascade; +drop function recovery_user_not_allowed cascade; + +COMMIT; + +`), + 7: []byte(` +begin; + + drop function update_iam_user_auth_account; + drop function insert_auth_account_subtype; + drop function insert_auth_method_subtype; + + drop table auth_account cascade; + drop table auth_method cascade; + +commit; + +`), + 8: []byte(` +begin; + + drop table server; + +commit; + +`), + 11: []byte(` +begin; + + drop view auth_token_account cascade; + drop table auth_token cascade; + + drop function update_last_access_time cascade; + drop function immutable_auth_token_columns cascade; + drop function expire_time_not_older_than_token cascade; + +commit; + +`), + 12: []byte(` +begin; + + drop table auth_password_credential; + drop table auth_password_conf cascade; + drop table if exists auth_password_account; + drop table if exists auth_password_method; + + drop function insert_auth_password_credential_subtype; + drop function insert_auth_password_conf_subtype; + +commit; + +`), + 13: []byte(` +begin; + + drop table auth_password_argon2_cred; + drop table auth_password_argon2_conf; + +commit; + +`), + 14: []byte(` +begin; + + drop view auth_password_current_conf; + drop view auth_password_conf_union; + +commit; + +`), + 20: []byte(` +begin; + + drop table host_set; + drop table host; + drop table host_catalog; + + drop function insert_host_set_subtype; + drop function insert_host_subtype; + drop function insert_host_catalog_subtype; + + delete + from oplog_ticket + where name in ( + 'host_catalog', + 'host', + 'host_set' + ); + +commit; + +`), + 22: []byte(` +begin; + + drop table static_host_set_member cascade; + drop table static_host_set cascade; + drop table static_host cascade; + drop table static_host_catalog cascade; + + delete + from oplog_ticket + where name in ( + 'static_host_catalog', + 'static_host', + 'static_host_set', + 'static_host_set_member' + ); + +commit; + +`), + 30: []byte(` +begin; + +drop function kms_version_column cascade; + +commit; + +`), + 31: []byte(` +begin; + +drop table kms_root_key cascade; +drop table kms_root_key_version cascade; +drop table kms_database_key cascade; +drop table kms_database_key_version cascade; +drop table kms_oplog_key cascade; +drop table kms_oplog_key_version cascade; +drop table kms_session_key cascade; +drop table kms_session_key_version cascade; + +commit; + +`), + 40: []byte(` +begin; + +drop function insert_target_subtype; +drop function delete_target_subtype; +drop function target_scope_valid; +drop function target_host_set_scope_valid + +commit; + +`), + 41: []byte(` +begin; + +drop table target cascade; +drop table target_host_set cascade; +drop table target_tcp; +drop view target_all_subtypes; +drop view target_host_set_catalog; + + +delete +from oplog_ticket +where name in ( + 'target_tcp' + ); + +commit; +`), + 50: []byte(` +begin; + + drop table session_state; + drop table session_state_enm; + drop table session; + drop table session_termination_reason_enm; + drop function insert_session_state; + drop function insert_new_session_state; + drop function insert_session; + drop function update_session_state_on_termination_reason; + drop function insert_session_state; + + + delete + from oplog_ticket + where name in ( + 'session' + ); + +commit; + +`), + 51: []byte(` +begin; + + drop table session_connection_state; + drop table session_connection_state_enm; + drop table session_connection; + drop table session_connection_closed_reason_enm; + drop function insert_session_connection_state; + drop function insert_new_connection_state; + drop function update_connection_state_on_closed_reason; +commit; + +`), + 60: []byte(` +begin; + + drop function wh_current_time_id; + drop function wh_current_date_id; + drop function wh_time_id; + drop function wh_date_id; + drop domain wh_dim_text; + drop domain wh_timestamp; + drop domain wh_public_id; + drop domain wh_dim_id; + drop function wh_dim_id; + drop domain wh_bytes_transmitted; + drop domain wh_inet_port; + +commit; + +`), + 62: []byte(` +begin; + + drop table wh_time_of_day_dimension; + drop table wh_date_dimension; + +commit; + +`), + 65: []byte(` +begin; + + drop view whx_user_dimension_target; + drop view whx_user_dimension_source; + drop view whx_host_dimension_target; + drop view whx_host_dimension_source; + drop table wh_user_dimension; + drop table wh_host_dimension; + +commit; + +`), + 66: []byte(` +begin; + + drop function wh_upsert_user; + drop function wh_upsert_host; + +commit; + +`), + 68: []byte(` +begin; + + drop table wh_session_connection_accumulating_fact; + drop table wh_session_accumulating_fact; + +commit; + +`), + 69: []byte(` +begin; + + drop trigger wh_insert_session_connection_state on session_connection_state; + drop function wh_insert_session_connection_state; + + drop trigger wh_insert_session_state on session_state; + drop function wh_insert_session_state; + + drop trigger wh_update_session_connection on session_connection; + drop function wh_update_session_connection; + + drop trigger wh_insert_session_connection on session_connection; + drop function wh_insert_session_connection; + + drop trigger wh_insert_session on session; + drop function wh_insert_session; + + drop function wh_rollup_connections; + +commit; + +`), + }, + } } diff --git a/internal/db/schema/schema.go b/internal/db/schema/schema.go new file mode 100644 index 0000000000..bdde84e184 --- /dev/null +++ b/internal/db/schema/schema.go @@ -0,0 +1,40 @@ +package schema + +import ( + "context" + "database/sql" + + "github.com/hashicorp/boundary/internal/errors" +) + +// InitStore executes the migrations needed to initialize the store. It +// returns true if migrations actually ran; false if the database is already current. +func InitStore(ctx context.Context, dialect string, url string) (bool, error) { + const op = "schema.InitStore" + d, err := sql.Open(dialect, url) + if err != nil { + return false, errors.Wrap(err, op) + } + + sMan, err := NewManager(ctx, dialect, d) + if err != nil { + return false, errors.Wrap(err, op) + } + + st, err := sMan.CurrentState(ctx) + if err != nil { + return false, errors.Wrap(err, op) + } + if st.Dirty { + return false, errors.New(errors.MigrationIntegrity, op, "db marked dirty") + } + + if st.InitializationStarted && st.DatabaseSchemaVersion == st.BinarySchemaVersion { + return false, nil + } + + if err := sMan.RollForward(ctx); err != nil { + return false, errors.Wrap(err, op) + } + return true, nil +} diff --git a/internal/db/schema/schema_test.go b/internal/db/schema/schema_test.go new file mode 100644 index 0000000000..bafb81c682 --- /dev/null +++ b/internal/db/schema/schema_test.go @@ -0,0 +1,65 @@ +package schema + +import ( + "context" + "database/sql" + "testing" + + "github.com/hashicorp/boundary/internal/docker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInitStore(t *testing.T) { + dialect := "postgres" + ctx := context.Background() + + c, u, _, err := docker.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + + // Set the possible migration state to be only part of the full migration + oState := migrationStates[dialect] + nState := createPartialMigrationState(oState, 8) + migrationStates[dialect] = nState + + ran, err := InitStore(ctx, dialect, u) + assert.NoError(t, err) + assert.True(t, ran) + ran, err = InitStore(ctx, dialect, u) + assert.NoError(t, err) + assert.False(t, ran) + + // Reset the possible migration state to contain everything + migrationStates[dialect] = oState + + ran, err = InitStore(ctx, dialect, u) + assert.NoError(t, err) + assert.True(t, ran) + ran, err = InitStore(ctx, dialect, u) + assert.NoError(t, err) + assert.False(t, ran) +} + +func TestInitStore_Dirty(t *testing.T) { + dialect := "postgres" + ctx := context.Background() + + c, u, _, err := docker.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + + // Mark the db as dirty indicating a previously run failed migration + db, err := sql.Open(dialect, u) + require.NoError(t, err) + m, err := NewManager(ctx, dialect, db) + m.driver.SetVersion(ctx, -1, true) + + b, err := InitStore(ctx, dialect, u) + assert.Error(t, err) + assert.False(t, b) +} diff --git a/internal/db/schema/state.go b/internal/db/schema/state.go new file mode 100644 index 0000000000..1c0531a233 --- /dev/null +++ b/internal/db/schema/state.go @@ -0,0 +1,54 @@ +package schema + +const nilVersion = -1 + +// migrationState is meant to be populated by the generated migration code and +// contains the internal representation of a schema in the current binary. +type migrationState struct { + // devMigration is true if the database schema that would be applied by + // InitStore would be from files in the /dev directory which indicates it would + // not be safe to run in a non dev environment. + devMigration bool + + // binarySchemaVersion provides the database schema version supported by + // this binary. + binarySchemaVersion int + + upMigrations map[int][]byte + downMigrations map[int][]byte +} + +// migrationStates is populated by the generated migration code with the key being the dialect. +var migrationStates = make(map[string]migrationState) + +func getUpMigration(dialect string) map[int][]byte { + ms, ok := migrationStates[dialect] + if !ok { + return nil + } + return ms.upMigrations +} + +func getDownMigration(dialect string) map[int][]byte { + ms, ok := migrationStates[dialect] + if !ok { + return nil + } + return ms.downMigrations +} + +// DevMigration returns true iff the provided dialect has changes which are still in development. +func DevMigration(dialect string) bool { + ms, ok := migrationStates[dialect] + return ok && ms.devMigration +} + +// BinarySchemaVersion provides the schema version that this binary supports for the provided dialect. +// If the binary doesn't support this dialect -1 is returned. +func BinarySchemaVersion(dialect string) int { + ms, ok := migrationStates[dialect] + if !ok { + return nilVersion + } + return ms.binarySchemaVersion +} diff --git a/internal/db/schema/state_test.go b/internal/db/schema/state_test.go new file mode 100644 index 0000000000..2031e66513 --- /dev/null +++ b/internal/db/schema/state_test.go @@ -0,0 +1,23 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBinarySchemaVersion(t *testing.T) { + dialect := "test_binaryschemaversion" + migrationStates[dialect] = migrationState{binarySchemaVersion: 3} + assert.Equal(t, 3, BinarySchemaVersion(dialect)) + assert.Equal(t, nilVersion, BinarySchemaVersion("unknown_dialect")) +} + +func TestDevMigration(t *testing.T) { + dialect := "test_devmigrations" + migrationStates[dialect] = migrationState{devMigration: true} + assert.True(t, DevMigration(dialect)) + migrationStates[dialect] = migrationState{devMigration: false} + assert.False(t, DevMigration(dialect)) + assert.False(t, DevMigration("unknown_dialect")) +} diff --git a/internal/db/schema/statement.go b/internal/db/schema/statement.go new file mode 100644 index 0000000000..ea243e3e53 --- /dev/null +++ b/internal/db/schema/statement.go @@ -0,0 +1,59 @@ +package schema + +import ( + "fmt" + "sort" + + "github.com/hashicorp/boundary/internal/errors" +) + +// statementProvider provides the migration statements in order. +// Next should be called prior to calling Version() or ReadUp() or sentinel +// values (-1 and nil) will be returned. +type statementProvider struct { + pos int + versions []int + up, down map[int][]byte +} + +func newStatementProvider(dialect string, curVer int) (*statementProvider, error) { + op := errors.Op("schema.newStatementProvider") + qp := statementProvider{pos: -1} + qp.up, qp.down = getUpMigration(dialect), getDownMigration(dialect) + if len(qp.up) != len(qp.down) { + return nil, errors.New(errors.MigrationIntegrity, op, fmt.Sprintf("Mismatch up/down size: up %d vs. down %d", len(qp.up), len(qp.down))) + } + for k := range qp.up { + if _, ok := qp.down[k]; !ok { + return nil, errors.New(errors.MigrationIntegrity, op, fmt.Sprintf("Up key %d doesn't exist in down %v", k, qp.down)) + } + qp.versions = append(qp.versions, k) + } + sort.Ints(qp.versions) + + for len(qp.versions) > 0 && qp.versions[0] <= curVer { + qp.versions = qp.versions[1:] + } + + return &qp, nil +} + +func (q *statementProvider) Next() bool { + q.pos++ + return len(q.versions) > q.pos +} + +func (q *statementProvider) Version() int { + if q.pos < 0 || q.pos >= len(q.versions) { + return -1 + } + return q.versions[q.pos] +} + +// ReadUp reads the current up migration +func (q *statementProvider) ReadUp() []byte { + if q.pos < 0 || q.pos >= len(q.versions) { + return nil + } + return q.up[q.versions[q.pos]] +} diff --git a/internal/db/schema/statement_test.go b/internal/db/schema/statement_test.go new file mode 100644 index 0000000000..3c6d5f256c --- /dev/null +++ b/internal/db/schema/statement_test.go @@ -0,0 +1,87 @@ +package schema + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStatementProvider(t *testing.T) { + testDialect := "test" + migrationStates[testDialect] = migrationState{ + binarySchemaVersion: 5, + upMigrations: map[int][]byte{ + 1: []byte("one"), + 2: []byte("two"), + 3: []byte("three"), + }, + downMigrations: map[int][]byte{ + 1: []byte("down one"), + 2: []byte("down two"), + 3: []byte("down three"), + }, + } + + st, err := newStatementProvider(testDialect, 1) + assert.NoError(t, err) + assert.Equal(t, -1, st.Version()) + assert.Equal(t, []byte(nil), st.ReadUp()) + + assert.True(t, st.Next()) + assert.Equal(t, 2, st.Version()) + assert.Equal(t, []byte("two"), st.ReadUp()) + + assert.True(t, st.Next()) + assert.Equal(t, 3, st.Version()) + assert.Equal(t, []byte("three"), st.ReadUp()) + + assert.False(t, st.Next()) + assert.Equal(t, -1, st.Version()) + assert.Equal(t, []byte(nil), st.ReadUp()) + + assert.False(t, st.Next()) + assert.Equal(t, -1, st.Version()) + assert.Equal(t, []byte(nil), st.ReadUp()) + + st, err = newStatementProvider("unknown_dialect", nilVersion) + assert.NoError(t, err) + assert.False(t, st.Next()) +} + +func TestStatementProvider_error(t *testing.T) { + cases := []struct { + name string + in migrationState + }{ + { + name: "mismatchLength", + in: migrationState{ + binarySchemaVersion: 5, + upMigrations: map[int][]byte{ + 1: []byte("one"), + }, + downMigrations: map[int][]byte{}, + }, + }, + { + name: "mismatchVersions", + in: migrationState{ + binarySchemaVersion: 5, + upMigrations: map[int][]byte{ + 1: []byte("one"), + }, + downMigrations: map[int][]byte{ + 2: []byte("two"), + }, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + migrationStates[tc.name] = tc.in + defer delete(migrationStates, tc.name) + _, err := newStatementProvider(tc.name, -1) + assert.Error(t, err) + }) + } +} diff --git a/internal/db/testing.go b/internal/db/testing.go index eb1958daa7..b410ae22f7 100644 --- a/internal/db/testing.go +++ b/internal/db/testing.go @@ -10,6 +10,7 @@ import ( _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" + "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/oplog/store" wrapping "github.com/hashicorp/go-kms-wrapping" @@ -23,6 +24,7 @@ func TestSetup(t *testing.T, dialect string, opt ...TestOption) (*gorm.DB, strin var cleanup func() error var url string var err error + ctx := context.Background() opts := getTestOpts(opt...) @@ -36,13 +38,14 @@ func TestSetup(t *testing.T, dialect string, opt ...TestOption) (*gorm.DB, strin assert.NoError(t, cleanup(), "Got error cleaning up db in docker.") }) default: - cleanup = func() error { return nil } url = opts.withTestDatabaseUrl } - _, err = InitStore(dialect, cleanup, url) + + _, err = schema.InitStore(ctx, dialect, url) if err != nil { - t.Fatal(err) + t.Fatalf("Couldn't init store on existing db: %v", err) } + db, err := gorm.Open(dialect, url) if err != nil { t.Fatal(err) diff --git a/internal/docker/supported.go b/internal/docker/supported.go index 290231dc42..f1601b8928 100644 --- a/internal/docker/supported.go +++ b/internal/docker/supported.go @@ -30,7 +30,7 @@ func startDbInDockerSupported(dialect string) (cleanup func() error, retURL, con resource, err = pool.Run("postgres", "12", []string{"POSTGRES_PASSWORD=password", "POSTGRES_DB=boundary"}) url = "postgres://postgres:password@localhost:%s?sslmode=disable" if err == nil { - url = fmt.Sprintf("postgres://postgres:password@%s?sslmode=disable", resource.GetHostPort("5432/tcp")) + url = fmt.Sprintf("postgres://postgres:password@%s/boundary?sslmode=disable", resource.GetHostPort("5432/tcp")) } default: panic(fmt.Sprintf("unknown dialect %q", dialect)) diff --git a/internal/errors/code.go b/internal/errors/code.go index 2dfd23f72e..6d94b8cdb4 100644 --- a/internal/errors/code.go +++ b/internal/errors/code.go @@ -56,10 +56,14 @@ const ( NotNull Code = 1001 // NotNull represents a value must not be null error NotUnique Code = 1002 // NotUnique represents a value must be unique error NotSpecificIntegrity Code = 1003 // NotSpecificIntegrity represents an integrity error that has no specific domain error code - MissingTable Code = 1004 // Missing table represents an undefined table error + MissingTable Code = 1004 // MissingTable represents an undefined table error RecordNotFound Code = 1100 // RecordNotFound represents that a record/row was not found matching the criteria MultipleRecords Code = 1101 // MultipleRecords represents that multiple records/rows were found matching the criteria ColumnNotFound Code = 1102 // ColumnNotFound represent that a column was not found in the underlying db MaxRetries Code = 1103 // MaxRetries represent that a db Tx hit max retires allowed Exception Code = 1104 // Exception represent that an underlying db exception was raised + + // Migration setup errors are codes 2000-3000 + MigrationIntegrity Code = 2000 // MigrationIntegrity represents an error with the generated migration related code + MigrationLock Code = 2001 // MigrationLock represents an error related to locking of the DB ) diff --git a/internal/errors/code_test.go b/internal/errors/code_test.go index d2bf3144c2..3ba1c0121d 100644 --- a/internal/errors/code_test.go +++ b/internal/errors/code_test.go @@ -162,6 +162,16 @@ func TestCode_Both_String_Info(t *testing.T) { c: MissingTable, want: MissingTable, }, + { + name: "MigrationIntegrity", + c: MigrationIntegrity, + want: MigrationIntegrity, + }, + { + name: "MigrationLock", + c: MigrationLock, + want: MigrationLock, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/errors/info.go b/internal/errors/info.go index 91555d6df3..925522f241 100644 --- a/internal/errors/info.go +++ b/internal/errors/info.go @@ -124,4 +124,12 @@ var errorCodeInfo = map[Code]Info{ Message: "too many retries", Kind: Transaction, }, + MigrationIntegrity: { + Message: "migration integrity", + Kind: Integrity, + }, + MigrationLock: { + Message: "bad db lock", + Kind: Integrity, + }, } diff --git a/internal/oplog/testing.go b/internal/oplog/testing.go index 79b0c653c5..d294a709d4 100644 --- a/internal/oplog/testing.go +++ b/internal/oplog/testing.go @@ -1,11 +1,12 @@ package oplog import ( + "context" "crypto/rand" + "database/sql" "testing" - "github.com/golang-migrate/migrate/v4" - "github.com/hashicorp/boundary/internal/db/migrations" + "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/oplog/oplog_test" wrapping "github.com/hashicorp/go-kms-wrapping" @@ -79,16 +80,12 @@ func testWrapper(t *testing.T) wrapping.Wrapper { // testInitStore will execute the migrations needed to initialize the store for tests func testInitStore(t *testing.T, cleanup func() error, url string) { t.Helper() - // run migrations - source, err := migrations.NewMigrationSource("postgres") - require.NoError(t, err, "Error creating migration source") - m, err := migrate.NewWithSourceInstance("postgres", source, url) - require.NoError(t, err, "Error creating migrations") + ctx := context.Background() + dialect := "postgres" - if err := m.Up(); err != nil && err != migrate.ErrNoChange { - if err := cleanup(); err != nil { - t.Fatalf("error cleaning up after migration failure: %v", err) - } - require.NoError(t, err, "Error running migrations") - } + d, err := sql.Open(dialect, url) + require.NoError(t, err) + sm, err := schema.NewManager(ctx, dialect, d) + require.NoError(t, err) + require.NoError(t, sm.RollForward(ctx)) } diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 751eb8f35d..330fc5a23c 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -14,7 +14,7 @@ import ( "github.com/hashicorp/boundary/internal/authtoken" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/config" - "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/servers" @@ -324,10 +324,10 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { } // Base server - tc.b = base.NewServer(nil) - tc.b.Command = &base.Command{ + tc.b = base.NewServer(&base.Command{ + Context: ctx, ShutdownCh: make(chan struct{}), - } + }) // Get dev config, or use a provided one var err error @@ -412,7 +412,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { if opts.DatabaseUrl != "" { tc.b.DatabaseUrl = opts.DatabaseUrl - if _, err := db.InitStore("postgres", nil, tc.b.DatabaseUrl); err != nil { + if _, err := schema.InitStore(ctx, "postgres", tc.b.DatabaseUrl); err != nil { t.Fatal(err) } if err := tc.b.ConnectToDatabase("postgres"); err != nil { @@ -423,7 +423,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { t.Fatal(err) } if !opts.DisableInitialLoginRoleCreation { - if _, err := tc.b.CreateInitialLoginRole(context.Background()); err != nil { + if _, err := tc.b.CreateInitialLoginRole(ctx); err != nil { t.Fatal(err) } if !opts.DisableAuthMethodCreation { @@ -453,7 +453,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { if opts.DisableAuthMethodCreation { createOpts = append(createOpts, base.WithSkipAuthMethodCreation()) } - if err := tc.b.CreateDevDatabase("postgres", createOpts...); err != nil { + if err := tc.b.CreateDevDatabase(ctx, "postgres", createOpts...); err != nil { t.Fatal(err) } }