From d0f8e55ccb390b9c1f803b3a6c4f2e7874f40337 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Fri, 14 Oct 2022 12:45:34 -0600 Subject: [PATCH] feat: add support for impersonation (#1460) Fixes #417 --- .envrc.example | 4 ++ .github/workflows/tests.yaml | 2 + README.md | 1 + cmd/root.go | 28 +++++++- cmd/root_test.go | 13 ++++ internal/proxy/proxy.go | 90 ++++++++++++++++++++---- tests/common_test.go | 9 +++ tests/connection_test.go | 5 +- tests/mysql_test.go | 132 +++++++++++++++++++++++------------ tests/postgres_test.go | 120 ++++++++++++++++++++----------- tests/sqlserver_test.go | 107 +++++++++++++++++++++------- 11 files changed, 380 insertions(+), 131 deletions(-) diff --git a/.envrc.example b/.envrc.example index 8bf5e9e3d..6c31e2d55 100644 --- a/.envrc.example +++ b/.envrc.example @@ -17,3 +17,7 @@ export SQLSERVER_PASS="sqlserver-password" export SQLSERVER_DB="sqlserver-db-name" export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json + +# Requires the impersonating IAM principal to have +# roles/iam.serviceAccountTokenCreator +export IMPERSONATED_USER="some-user-with-db-access@example.com" diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f86f22a7d..661bacf7e 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -108,6 +108,7 @@ jobs: SQLSERVER_USER:${{ secrets.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ secrets.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS SQLSERVER_DB:${{ secrets.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_DB + IMPERSONATED_USER:${{ secrets.GOOGLE_CLOUD_PROJECT }}/IMPERSONATED_USER - name: Enable fuse config (Linux) if: runner.os == 'Linux' @@ -130,6 +131,7 @@ jobs: SQLSERVER_USER: '${{ steps.secrets.outputs.SQLSERVER_USER }}' SQLSERVER_PASS: '${{ steps.secrets.outputs.SQLSERVER_PASS }}' SQLSERVER_DB: '${{ steps.secrets.outputs.SQLSERVER_DB }}' + IMPERSONATED_USER: '${{ steps.secrets.outputs.IMPERSONATED_USER }}' TMPDIR: "/tmp" TMP: '${{ runner.temp }}' # specifying bash shell ensures a failure in a piped process isn't lost by using `set -eo pipefail` diff --git a/README.md b/README.md index 4f8dc4ee3..a55eee456 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ The Cloud SQL Auth proxy has support for: - [Automatic IAM Authentication][iam-auth] (Postgres only) - Metrics ([Cloud Monitoring][], [Cloud Trace][], and [Prometheus][]) - [HTTP Healthchecks][health-check-example] +- Service account impersonation - Separate Dialer functionality released as the [Cloud SQL Go Connector][go connector] - Fully POSIX-compliant flags diff --git a/cmd/root.go b/cmd/root.go index 20f707123..06cc52997 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -82,6 +82,13 @@ type Command struct { healthCheck bool httpAddress string httpPort string + + // impersonationChain is a comma separated list of one or more service + // accounts. The last entry in the chain is the impersonation target. Any + // additional service accounts before the target are delegates. The + // roles/iam.serviceAccountTokenCreator must be configured for each account + // that will be impersonated. + impersonationChain string } // Option is a function that configures a Command. @@ -253,6 +260,9 @@ https://cloud.google.com/storage/docs/requester-pays`) cmd.PersistentFlags().StringVar(&c.conf.FUSETempDir, "fuse-tmp-dir", filepath.Join(os.TempDir(), "csql-tmp"), "Temp dir for Unix sockets created with FUSE") + cmd.PersistentFlags().StringVar(&c.impersonationChain, "impersonate-service-account", "", + `Comma separated list of service accounts to impersonate. Last value +is the target account.`) // Global and per instance flags cmd.PersistentFlags().StringVarP(&c.conf.Addr, "address", "a", "127.0.0.1", @@ -338,7 +348,10 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { if userHasSet("sqladmin-api-endpoint") && conf.APIEndpointURL != "" { _, err := url.Parse(conf.APIEndpointURL) if err != nil { - return newBadCommandError(fmt.Sprintf("the value provided for --sqladmin-api-endpoint is not a valid URL, %v", conf.APIEndpointURL)) + return newBadCommandError(fmt.Sprintf( + "the value provided for --sqladmin-api-endpoint is not a valid URL, %v", + conf.APIEndpointURL, + )) } // add a trailing '/' if omitted @@ -347,6 +360,19 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { } } + if cmd.impersonationChain != "" { + accts := strings.Split(cmd.impersonationChain, ",") + conf.ImpersonateTarget = accts[0] + // Assign delegates if the chain is more than one account. Delegation + // goes from last back towards target, e.g., With sa1,sa2,sa3, sa3 + // delegates to sa2, which impersonates the target sa1. + if l := len(accts); l > 1 { + for i := l - 1; i > 0; i-- { + conf.ImpersonateDelegates = append(conf.ImpersonateDelegates, accts[i]) + } + } + } + var ics []proxy.InstanceConnConfig for _, a := range args { // Assume no query params initially diff --git a/cmd/root_test.go b/cmd/root_test.go index 7c751c3f5..66879144a 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -279,6 +279,19 @@ func TestNewCommandArguments(t *testing.T) { QuotaProject: "proj", }), }, + { + desc: "", + args: []string{"--impersonate-service-account", + "sv1@developer.gserviceaccount.com,sv2@developer.gserviceaccount.com,sv3@developer.gserviceaccount.com", + "proj:region:inst"}, + want: withDefaults(&proxy.Config{ + ImpersonateTarget: "sv1@developer.gserviceaccount.com", + ImpersonateDelegates: []string{ + "sv3@developer.gserviceaccount.com", + "sv2@developer.gserviceaccount.com", + }, + }), + }, } for _, tc := range tcs { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index c5cce383c..b5b2d2f65 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -30,6 +30,9 @@ import ( "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/cloudsql" "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/gcloud" "golang.org/x/oauth2" + "google.golang.org/api/impersonate" + "google.golang.org/api/option" + "google.golang.org/api/sqladmin/v1" ) var ( @@ -160,6 +163,15 @@ type Config struct { // API request quotas. QuotaProject string + // ImpersonateTarget is the service account to impersonate. The IAM + // principal doing the impersonation must have the + // roles/iam.serviceAccountTokenCreator role. + ImpersonateTarget string + // ImpersonateDelegates are the intermediate service accounts through which + // the impersonation is achieved. Each delegate must have the + // roles/iam.serviceAccountTokenCreator role. + ImpersonateDelegates []string + // StructuredLogs sets all output to use JSON in the LogEntry format. // See https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry StructuredLogs bool @@ -187,38 +199,86 @@ func (c *Config) DialOptions(i InstanceConnConfig) []cloudsqlconn.DialOption { return opts } -// DialerOptions builds appropriate list of options from the Config -// values for use by cloudsqlconn.NewClient() -func (c *Config) DialerOptions(l cloudsql.Logger) ([]cloudsqlconn.Option, error) { - opts := []cloudsqlconn.Option{ - cloudsqlconn.WithUserAgent(c.UserAgent), +func (c *Config) credentialsOpt(l cloudsql.Logger) (cloudsqlconn.Option, error) { + // If service account impersonation is configured, set up an impersonated + // credentials token source. + if c.ImpersonateTarget != "" { + var iopts []option.ClientOption + switch { + case c.Token != "": + l.Infof("Impersonating service account with OAuth2 token") + iopts = append(iopts, option.WithTokenSource( + oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}), + )) + case c.CredentialsFile != "": + l.Infof("Impersonating service account with the credentials file at %q", c.CredentialsFile) + iopts = append(iopts, option.WithCredentialsFile(c.CredentialsFile)) + case c.CredentialsJSON != "": + l.Infof("Impersonating service account with JSON credentials environment variable") + iopts = append(iopts, option.WithCredentialsJSON([]byte(c.CredentialsJSON))) + case c.GcloudAuth: + l.Infof("Impersonating service account with gcloud user credentials") + ts, err := gcloud.TokenSource() + if err != nil { + return nil, err + } + iopts = append(iopts, option.WithTokenSource(ts)) + default: + l.Infof("Impersonating service account with Application Default Credentials") + } + ts, err := impersonate.CredentialsTokenSource( + context.Background(), + impersonate.CredentialsConfig{ + TargetPrincipal: c.ImpersonateTarget, + Delegates: c.ImpersonateDelegates, + Scopes: []string{sqladmin.SqlserviceAdminScope}, + }, + iopts..., + ) + if err != nil { + return nil, err + } + return cloudsqlconn.WithTokenSource(ts), nil } + + // Otherwise, configure credentials as usual. switch { case c.Token != "": - l.Infof("Authorizing with the -token flag") - opts = append(opts, cloudsqlconn.WithTokenSource( + l.Infof("Authorizing with OAuth2 token") + return cloudsqlconn.WithTokenSource( oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}), - )) + ), nil case c.CredentialsFile != "": l.Infof("Authorizing with the credentials file at %q", c.CredentialsFile) - opts = append(opts, cloudsqlconn.WithCredentialsFile( - c.CredentialsFile, - )) + return cloudsqlconn.WithCredentialsFile(c.CredentialsFile), nil case c.CredentialsJSON != "": l.Infof("Authorizing with JSON credentials environment variable") - opts = append(opts, cloudsqlconn.WithCredentialsJSON( - []byte(c.CredentialsJSON), - )) + return cloudsqlconn.WithCredentialsJSON([]byte(c.CredentialsJSON)), nil case c.GcloudAuth: l.Infof("Authorizing with gcloud user credentials") ts, err := gcloud.TokenSource() if err != nil { return nil, err } - opts = append(opts, cloudsqlconn.WithTokenSource(ts)) + return cloudsqlconn.WithTokenSource(ts), nil default: l.Infof("Authorizing with Application Default Credentials") + // Return no-op options to avoid having to handle nil in caller code + return cloudsqlconn.WithOptions(), nil + } +} + +// DialerOptions builds appropriate list of options from the Config +// values for use by cloudsqlconn.NewClient() +func (c *Config) DialerOptions(l cloudsql.Logger) ([]cloudsqlconn.Option, error) { + opts := []cloudsqlconn.Option{ + cloudsqlconn.WithUserAgent(c.UserAgent), + } + co, err := c.credentialsOpt(l) + if err != nil { + return nil, err } + opts = append(opts, co) if c.APIEndpointURL != "" { opts = append(opts, cloudsqlconn.WithAdminAPIEndpoint(c.APIEndpointURL)) diff --git a/tests/common_test.go b/tests/common_test.go index 063d344cf..71c5a2c60 100644 --- a/tests/common_test.go +++ b/tests/common_test.go @@ -25,6 +25,7 @@ import ( "bytes" "context" "errors" + "flag" "fmt" "io" "os" @@ -34,6 +35,14 @@ import ( "github.com/GoogleCloudPlatform/cloud-sql-proxy/v2/internal/log" ) +var ( + impersonatedUser = flag.String( + "impersonated_user", + os.Getenv("IMPERSONATED_USER"), + "Name of the service account that supports impersonation (impersonator must have roles/iam.serviceAccountTokenCreator)", + ) +) + // ProxyExec represents an execution of the Cloud SQL proxy. type ProxyExec struct { Out io.ReadCloser diff --git a/tests/connection_test.go b/tests/connection_test.go index f0aafa4e4..e216b033f 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -26,7 +26,6 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" - "google.golang.org/api/sqladmin/v1" ) const connTestTimeout = time.Minute @@ -35,7 +34,9 @@ const connTestTimeout = time.Minute // and then unsets GOOGLE_APPLICATION_CREDENTIALS. It returns a cleanup function // that restores the original setup. func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) { - ts, err := google.DefaultTokenSource(context.Background(), sqladmin.SqlserviceAdminScope) + ts, err := google.DefaultTokenSource(context.Background(), + "https://www.googleapis.com/auth/cloud-platform", + ) if err != nil { t.Errorf("failed to resolve token source: %v", err) } diff --git a/tests/mysql_test.go b/tests/mysql_test.go index 9a290b9e2..d3ec74cbf 100644 --- a/tests/mysql_test.go +++ b/tests/mysql_test.go @@ -44,11 +44,7 @@ func requireMySQLVars(t *testing.T) { } } -func TestMySQLTCP(t *testing.T) { - if testing.Short() { - t.Skip("skipping MySQL integration tests") - } - requireMySQLVars(t) +func mysqlDSN() string { cfg := mysql.Config{ User: *mysqlUser, Passwd: *mysqlPass, @@ -57,7 +53,15 @@ func TestMySQLTCP(t *testing.T) { Addr: "127.0.0.1:3306", Net: "tcp", } - proxyConnTest(t, []string{*mysqlConnName}, "mysql", cfg.FormatDSN()) + return cfg.FormatDSN() +} + +func TestMySQLTCP(t *testing.T) { + if testing.Short() { + t.Skip("skipping MySQL integration tests") + } + requireMySQLVars(t) + proxyConnTest(t, []string{*mysqlConnName}, "mysql", mysqlDSN()) } func TestMySQLUnix(t *testing.T) { @@ -82,66 +86,102 @@ func TestMySQLUnix(t *testing.T) { []string{"--unix-socket", tmpDir, *mysqlConnName}, "mysql", cfg.FormatDSN()) } -func TestMySQLAuthWithToken(t *testing.T) { +func TestMySQLImpersonation(t *testing.T) { if testing.Short() { t.Skip("skipping MySQL integration tests") } requireMySQLVars(t) - tok, _, cleanup := removeAuthEnvVar(t) - defer cleanup() - cfg := mysql.Config{ - User: *mysqlUser, - Passwd: *mysqlPass, - DBName: *mysqlDB, - AllowNativePasswords: true, - Addr: "127.0.0.1:3306", - Net: "tcp", - } - proxyConnTest(t, - []string{"--token", tok.AccessToken, *mysqlConnName}, - "mysql", cfg.FormatDSN()) + proxyConnTest(t, []string{ + "--impersonate-service-account", *impersonatedUser, + *mysqlConnName}, + "mysql", mysqlDSN()) } -func TestMySQLAuthWithCredentialsFile(t *testing.T) { +func TestMySQLAuthentication(t *testing.T) { if testing.Short() { t.Skip("skipping MySQL integration tests") } requireMySQLVars(t) - _, path, cleanup := removeAuthEnvVar(t) + + creds := keyfile(t) + tok, path, cleanup := removeAuthEnvVar(t) defer cleanup() - cfg := mysql.Config{ - User: *mysqlUser, - Passwd: *mysqlPass, - DBName: *mysqlDB, - AllowNativePasswords: true, - Addr: "127.0.0.1:3306", - Net: "tcp", + tcs := []struct { + desc string + args []string + }{ + { + desc: "with token", + args: []string{"--token", tok.AccessToken, *mysqlConnName}, + }, + { + desc: "with token and impersonation", + args: []string{ + "--token", tok.AccessToken, + "--impersonate-service-account", *impersonatedUser, + *mysqlConnName}, + }, + { + desc: "with credentials file", + args: []string{"--credentials-file", path, *mysqlConnName}, + }, + { + desc: "with credentials file and impersonation", + args: []string{ + "--credentials-file", path, + "--impersonate-service-account", *impersonatedUser, + *mysqlConnName}, + }, + { + desc: "with credentials JSON", + args: []string{"--json-credentials", string(creds), *mysqlConnName}, + }, + { + desc: "with credentials JSON and impersonation", + args: []string{ + "--json-credentials", string(creds), + "--impersonate-service-account", *impersonatedUser, + *mysqlConnName}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTest(t, tc.args, "mysql", mysqlDSN()) + }) } - proxyConnTest(t, - []string{"--credentials-file", path, *mysqlConnName}, - "mysql", cfg.FormatDSN()) } -func TestMySQLAuthWithCredentialsJSON(t *testing.T) { +func TestMySQLGcloudAuth(t *testing.T) { if testing.Short() { t.Skip("skipping MySQL integration tests") } requireMySQLVars(t) - creds := keyfile(t) - _, _, cleanup := removeAuthEnvVar(t) - defer cleanup() - cfg := mysql.Config{ - User: *mysqlUser, - Passwd: *mysqlPass, - DBName: *mysqlDB, - AllowNativePasswords: true, - Addr: "127.0.0.1:3306", - Net: "tcp", + tcs := []struct { + desc string + args []string + }{ + { + desc: "gcloud user authentication", + args: []string{"--gcloud-auth", *mysqlConnName}, + }, + { + desc: "gcloud user authentication with impersonation", + args: []string{ + "--gcloud-auth", + "--impersonate-service-account", *impersonatedUser, + *mysqlConnName}, + }, } - proxyConnTest(t, - []string{"--json-credentials", creds, *mysqlConnName}, - "mysql", cfg.FormatDSN()) + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTest(t, tc.args, "mysql", mysqlDSN()) + }) + } +} + +func TestMySQLHealthCheck(t *testing.T) { + testHealthCheck(t, *mysqlConnName) } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 9ba9692c9..fcab1d79c 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -31,8 +31,7 @@ var ( postgresUser = flag.String("postgres_user", os.Getenv("POSTGRES_USER"), "Name of database user.") postgresPass = flag.String("postgres_pass", os.Getenv("POSTGRES_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") postgresDB = flag.String("postgres_db", os.Getenv("POSTGRES_DB"), "Name of the database to connect to.") - - postgresIAMUser = flag.String("postgres_user_iam", os.Getenv("POSTGRES_USER_IAM"), "Name of database user configured with IAM DB Authentication.") + postgresIAMUser = flag.String("postgres_user_iam", os.Getenv("POSTGRES_USER_IAM"), "Name of database user configured with IAM DB Authentication.") ) func requirePostgresVars(t *testing.T) { @@ -48,15 +47,18 @@ func requirePostgresVars(t *testing.T) { } } +func postgresDSN() string { + return fmt.Sprintf("host=localhost user=%s password=%s database=%s sslmode=disable", + *postgresUser, *postgresPass, *postgresDB) +} + func TestPostgresTCP(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") } requirePostgresVars(t) - dsn := fmt.Sprintf("host=localhost user=%s password=%s database=%s sslmode=disable", - *postgresUser, *postgresPass, *postgresDB) - proxyConnTest(t, []string{*postgresConnName}, "pgx", dsn) + proxyConnTest(t, []string{*postgresConnName}, "pgx", postgresDSN()) } func TestPostgresUnix(t *testing.T) { @@ -89,63 +91,101 @@ func createTempDir(t *testing.T) (string, func()) { } } -func TestPostgresAuthWithToken(t *testing.T) { +func TestPostgresImpersonation(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") } requirePostgresVars(t) - tok, _, cleanup := removeAuthEnvVar(t) - defer cleanup() - dsn := fmt.Sprintf("host=localhost user=%s password=%s database=%s sslmode=disable", - *postgresUser, *postgresPass, *postgresDB) - proxyConnTest(t, - []string{"--token", tok.AccessToken, *postgresConnName}, - "pgx", dsn) + proxyConnTest(t, []string{ + "--impersonate-service-account", *impersonatedUser, + *postgresConnName}, + "pgx", postgresDSN()) } -func TestPostgresAuthWithCredentialsFile(t *testing.T) { +func TestPostgresAuthentication(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") } requirePostgresVars(t) - _, path, cleanup := removeAuthEnvVar(t) - defer cleanup() - - dsn := fmt.Sprintf("host=localhost user=%s password=%s database=%s sslmode=disable", - *postgresUser, *postgresPass, *postgresDB) - proxyConnTest(t, - []string{"--credentials-file", path, *postgresConnName}, - "pgx", dsn) -} -func TestPostgresAuthWithCredentialsJSON(t *testing.T) { - if testing.Short() { - t.Skip("skipping Postgres integration tests") - } - requirePostgresVars(t) creds := keyfile(t) - _, _, cleanup := removeAuthEnvVar(t) + tok, path, cleanup := removeAuthEnvVar(t) defer cleanup() - dsn := fmt.Sprintf("host=localhost user=%s password=%s database=%s sslmode=disable", - *postgresUser, *postgresPass, *postgresDB) - proxyConnTest(t, - []string{"--json-credentials", string(creds), *postgresConnName}, - "pgx", dsn) + tcs := []struct { + desc string + args []string + }{ + { + desc: "with token", + args: []string{"--token", tok.AccessToken, *postgresConnName}, + }, + { + desc: "with token and impersonation", + args: []string{ + "--token", tok.AccessToken, + "--impersonate-service-account", *impersonatedUser, + *postgresConnName}, + }, + { + desc: "with credentials file", + args: []string{"--credentials-file", path, *postgresConnName}, + }, + { + desc: "with credentials file and impersonation", + args: []string{ + "--credentials-file", path, + "--impersonate-service-account", *impersonatedUser, + *postgresConnName}, + }, + { + desc: "with credentials JSON", + args: []string{"--json-credentials", string(creds), *postgresConnName}, + }, + { + desc: "with credentials JSON and impersonation", + args: []string{ + "--json-credentials", string(creds), + "--impersonate-service-account", *impersonatedUser, + *postgresConnName}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTest(t, tc.args, "pgx", postgresDSN()) + }) + } } -func TestAuthWithGcloudAuth(t *testing.T) { +func TestPostgresGcloudAuth(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests") } requirePostgresVars(t) - dsn := fmt.Sprintf("host=localhost user=%s password=%s database=%s sslmode=disable", - *postgresUser, *postgresPass, *postgresDB) - proxyConnTest(t, - []string{"--gcloud-auth", *postgresConnName}, - "pgx", dsn) + tcs := []struct { + desc string + args []string + }{ + { + desc: "gcloud user authentication", + args: []string{"--gcloud-auth", *postgresConnName}, + }, + { + desc: "gcloud user authentication with impersonation", + args: []string{ + "--gcloud-auth", + "--impersonate-service-account", *impersonatedUser, + *postgresConnName}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTest(t, tc.args, "pgx", postgresDSN()) + }) + } + } func TestPostgresIAMDBAuthn(t *testing.T) { diff --git a/tests/sqlserver_test.go b/tests/sqlserver_test.go index 36d2c16dc..6e82f7507 100644 --- a/tests/sqlserver_test.go +++ b/tests/sqlserver_test.go @@ -44,61 +44,114 @@ func requireSQLServerVars(t *testing.T) { } } +func sqlserverDSN() string { + return fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", + *sqlserverUser, *sqlserverPass, *sqlserverDB) +} + func TestSQLServerTCP(t *testing.T) { if testing.Short() { t.Skip("skipping SQL Server integration tests") } requireSQLServerVars(t) - dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", - *sqlserverUser, *sqlserverPass, *sqlserverDB) - proxyConnTest(t, []string{*sqlserverConnName}, "sqlserver", dsn) + proxyConnTest(t, []string{*sqlserverConnName}, "sqlserver", sqlserverDSN()) } -func TestSQLServerAuthWithToken(t *testing.T) { +func TestSQLServerImpersonation(t *testing.T) { if testing.Short() { t.Skip("skipping SQL Server integration tests") } requireSQLServerVars(t) - tok, _, cleanup := removeAuthEnvVar(t) - defer cleanup() - dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", - *sqlserverUser, *sqlserverPass, *sqlserverDB) - proxyConnTest(t, - []string{"--token", tok.AccessToken, *sqlserverConnName}, - "sqlserver", dsn) + proxyConnTest(t, []string{ + "--impersonate-service-account", *impersonatedUser, + *sqlserverConnName}, + "sqlserver", sqlserverDSN()) } -func TestSQLServerAuthWithCredentialsFile(t *testing.T) { +func TestSQLServerAuthentication(t *testing.T) { if testing.Short() { t.Skip("skipping SQL Server integration tests") } requireSQLServerVars(t) - _, path, cleanup := removeAuthEnvVar(t) + + creds := keyfile(t) + tok, path, cleanup := removeAuthEnvVar(t) defer cleanup() - dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", - *sqlserverUser, *sqlserverPass, *sqlserverDB) - proxyConnTest(t, - []string{"--credentials-file", path, *sqlserverConnName}, - "sqlserver", dsn) + tcs := []struct { + desc string + args []string + }{ + { + desc: "with token", + args: []string{"--token", tok.AccessToken, *sqlserverConnName}, + }, + { + desc: "with token and impersonation", + args: []string{ + "--token", tok.AccessToken, + "--impersonate-service-account", *impersonatedUser, + *sqlserverConnName}, + }, + { + desc: "with credentials file", + args: []string{"--credentials-file", path, *sqlserverConnName}, + }, + { + desc: "with credentials file and impersonation", + args: []string{ + "--credentials-file", path, + "--impersonate-service-account", *impersonatedUser, + *sqlserverConnName}, + }, + { + desc: "with credentials JSON", + args: []string{"--json-credentials", string(creds), *sqlserverConnName}, + }, + { + desc: "with credentials JSON and impersonation", + args: []string{ + "--json-credentials", string(creds), + "--impersonate-service-account", *impersonatedUser, + *sqlserverConnName}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTest(t, tc.args, "sqlserver", sqlserverDSN()) + }) + } } -func TestSQLServerAuthWithCredentialsJSON(t *testing.T) { +func TestSQLServerGcloudAuth(t *testing.T) { if testing.Short() { t.Skip("skipping SQL Server integration tests") } requireSQLServerVars(t) - creds := keyfile(t) - _, _, cleanup := removeAuthEnvVar(t) - defer cleanup() - dsn := fmt.Sprintf("sqlserver://%s:%s@127.0.0.1?database=%s", - *sqlserverUser, *sqlserverPass, *sqlserverDB) - proxyConnTest(t, - []string{"--json-credentials", creds, *sqlserverConnName}, - "sqlserver", dsn) + tcs := []struct { + desc string + args []string + }{ + { + desc: "gcloud user authentication", + args: []string{"--gcloud-auth", *sqlserverConnName}, + }, + { + desc: "gcloud user authentication with impersonation", + args: []string{ + "--gcloud-auth", + "--impersonate-service-account", *impersonatedUser, + *sqlserverConnName}, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + proxyConnTest(t, tc.args, "sqlserver", sqlserverDSN()) + }) + } } func TestSQLServerHealthCheck(t *testing.T) {