diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 29bb9a276..4a1f6bf91 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -14,15 +14,10 @@ import ( nurl "net/url" "strconv" "strings" -) -import ( "github.com/go-sql-driver/mysql" - "github.com/hashicorp/go-multierror" -) - -import ( "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" ) func init() { diff --git a/database/sqlserver/README.md b/database/sqlserver/README.md index 86a0b79f7..c4ef5a3a3 100644 --- a/database/sqlserver/README.md +++ b/database/sqlserver/README.md @@ -16,6 +16,7 @@ | `dial+timeout` | | in seconds (default is 15), set to 0 for no timeout. | | `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. | | `app+name` || The application name (default is go-mssqldb). | +| `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported. See https://github.com/denisenkom/go-mssqldb for full parameter list. diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 0f8252f3e..024001871 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -4,11 +4,15 @@ import ( "context" "database/sql" "fmt" - "go.uber.org/atomic" "io" "io/ioutil" nurl "net/url" + "strconv" + "strings" + + "go.uber.org/atomic" + "github.com/Azure/go-autorest/autorest/adal" mssql "github.com/denisenkom/go-mssqldb" // mssql support "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" @@ -23,10 +27,11 @@ func init() { var DefaultMigrationsTable = "schema_migrations" var ( - ErrNilConfig = fmt.Errorf("no config") - ErrNoDatabaseName = fmt.Errorf("no database name") - ErrNoSchema = fmt.Errorf("no schema") - ErrDatabaseDirty = fmt.Errorf("database is dirty") + ErrNilConfig = fmt.Errorf("no config") + ErrNoDatabaseName = fmt.Errorf("no database name") + ErrNoSchema = fmt.Errorf("no schema") + ErrDatabaseDirty = fmt.Errorf("database is dirty") + ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") ) var lockErrorMap = map[mssql.ReturnStatus]string{ @@ -117,16 +122,49 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return ss, nil } -// Open a connection to the database +// Open a connection to the database. func (ss *SQLServer) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { return nil, err } - db, err := sql.Open("sqlserver", migrate.FilterCustomQuery(purl).String()) - if err != nil { - return nil, err + useMsiParam := purl.Query().Get("useMsi") + useMsi := false + if len(useMsiParam) > 0 { + useMsi, err = strconv.ParseBool(useMsiParam) + if err != nil { + return nil, err + } + } + + if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet { + return nil, ErrMultipleAuthOptionsPassed + } + + filteredURL := migrate.FilterCustomQuery(purl).String() + + var db *sql.DB + if useMsi { + resource := getAADResourceFromServerUri(purl) + tokenProvider, err := getMSITokenProvider(resource) + if err != nil { + return nil, err + } + + connector, err := mssql.NewAccessTokenConnector( + filteredURL, tokenProvider) + if err != nil { + return nil, err + } + + db = sql.OpenDB(connector) + + } else { + db, err = sql.Open("sqlserver", filteredURL) + if err != nil { + return nil, err + } } migrationsTable := purl.Query().Get("x-migrations-table") @@ -339,3 +377,26 @@ func (ss *SQLServer) ensureVersionTable() (err error) { return nil } + +func getMSITokenProvider(resource string) (func() (string, error), error) { + msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) + if err != nil { + return nil, err + } + + return func() (string, error) { + err := msi.EnsureFresh() + if err != nil { + return "", err + } + token := msi.OAuthToken() + return token, nil + }, nil +} + +// The sql server resource can change across clouds so get it +// dynamically based on the server uri. +// ex. .database.windows.net -> https://database.windows.net +func getAADResourceFromServerUri(purl *nurl.URL) string { + return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], ".")) +} diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index 7bf393759..ad0dc79ed 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -38,6 +38,14 @@ func msConnectionString(host, port string) string { return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port) } +func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string { + return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master&useMsi=%t", saPassword, host, port, useMsi) +} + +func msConnectionStringMsi(host, port string, useMsi bool) string { + return fmt.Sprintf("sqlserver://sa@%v:%v?database=master&useMsi=%t", host, port, useMsi) +} + func isReady(ctx context.Context, c dktest.ContainerInfo) bool { ip, port, err := c.Port(defaultPort) if err != nil { @@ -218,3 +226,66 @@ func TestLockWorks(t *testing.T) { } }) } + +func TestMsiTrue(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := msConnectionStringMsi(ip, port, true) + p := &SQLServer{} + _, err = p.Open(addr) + if err == nil { + t.Fatal("MSI should fail when not running in an Azure context.") + } + }) +} + +func TestOpenWithPasswordAndMSI(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := msConnectionStringMsiWithPassword(ip, port, true) + p := &SQLServer{} + _, err = p.Open(addr) + if err == nil { + t.Fatal("Open should fail when both password and useMsi=true are passed.") + } + + addr = msConnectionStringMsiWithPassword(ip, port, false) + p = &SQLServer{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + dt.Test(t, d, []byte("SELECT 1")) + }) +} + +func TestMsiFalse(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := msConnectionStringMsi(ip, port, false) + p := &SQLServer{} + _, err = p.Open(addr) + if err == nil { + t.Fatal("Open should fail since no password was passed and useMsi is false.") + } + }) +} diff --git a/go.mod b/go.mod index 4e0a1ae89..df6b3cca8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/golang-migrate/migrate/v4 require ( cloud.google.com/go/spanner v1.18.0 cloud.google.com/go/storage v1.10.0 + github.com/Azure/go-autorest/autorest/adal v0.9.14 github.com/ClickHouse/clickhouse-go v1.4.3 github.com/apache/arrow/go/arrow v0.0.0-20210521153258-78c88a9f517b // indirect github.com/aws/aws-sdk-go v1.17.7 @@ -12,7 +13,7 @@ require ( github.com/cenkalti/backoff/v4 v4.0.2 github.com/cockroachdb/cockroach-go v0.0.0-20190925194419-606b3d062051 github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369 // indirect - github.com/denisenkom/go-mssqldb v0.0.0-20200620013148-b91950f658ec + github.com/denisenkom/go-mssqldb v0.10.0 github.com/dhui/dktest v0.3.4 github.com/docker/docker v17.12.0-ce-rc1.0.20210128214336-420b1d36250f+incompatible github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 // indirect diff --git a/go.sum b/go.sum index 3517c1da9..141223772 100644 --- a/go.sum +++ b/go.sum @@ -48,11 +48,15 @@ github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 h1:w+iIsaOQNcT7O github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= -github.com/Azure/go-autorest/autorest/adal v0.9.2 h1:Aze/GQeAN1RRbGmnUJvUj+tFGBzFdIg3293/A9rbxC4= github.com/Azure/go-autorest/autorest/adal v0.9.2/go.mod h1:/3SMAM86bP6wC9Ev35peQDUeqFZBMH07vvUOmg4z/fE= +github.com/Azure/go-autorest/autorest/adal v0.9.14 h1:G8hexQdV5D4khOXrWG2YuLCFKhWYmWD8bHYaXN5ophk= +github.com/Azure/go-autorest/autorest/adal v0.9.14/go.mod h1:W/MM4U6nLxnIskrw4UwWzlHfGjwUS50aOsc/I3yuU8M= github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= +github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk= github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= +github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg= +github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -136,9 +140,8 @@ github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369/go.mod h1:e6NPNENfs github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20200620013148-b91950f658ec h1:NfhRXXFDPxcF5Cwo06DzeIaE7uuJtAUhsDwH3LNsjos= -github.com/denisenkom/go-mssqldb v0.0.0-20200620013148-b91950f658ec/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dhui/dktest v0.3.4 h1:VbUEcaSP+U2/yUr9d2JhSThXYEnDlGabRSHe2rIE46E= github.com/dhui/dktest v0.3.4/go.mod h1:4m4n6lmXlmVfESth7mzdcv8nBI5mOb5UROPqjM02csU= @@ -506,6 +509,7 @@ golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=