From 48a2fe5f2c0e5487d3d5ccfa29023ac7d4a77600 Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Wed, 8 Feb 2023 14:01:26 -0700 Subject: [PATCH] fix: require login_token with token and enable_iam_login (#1641) When token and enable_iam_login are provided, the caller must also provide a login token to ensure only a minimally scoped token in embedded into the ephemeral certificate. --- cmd/cloud_sql_proxy/cloud_sql_proxy.go | 23 ++++++- cmd/cloud_sql_proxy/cloud_sql_proxy_test.go | 69 +++++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/cmd/cloud_sql_proxy/cloud_sql_proxy.go b/cmd/cloud_sql_proxy/cloud_sql_proxy.go index f59c9d430..fb8a9c43e 100644 --- a/cmd/cloud_sql_proxy/cloud_sql_proxy.go +++ b/cmd/cloud_sql_proxy/cloud_sql_proxy.go @@ -101,8 +101,9 @@ dropped`, ) // Settings for authentication. - token = flag.String("token", "", "When set, the proxy uses this Bearer token for authorization.") - tokenFile = flag.String("credential_file", "", + token = flag.String("token", "", "When set, the proxy uses this Bearer token for authorization.") + loginToken = flag.String("login_token", "", "Used in conjunction with --token and --enable_iam_login only") + tokenFile = flag.String("credential_file", "", `If provided, this json file will be used to retrieve Service Account credentials. You may set the GOOGLE_APPLICATION_CREDENTIALS environment variable for the same effect.`, @@ -381,13 +382,29 @@ func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, o return oauth2.NewClient(ctx, cred.TokenSource), scoped.TokenSource, nil } +var errLoginToken = errors.New("login_token must be used with token and enable_iam_login") + func authenticatedClient(ctx context.Context) (*http.Client, oauth2.TokenSource, error) { if *tokenFile != "" { return authenticatedClientFromPath(ctx, *tokenFile) } + // If login token has been set, but there is no token or + // enable_iam_login has not been set, error. + if *loginToken != "" && (*token == "" || !(*enableIAMLogin)) { + return nil, nil, errLoginToken + } + if tok := *token; tok != "" { src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok}) - return oauth2.NewClient(ctx, src), src, nil + cl := oauth2.NewClient(ctx, src) + if *enableIAMLogin { + if *loginToken == "" { + return nil, nil, errLoginToken + } + lts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: *loginToken}) + return cl, lts, nil + } + return cl, src, nil } if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" { return authenticatedClientFromPath(ctx, f) diff --git a/cmd/cloud_sql_proxy/cloud_sql_proxy_test.go b/cmd/cloud_sql_proxy/cloud_sql_proxy_test.go index 4319379bc..f9cde6b76 100644 --- a/cmd/cloud_sql_proxy/cloud_sql_proxy_test.go +++ b/cmd/cloud_sql_proxy/cloud_sql_proxy_test.go @@ -18,6 +18,8 @@ import ( "os" "strings" "testing" + + "golang.org/x/net/context" ) func TestVersionStripsNewline(t *testing.T) { @@ -31,3 +33,70 @@ func TestVersionStripsNewline(t *testing.T) { t.Fatalf("want = %q, got = %q", want, got) } } + +func TestAuthenticatedClient(t *testing.T) { + tcs := []struct { + desc string + setup func() func() + wantErr bool + }{ + { + desc: "when just token is set", + setup: func() func() { + *token = "MYTOKEN" + return func() { + *token = "" + } + }, + wantErr: false, + }, + { + desc: "when just login_token is set", + setup: func() func() { + *loginToken = "MYTOKEN" + return func() { + *loginToken = "" + } + }, + wantErr: true, + }, + { + desc: "when token and enable_iam_login are set", + setup: func() func() { + *token = "MYTOKEN" + *enableIAMLogin = true + return func() { + *token = "" + *enableIAMLogin = false + } + }, + wantErr: true, + }, + { + desc: "when token, login_token, and enable_iam_login are set", + setup: func() func() { + *token = "MYTOKEN" + *loginToken = "MYLOGINTOKEN" + *enableIAMLogin = true + return func() { + *token = "" + *loginToken = "" + *enableIAMLogin = false + } + }, + wantErr: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + cleanup := tc.setup() + defer cleanup() + _, _, err := authenticatedClient(context.Background()) + gotErr := err != nil + if tc.wantErr != gotErr { + t.Fatalf("err: want = %v, got = %v", tc.wantErr, gotErr) + } + }) + } +}