Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: require login_token with token and enable_iam_login #1641

Merged
merged 2 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions cmd/cloud_sql_proxy/cloud_sql_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ dropped`,
)

// Settings for authentication.
token = flag.String("token", "", "When set, the proxy uses this Bearer token for authorization.")
tokenFile = flag.String("credential_file", "",
token = flag.String("token", "", "When set, the proxy uses this Bearer token for authorization.")
loginToken = flag.String("login_token", "", "Used in conjunction with --token and --enable_iam_login only")
tokenFile = flag.String("credential_file", "",
`If provided, this json file will be used to retrieve Service Account
credentials. You may set the GOOGLE_APPLICATION_CREDENTIALS environment
variable for the same effect.`,
Expand Down Expand Up @@ -381,13 +382,29 @@ func authenticatedClientFromPath(ctx context.Context, f string) (*http.Client, o
return oauth2.NewClient(ctx, cred.TokenSource), scoped.TokenSource, nil
}

var errLoginToken = errors.New("login_token must be used with token and enable_iam_login")

func authenticatedClient(ctx context.Context) (*http.Client, oauth2.TokenSource, error) {
if *tokenFile != "" {
return authenticatedClientFromPath(ctx, *tokenFile)
}
// If login token has been set, but there is no token or
// enable_iam_login has not been set, error.
if *loginToken != "" && (*token == "" || !(*enableIAMLogin)) {
return nil, nil, errLoginToken
}

if tok := *token; tok != "" {
src := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tok})
return oauth2.NewClient(ctx, src), src, nil
cl := oauth2.NewClient(ctx, src)
if *enableIAMLogin {
if *loginToken == "" {
return nil, nil, errLoginToken
}
lts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: *loginToken})
return cl, lts, nil
}
return cl, src, nil
}
if f := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS"); f != "" {
return authenticatedClientFromPath(ctx, f)
Expand Down
69 changes: 69 additions & 0 deletions cmd/cloud_sql_proxy/cloud_sql_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"os"
"strings"
"testing"

"golang.org/x/net/context"
)

func TestVersionStripsNewline(t *testing.T) {
Expand All @@ -31,3 +33,70 @@ func TestVersionStripsNewline(t *testing.T) {
t.Fatalf("want = %q, got = %q", want, got)
}
}

func TestAuthenticatedClient(t *testing.T) {
tcs := []struct {
desc string
setup func() func()
wantErr bool
}{
{
desc: "when just token is set",
setup: func() func() {
*token = "MYTOKEN"
return func() {
*token = ""
}
},
wantErr: false,
},
{
desc: "when just login_token is set",
setup: func() func() {
*loginToken = "MYTOKEN"
return func() {
*loginToken = ""
}
},
wantErr: true,
},
{
desc: "when token and enable_iam_login are set",
setup: func() func() {
*token = "MYTOKEN"
*enableIAMLogin = true
return func() {
*token = ""
*enableIAMLogin = false
}
},
wantErr: true,
},
{
desc: "when token, login_token, and enable_iam_login are set",
setup: func() func() {
*token = "MYTOKEN"
*loginToken = "MYLOGINTOKEN"
*enableIAMLogin = true
return func() {
*token = ""
*loginToken = ""
*enableIAMLogin = false
}
},
wantErr: false,
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
cleanup := tc.setup()
defer cleanup()
_, _, err := authenticatedClient(context.Background())
gotErr := err != nil
if tc.wantErr != gotErr {
t.Fatalf("err: want = %v, got = %v", tc.wantErr, gotErr)
}
})
}
}