Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BigTailWolf committed Sep 28, 2023
1 parent 399b52f commit c50beac
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 93 deletions.
31 changes: 4 additions & 27 deletions google/internal/externalaccount/basecredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
"golang.org/x/oauth2/google/internal/stsexchange"
)

// now aliases time.Now for testing
Expand Down Expand Up @@ -64,31 +62,10 @@ type Config struct {
WorkforcePoolUserProject string
}

// Each element consists of a list of patterns. validateURLs checks for matches
// that include all elements in a given list, in that order.

var (
validWorkforceAudiencePattern *regexp.Regexp = regexp.MustCompile(`//iam\.googleapis\.com/locations/[^/]+/workforcePools/`)
)

func validateURL(input string, patterns []*regexp.Regexp, scheme string) bool {
parsed, err := url.Parse(input)
if err != nil {
return false
}
if !strings.EqualFold(parsed.Scheme, scheme) {
return false
}
toTest := parsed.Host

for _, pattern := range patterns {
if pattern.MatchString(toTest) {
return true
}
}
return false
}

func validateWorkforceAudience(input string) bool {
return validWorkforceAudiencePattern.MatchString(input)
}
Expand Down Expand Up @@ -231,7 +208,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, err
}
stsRequest := sts_exchange.StsTokenExchangeRequest{
stsRequest := stsexchange.StsTokenExchangeRequest{
GrantType: "urn:ietf:params:oauth:grant-type:token-exchange",
Audience: conf.Audience,
Scope: conf.Scopes,
Expand All @@ -242,7 +219,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
header := make(http.Header)
header.Add("Content-Type", "application/x-www-form-urlencoded")
header.Add("x-goog-api-client", getMetricsHeaderValue(conf, credSource))
clientAuth := sts_exchange.ClientAuthentication{
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
Expand All @@ -255,7 +232,7 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
"userProject": conf.WorkforcePoolUserProject,
}
}
stsResp, err := sts_exchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
stsResp, err := stsexchange.ExchangeToken(ts.ctx, conf.TokenURL, &stsRequest, clientAuth, header, options)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
"golang.org/x/oauth2/google/internal/stsexchange"
)

// now aliases time.Now for testing.
Expand Down Expand Up @@ -87,13 +87,13 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
return nil, errors.New("oauth2/google: The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret.")
}

clientAuth := sts_exchange.ClientAuthentication{
clientAuth := stsexchange.ClientAuthentication{
AuthStyle: oauth2.AuthStyleInHeader,
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}

stsResponse, err := sts_exchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
stsResponse, err := stsexchange.RefreshAccessToken(ts.ctx, conf.TokenURL, conf.RefreshToken, clientAuth, nil)
if err != nil {
return nil, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google/internal/sts_exchange"
"golang.org/x/oauth2/google/internal/stsexchange"
)

const expiryDelta = 10 * time.Second
Expand All @@ -33,59 +33,11 @@ type testRefreshTokenServer struct {
Authorization string
ContentType string
Body string
ResponsePayload *sts_exchange.Response
ResponsePayload *stsexchange.Response
Response string
server *httptest.Server
}

func (trts *testRefreshTokenServer) Run(t *testing.T) (string, error) {
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}

func (trts *testRefreshTokenServer) Close() error {
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}

// Tests

func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expand All @@ -111,18 +63,18 @@ func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
RefreshToken: "CCCCCCC",
},
}

url, err := server.Run(t)
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
defer server.close(t)

config := &Config{
RefreshToken: "BBBBBBBBB",
Expand Down Expand Up @@ -153,17 +105,17 @@ func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}

url, err := server.Run(t)
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
defer server.close(t)

config := &Config{
RefreshToken: "BBBBBBBBB",
Expand Down Expand Up @@ -191,17 +143,17 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
ContentType: "application/x-www-form-urlencoded",
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
ResponsePayload: &sts_exchange.Response{
ResponsePayload: &stsexchange.Response{
ExpiresIn: 3600,
AccessToken: "AAAAAAA",
},
}

url, err := server.Run(t)
url, err := server.run(t)
if err != nil {
t.Fatalf("Error starting server")
}
defer server.Close()
defer server.close(t)
testCases := []struct {
name string
config Config
Expand Down Expand Up @@ -257,3 +209,51 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
})
}
}

func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
t.Helper()
if trts.server != nil {
return "", errors.New("Server is already running")
}
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.URL.String(), trts.URL; got != want {
t.Errorf("URL.String(): got %v but want %v", got, want)
}
headerAuth := r.Header.Get("Authorization")
if got, want := headerAuth, trts.Authorization; got != want {
t.Errorf("got %v but want %v", got, want)
}
headerContentType := r.Header.Get("Content-Type")
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
if got, want := string(body), trts.Body; got != want {
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
}
w.Header().Set("Content-Type", "application/json")
if trts.ResponsePayload != nil {
content, err := json.Marshal(trts.ResponsePayload)
if err != nil {
t.Fatalf("unable to marshall response JSON")
}
w.Write(content)
} else {
w.Write([]byte(trts.Response))
}
}))
return trts.server.URL, nil
}

func (trts *testRefreshTokenServer) close(t *testing.T) error {
t.Helper()
if trts.server == nil {
return errors.New("No server is running")
}
trts.server.Close()
trts.server = nil
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package sts_exchange
package stsexchange

import (
"encoding/base64"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package sts_exchange
package stsexchange

import (
"net/http"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package sts_exchange
package stsexchange

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package sts_exchange
package stsexchange

import (
"context"
Expand Down

0 comments on commit c50beac

Please sign in to comment.