diff --git a/.envrc.example b/.envrc.example index 63468994..0f425b08 100644 --- a/.envrc.example +++ b/.envrc.example @@ -4,3 +4,7 @@ export ALLOYDB_PASS="postgres-password" export ALLOYDB_DB="postgres-db-name" export GOOGLE_APPLICATION_CREDENTIALS="path/to/credentials" + +# Requires the impersonating IAM principal to have +# roles/iam.serviceAccountTokenCreator +export IMPERSONATED_USER="some-user-with-db-access@example.com" diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6d9e5847..7c9c1f41 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -90,6 +90,7 @@ jobs: secrets: |- ALLOYDB_CONN_NAME:${{ secrets.GOOGLE_CLOUD_PROJECT }}/ALLOYDB_CONN_NAME ALLOYDB_CLUSTER_PASS:${{ secrets.GOOGLE_CLOUD_PROJECT }}/ALLOYDB_CLUSTER_PASS + IMPERSONATED_USER:${{ secrets.GOOGLE_CLOUD_PROJECT }}/IMPERSONATED_USER - name: Run tests env: @@ -97,6 +98,7 @@ jobs: ALLOYDB_USER: 'postgres' ALLOYDB_PASS: '${{ steps.secrets.outputs.ALLOYDB_CLUSTER_PASS }}' ALLOYDB_CONNECTION_NAME: '${{ steps.secrets.outputs.ALLOYDB_CONN_NAME }}' + IMPERSONATED_USER: '${{ steps.secrets.outputs.IMPERSONATED_USER }}' # specifying bash shell ensures a failure in a piped process isn't lost by using `set -eo pipefail` shell: bash run: | diff --git a/cmd/root.go b/cmd/root.go index a22ea6cb..c466080a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -94,6 +94,13 @@ type Command struct { healthCheck bool httpAddress string httpPort string + + // impersonationChain is a comma separated list of one or more service + // accounts. The last entry in the chain is the impersonation target. Any + // additional service accounts before the target are delegates. The + // roles/iam.serviceAccountTokenCreator must be configured for each account + // that will be impersonated. + impersonationChain string } // Option is a function that configures a Command. @@ -183,6 +190,9 @@ the maximum time has passed. Defaults to 0s.`) cmd.PersistentFlags().StringVar(&c.conf.FUSETempDir, "fuse-tmp-dir", filepath.Join(os.TempDir(), "csql-tmp"), "Temp dir for Unix sockets created with FUSE") + cmd.PersistentFlags().StringVar(&c.impersonationChain, "impersonate-service-account", "", + `Comma separated list of service accounts to impersonate. Last value ++is the target account.`) cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "", "Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.") @@ -274,7 +284,10 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { if userHasSet("alloydbadmin-api-endpoint") { _, err := url.Parse(conf.APIEndpointURL) if err != nil { - return newBadCommandError(fmt.Sprintf("provided value for --alloydbadmin-api-endpoint is not a valid url, %v", conf.APIEndpointURL)) + return newBadCommandError(fmt.Sprintf( + "provided value for --alloydbadmin-api-endpoint is not a valid url, %v", + conf.APIEndpointURL, + )) } // Remove trailing '/' if included @@ -298,6 +311,19 @@ func parseConfig(cmd *Command, conf *proxy.Config, args []string) error { cmd.logger.Infof("Ignoring --disable-traces as --telemetry-project was not set") } + if cmd.impersonationChain != "" { + accts := strings.Split(cmd.impersonationChain, ",") + conf.ImpersonateTarget = accts[0] + // Assign delegates if the chain is more than one account. Delegation + // goes from last back towards target, e.g., With sa1,sa2,sa3, sa3 + // delegates to sa2, which impersonates the target sa1. + if l := len(accts); l > 1 { + for i := l - 1; i > 0; i-- { + conf.ImpersonateDelegates = append(conf.ImpersonateDelegates, accts[i]) + } + } + } + var ics []proxy.InstanceConnConfig for _, a := range args { // Assume no query params initially diff --git a/cmd/root_test.go b/cmd/root_test.go index 2c50dab4..865305d9 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -221,6 +221,19 @@ func TestNewCommandArguments(t *testing.T) { CredentialsJSON: `{"json":"goes-here"}`, }), }, + { + desc: "", + args: []string{"--impersonate-service-account", + "sv1@developer.gserviceaccount.com,sv2@developer.gserviceaccount.com,sv3@developer.gserviceaccount.com", + "projects/proj/locations/region/clusters/clust/instances/inst"}, + want: withDefaults(&proxy.Config{ + ImpersonateTarget: "sv1@developer.gserviceaccount.com", + ImpersonateDelegates: []string{ + "sv3@developer.gserviceaccount.com", + "sv2@developer.gserviceaccount.com", + }, + }), + }, } for _, tc := range tcs { diff --git a/go.mod b/go.mod index 5f15c3d1..dae59646 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( go.uber.org/zap v1.24.0 golang.org/x/oauth2 v0.2.0 golang.org/x/sys v0.3.0 + google.golang.org/api v0.103.0 ) require ( @@ -57,7 +58,6 @@ require ( golang.org/x/sync v0.1.0 // indirect golang.org/x/text v0.4.0 // indirect golang.org/x/time v0.2.0 // indirect - google.golang.org/api v0.103.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 // indirect google.golang.org/grpc v1.51.0 // indirect diff --git a/go.sum b/go.sum index 2f030dce..f349565c 100644 --- a/go.sum +++ b/go.sum @@ -1979,6 +1979,7 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 2b0bf6ef..95b7e97f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -31,6 +31,9 @@ import ( "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/gcloud" "golang.org/x/oauth2" + "google.golang.org/api/impersonate" + "google.golang.org/api/option" + "google.golang.org/api/sqladmin/v1" ) // InstanceConnConfig holds the configuration for an individual instance @@ -104,43 +107,102 @@ type Config struct { // regardless of any open connections. WaitOnClose time.Duration + // ImpersonateTarget is the service account to impersonate. The IAM + // principal doing the impersonation must have the + // roles/iam.serviceAccountTokenCreator role. + ImpersonateTarget string + // ImpersonateDelegates are the intermediate service accounts through which + // the impersonation is achieved. Each delegate must have the + // roles/iam.serviceAccountTokenCreator role. + ImpersonateDelegates []string + // StructuredLogs sets all output to use JSON in the LogEntry format. // See https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry StructuredLogs bool } -// DialerOptions builds appropriate list of options from the Config -// values for use by alloydbconn.NewClient() -func (c *Config) DialerOptions(l alloydb.Logger) ([]alloydbconn.Option, error) { - opts := []alloydbconn.Option{ - alloydbconn.WithUserAgent(c.UserAgent), +func (c *Config) credentialsOpt(l alloydb.Logger) (alloydbconn.Option, error) { + // If service account impersonation is configured, set up an impersonated + // credentials token source. + if c.ImpersonateTarget != "" { + var iopts []option.ClientOption + switch { + case c.Token != "": + l.Infof("Impersonating service account with OAuth2 token") + iopts = append(iopts, option.WithTokenSource( + oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}), + )) + case c.CredentialsFile != "": + l.Infof("Impersonating service account with the credentials file at %q", c.CredentialsFile) + iopts = append(iopts, option.WithCredentialsFile(c.CredentialsFile)) + case c.CredentialsJSON != "": + l.Infof("Impersonating service account with JSON credentials environment variable") + iopts = append(iopts, option.WithCredentialsJSON([]byte(c.CredentialsJSON))) + case c.GcloudAuth: + l.Infof("Impersonating service account with gcloud user credentials") + ts, err := gcloud.TokenSource() + if err != nil { + return nil, err + } + iopts = append(iopts, option.WithTokenSource(ts)) + default: + l.Infof("Impersonating service account with Application Default Credentials") + } + ts, err := impersonate.CredentialsTokenSource( + context.Background(), + impersonate.CredentialsConfig{ + TargetPrincipal: c.ImpersonateTarget, + Delegates: c.ImpersonateDelegates, + Scopes: []string{sqladmin.SqlserviceAdminScope}, + }, + iopts..., + ) + if err != nil { + return nil, err + } + return alloydbconn.WithTokenSource(ts), nil } - opts = append(opts, alloydbconn.WithAdminAPIEndpoint(c.APIEndpointURL)) + // Otherwise, configure credentials as usual. switch { case c.Token != "": - l.Infof("Authorizing with the -token flag") - opts = append(opts, alloydbconn.WithTokenSource( + l.Infof("Authorizing with OAuth2 token") + return alloydbconn.WithTokenSource( oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}), - )) + ), nil case c.CredentialsFile != "": l.Infof("Authorizing with the credentials file at %q", c.CredentialsFile) - opts = append(opts, alloydbconn.WithCredentialsFile( - c.CredentialsFile, - )) + return alloydbconn.WithCredentialsFile(c.CredentialsFile), nil + case c.CredentialsJSON != "": + l.Infof("Authorizing with JSON credentials environment variable") + return alloydbconn.WithCredentialsJSON([]byte(c.CredentialsJSON)), nil case c.GcloudAuth: l.Infof("Authorizing with gcloud user credentials") ts, err := gcloud.TokenSource() if err != nil { return nil, err } - opts = append(opts, alloydbconn.WithTokenSource(ts)) - case c.CredentialsJSON != "": - l.Infof("Authorizing with JSON credentials environment variable") - opts = append(opts, alloydbconn.WithCredentialsJSON( - []byte(c.CredentialsJSON), - )) + return alloydbconn.WithTokenSource(ts), nil default: l.Infof("Authorizing with Application Default Credentials") + // Return no-op options to avoid having to handle nil in caller code + return alloydbconn.WithOptions(), nil + } +} + +// DialerOptions builds appropriate list of options from the Config +// values for use by alloydbconn.NewClient() +func (c *Config) DialerOptions(l alloydb.Logger) ([]alloydbconn.Option, error) { + opts := []alloydbconn.Option{ + alloydbconn.WithUserAgent(c.UserAgent), + } + co, err := c.credentialsOpt(l) + if err != nil { + return nil, err + } + opts = append(opts, co) + + if c.APIEndpointURL != "" { + opts = append(opts, alloydbconn.WithAdminAPIEndpoint(c.APIEndpointURL)) } return opts, nil diff --git a/tests/alloydb_test.go b/tests/alloydb_test.go index d3aa94db..60e5c037 100644 --- a/tests/alloydb_test.go +++ b/tests/alloydb_test.go @@ -26,10 +26,15 @@ import ( ) var ( - alloydbConnName = flag.String("alloydb_conn_name", os.Getenv("ALLOYDB_CONNECTION_NAME"), "AlloyDB instance connection name, in the form of 'project:region:instance'.") - alloydbUser = flag.String("alloydb_user", os.Getenv("ALLOYDB_USER"), "Name of database user.") - alloydbPass = flag.String("alloydb_pass", os.Getenv("ALLOYDB_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") - alloydbDB = flag.String("alloydb_db", os.Getenv("ALLOYDB_DB"), "Name of the database to connect to.") + alloydbConnName = flag.String("alloydb_conn_name", os.Getenv("ALLOYDB_CONNECTION_NAME"), "AlloyDB instance connection name, in the form of 'project:region:instance'.") + alloydbUser = flag.String("alloydb_user", os.Getenv("ALLOYDB_USER"), "Name of database user.") + alloydbPass = flag.String("alloydb_pass", os.Getenv("ALLOYDB_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).") + alloydbDB = flag.String("alloydb_db", os.Getenv("ALLOYDB_DB"), "Name of the database to connect to.") + impersonatedUser = flag.String( + "impersonated_user", + os.Getenv("IMPERSONATED_USER"), + "Name of the service account that supports impersonation (impersonator must have roles/iam.serviceAccountTokenCreator)", + ) ) func requirePostgresVars(t *testing.T) { diff --git a/tests/connection_test.go b/tests/connection_test.go index 1ec27cff..b079639e 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -34,7 +34,9 @@ const connTestTimeout = time.Minute func removeAuthEnvVar(t *testing.T, wantToken bool) (*oauth2.Token, string, func()) { var tok *oauth2.Token if wantToken { - ts, err := google.DefaultTokenSource(context.Background()) + ts, err := google.DefaultTokenSource(context.Background(), + "https://www.googleapis.com/auth/cloud-platform", + ) if err != nil { t.Errorf("failed to resolve token source: %v", err) }