From dc416ce77197d63c6be2e0ce22396b17ffe47cc8 Mon Sep 17 00:00:00 2001 From: FPiety0521 Date: Mon, 14 Oct 2019 03:13:57 -0700 Subject: [PATCH] Only populate db config values if it's not specified See also: https://github.com/golang-migrate/migrate/issues/262 --- database/cockroachdb/cockroachdb.go | 20 ++++++++------- database/mysql/mysql.go | 20 ++++++++------- database/redshift/redshift.go | 20 ++++++++------- database/sqlserver/sqlserver.go | 38 ++++++++++++++++------------- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 8d6c470..ab547e2 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -61,17 +61,19 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - query := `SELECT current_database()` - var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { - return nil, &database.Error{OrigErr: err, Query: []byte(query)} - } + if config.DatabaseName == "" { + query := `SELECT current_database()` + var databaseName string + if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } - if len(databaseName) == 0 { - return nil, ErrNoDatabaseName - } + if len(databaseName) == 0 { + return nil, ErrNoDatabaseName + } - config.DatabaseName = databaseName + config.DatabaseName = databaseName + } if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 27accc3..2c0e0ae 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -63,17 +63,19 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - query := `SELECT DATABASE()` - var databaseName sql.NullString - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { - return nil, &database.Error{OrigErr: err, Query: []byte(query)} - } + if config.DatabaseName == "" { + query := `SELECT DATABASE()` + var databaseName sql.NullString + if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } - if len(databaseName.String) == 0 { - return nil, ErrNoDatabaseName - } + if len(databaseName.String) == 0 { + return nil, ErrNoDatabaseName + } - config.DatabaseName = databaseName.String + config.DatabaseName = databaseName.String + } if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 3d018bd..f8f1644 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -53,17 +53,19 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - query := `SELECT CURRENT_DATABASE()` - var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { - return nil, &database.Error{OrigErr: err, Query: []byte(query)} - } + if config.DatabaseName == "" { + query := `SELECT CURRENT_DATABASE()` + var databaseName string + if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } - if len(databaseName) == 0 { - return nil, ErrNoDatabaseName - } + if len(databaseName) == 0 { + return nil, ErrNoDatabaseName + } - config.DatabaseName = databaseName + config.DatabaseName = databaseName + } if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index ef99291..3d0514a 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -65,30 +65,34 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, err } - query := `SELECT DB_NAME()` - var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { - return nil, &database.Error{OrigErr: err, Query: []byte(query)} - } + if config.DatabaseName == "" { + query := `SELECT DB_NAME()` + var databaseName string + if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } + + if len(databaseName) == 0 { + return nil, ErrNoDatabaseName + } - if len(databaseName) == 0 { - return nil, ErrNoDatabaseName + config.DatabaseName = databaseName } - config.DatabaseName = databaseName + if config.SchemaName == "" { + query := `SELECT SCHEMA_NAME()` + var schemaName string + if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } - query = `SELECT SCHEMA_NAME()` - var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { - return nil, &database.Error{OrigErr: err, Query: []byte(query)} - } + if len(schemaName) == 0 { + return nil, ErrNoSchema + } - if len(schemaName) == 0 { - return nil, ErrNoSchema + config.SchemaName = schemaName } - config.SchemaName = schemaName - if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable }