From 3a1cc43a9874d4d8fa204e90f821db72b0b42789 Mon Sep 17 00:00:00 2001 From: Adriana Date: Tue, 3 Jan 2023 11:06:04 -0500 Subject: [PATCH] Address requested changes --- idtoken/idtoken.go | 42 +++++++++++++------ idtoken/idtoken_test.go | 80 ------------------------------------- idtoken/integration_test.go | 55 ++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 93 deletions(-) delete mode 100644 idtoken/idtoken_test.go diff --git a/idtoken/idtoken.go b/idtoken/idtoken.go index 152ac9c2007..6ed936e7fcb 100644 --- a/idtoken/idtoken.go +++ b/idtoken/idtoken.go @@ -29,6 +29,14 @@ import ( // ClientOption is for configuring a Google API client or transport. type ClientOption = option.ClientOption +type credentialsType int + +const ( + unknownCredType credentialsType = iota + serviceAccount + impersonatedServiceAccount +) + // NewClient creates a HTTP Client that automatically adds an ID token to each // request via an Authorization header. The token will have the audience // provided and be configured with the supplied options. The parameter audience @@ -71,7 +79,6 @@ func NewClient(ctx context.Context, audience string, opts ...ClientOption) (*htt // provided and configured with the supplied options. The parameter audience may // not be empty. func NewTokenSource(ctx context.Context, audience string, opts ...ClientOption) (oauth2.TokenSource, error) { - option.WithScopes() if audience == "" { return nil, fmt.Errorf("idtoken: must supply a non-empty audience") } @@ -112,7 +119,8 @@ func tokenSourceFromBytes(ctx context.Context, data []byte, audience string, ds if err != nil { return nil, err } - if allowedType == "service_account" { + switch allowedType { + case serviceAccount: cfg, err := google.JWTConfigFromJSON(data, ds.GetScopes()...) if err != nil { return nil, err @@ -132,8 +140,7 @@ func tokenSourceFromBytes(ctx context.Context, data []byte, audience string, ds return nil, err } return oauth2.ReuseTokenSource(tok, ts), nil - } else { - // if allowedType is "impersonated_service_account": + case impersonatedServiceAccount: type url struct { ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` } @@ -154,26 +161,37 @@ func tokenSourceFromBytes(ctx context.Context, data []byte, audience string, ds log.Println(err) } return ts, nil + default: + return nil, fmt.Errorf("idtoken: unsupported credentials type") } } -// isOfAllowedType returns the credentials type as a string, and an error. +// getAllowedType returns the credentials type of type credentialsType, and an error. // allowed types are "service_account" and "impersonated_service_account" -func getAllowedType(data []byte) (string, error) { +func getAllowedType(data []byte) (credentialsType, error) { + var t credentialsType if len(data) == 0 { - return "", fmt.Errorf("idtoken: credential provided is 0 bytes") + return t, fmt.Errorf("idtoken: credential provided is 0 bytes") } var f struct { Type string `json:"type"` } - // if not service account return an error if err := json.Unmarshal(data, &f); err != nil { - return "", err + return t, err } - if f.Type != "service_account" && f.Type != "impersonated_service_account" { - return "", fmt.Errorf("idtoken: credential must be service_account or impersonated_service_account, found %q", f.Type) + t = parseCredType(f.Type) + return t, nil +} + +func parseCredType(typeString string) credentialsType { + switch typeString { + case "service_account": + return serviceAccount + case "impersonated_service_account": + return impersonatedServiceAccount + default: + return unknownCredType } - return f.Type, nil } // WithCustomClaims optionally specifies custom private claims for an ID token. diff --git a/idtoken/idtoken_test.go b/idtoken/idtoken_test.go deleted file mode 100644 index 367f12d8710..00000000000 --- a/idtoken/idtoken_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2020 Google LLC. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package idtoken - -import ( - "context" - "reflect" - "testing" - - "golang.org/x/oauth2" - "google.golang.org/api/internal" -) - -var TokenSource oauth2.TokenSource - -func TestNewTokenSource(t *testing.T) { - tests := []struct { - name string - ctx context.Context - audience string - want oauth2.TokenSource - wantErr bool - }{ - { - name: "works", - ctx: context.Background(), - audience: "https://apikeys.googleapis.com", - want: TokenSource, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NewTokenSource(tt.ctx, tt.audience) - if (err != nil) != tt.wantErr { - t.Errorf("NewTokenSource() error = %v, wantErr %v", err, tt.wantErr) - return - } - tok, err := got.Token() - if (err != nil) != tt.wantErr { - t.Errorf("NewTokenSource() error = %v, wantErr %v", err, tt.wantErr) - return - } - _, err = Validate(tt.ctx, tok.AccessToken, tt.audience) - if err != nil { - t.Errorf("NewTokenSource() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_newTokenSource(t *testing.T) { - type args struct { - ctx context.Context - audience string - ds *internal.DialSettings - } - tests := []struct { - name string - args args - want oauth2.TokenSource - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := newTokenSource(tt.args.ctx, tt.args.audience, tt.args.ds) - if (err != nil) != tt.wantErr { - t.Errorf("newTokenSource() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("newTokenSource() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/idtoken/integration_test.go b/idtoken/integration_test.go index 32281fa8507..34f9a43c131 100644 --- a/idtoken/integration_test.go +++ b/idtoken/integration_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 Google LLC. +// Copyright 2023 Google LLC. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -6,6 +6,7 @@ package idtoken_test import ( "context" + "encoding/json" "net/http" "os" "strings" @@ -49,6 +50,58 @@ func TestNewTokenSource(t *testing.T) { } } +func TestNewTokenSource_WithImpersonatedServiceAccountCreds(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + type SourceCreds struct { + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + RefreshToken string `json:"refresh_token"` + Type string `json:"type"` + } + type Creds struct { + Type string `json:"type"` + ImpersonationUrl string `json:"service_account_impersonation_url"` + SourceCredentials SourceCreds `json:"source_credentials"` + } + sc := SourceCreds{ + "", + "", + "", + "", + } + c := Creds{ + "", + "", + sc, + } + creds, err := json.Marshal(c) + if err != nil { + t.Fatalf("unable to marshal credentials: %v", err) + } + ts, err := idtoken.NewTokenSource(context.Background(), "http://example.com", option.WithCredentialsJSON(creds)) + if err != nil { + t.Fatalf("unable to create TokenSource: %v", err) + } + tok, err := ts.Token() + if err != nil { + t.Fatalf("unable to retrieve Token: %v", err) + } + req := &http.Request{Header: make(http.Header)} + tok.SetAuthHeader(req) + if !strings.HasPrefix(req.Header.Get("Authorization"), "Bearer ") { + t.Fatalf("token should sign requests with Bearer Authorization header") + } + validTok, err := idtoken.Validate(context.Background(), tok.AccessToken, aud) + if err != nil { + t.Fatalf("token validation failed: %v", err) + } + if validTok.Audience != aud { + t.Fatalf("got %q, want %q", validTok.Audience, aud) + } +} + func TestNewClient_WithCredentialFile(t *testing.T) { if testing.Short() { t.Skip("skipping integration test")