Skip to content

Commit

Permalink
Merge branch 'main' into auth-slices
Browse files Browse the repository at this point in the history
  • Loading branch information
gcf-merge-on-green[bot] authored Jul 2, 2024
2 parents ba65ae2 + 58e3df4 commit 297bd79
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 132 deletions.
8 changes: 4 additions & 4 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,12 +538,12 @@ func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) {
v := url.Values{}
v.Set("grant_type", defaultGrantType)
v.Set("assertion", payload)
resp, err := tp.Client.PostForm(tp.opts.TokenURL, v)
req, err := http.NewRequestWithContext(ctx, "POST", tp.opts.TokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
return nil, err
}
defer resp.Body.Close()
body, err := internal.ReadAll(resp.Body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, body, err := internal.DoRequest(tp.Client, req)
if err != nil {
return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions auth/credentials/downscope/downscope.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,21 @@ func (dts *downscopedTokenProvider) Token(ctx context.Context) (*auth.Token, err
form.Add("subject_token", tok.Value)
form.Add("options", string(b))

resp, err := dts.Client.PostForm(dts.identityBindingEndpoint, form)
req, err := http.NewRequestWithContext(ctx, "POST", dts.identityBindingEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := internal.ReadAll(resp.Body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, body, err := internal.DoRequest(dts.Client, req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("downscope: unable to exchange token, %v: %s", resp.StatusCode, respBody)
return nil, fmt.Errorf("downscope: unable to exchange token, %v: %s", resp.StatusCode, body)
}

var tresp downscopedTokenResponse
err = json.Unmarshal(respBody, &tresp)
err = json.Unmarshal(body, &tresp)
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions auth/credentials/idtoken/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"strings"
"sync"
"time"

"cloud.google.com/go/auth/internal"
)

type cachingClient struct {
Expand Down Expand Up @@ -52,22 +54,20 @@ func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse,
if response, ok := c.get(url); ok {
return response, nil
}
req, err := http.NewRequest(http.MethodGet, url, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
resp, err := c.client.Do(req)
resp, body, err := internal.DoRequest(c.client, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
}

certResp := &certResponse{}
if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
if err := json.Unmarshal(body, &certResp); err != nil {
return nil, err

}
Expand Down
2 changes: 1 addition & 1 deletion auth/credentials/idtoken/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (c computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error)
v.Set("licenses", "TRUE")
}
urlSuffix := identitySuffix + "?" + v.Encode()
res, err := c.client.Get(urlSuffix)
res, err := c.client.GetWithContext(ctx, urlSuffix)
if err != nil {
return nil, err
}
Expand Down
9 changes: 2 additions & 7 deletions auth/credentials/impersonate/idtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,15 @@ func (i impersonatedIDTokenProvider) Token(ctx context.Context) (*auth.Token, er
}

url := fmt.Sprintf("%s/v1/%s:generateIdToken", iamCredentialsEndpoint, formatIAMServiceAccountName(i.targetPrincipal))
req, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := i.client.Do(req)
resp, body, err := internal.DoRequest(i.client, req)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to generate ID token: %w", err)
}
defer resp.Body.Close()
body, err := internal.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %w", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}
Expand Down
10 changes: 2 additions & 8 deletions auth/credentials/impersonate/impersonate.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,21 +238,15 @@ func (i impersonatedTokenProvider) Token(ctx context.Context) (*auth.Token, erro
return nil, fmt.Errorf("impersonate: unable to marshal request: %w", err)
}
url := fmt.Sprintf("%s/v1/%s:generateAccessToken", iamCredentialsEndpoint, formatIAMServiceAccountName(i.targetPrincipal))
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(b))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := i.client.Do(req)
resp, body, err := internal.DoRequest(i.client, req)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to generate access token: %w", err)
}
defer resp.Body.Close()
body, err := internal.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %w", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}
Expand Down
22 changes: 7 additions & 15 deletions auth/credentials/impersonate/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ type userTokenProvider struct {
}

func (u userTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
signedJWT, err := u.signJWT()
signedJWT, err := u.signJWT(ctx)
if err != nil {
return nil, err
}
return u.exchangeToken(ctx, signedJWT)
}

func (u userTokenProvider) signJWT() (string, error) {
func (u userTokenProvider) signJWT(ctx context.Context) (string, error) {
now := time.Now()
exp := now.Add(u.lifetime)
claims := claimSet{
Expand All @@ -124,20 +124,16 @@ func (u userTokenProvider) signJWT() (string, error) {
return "", fmt.Errorf("impersonate: unable to marshal request: %w", err)
}
reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentialsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("impersonate: unable to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
rawResp, err := u.client.Do(req)
resp, body, err := internal.DoRequest(u.client, req)
if err != nil {
return "", fmt.Errorf("impersonate: unable to sign JWT: %w", err)
}
body, err := internal.ReadAll(rawResp.Body)
if err != nil {
return "", fmt.Errorf("impersonate: unable to read body: %w", err)
}
if c := rawResp.StatusCode; c < 200 || c > 299 {
if c := resp.StatusCode; c < 200 || c > 299 {
return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
}

Expand All @@ -157,15 +153,11 @@ func (u userTokenProvider) exchangeToken(ctx context.Context, signedJWT string)
if err != nil {
return nil, err
}
rawResp, err := u.client.Do(req)
resp, body, err := internal.DoRequest(u.client, req)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to exchange token: %w", err)
}
body, err := internal.ReadAll(rawResp.Body)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %w", err)
}
if c := rawResp.StatusCode; c < 200 || c > 299 {
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}

Expand Down
59 changes: 17 additions & 42 deletions auth/credentials/internal/externalaccount/aws_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (sp *awsSubjectProvider) subjectToken(ctx context.Context) (string, error)

// Generate the signed request to AWS STS GetCallerIdentity API.
// Use the required regional endpoint. Otherwise, the request will fail.
req, err := http.NewRequest("POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
req, err := http.NewRequestWithContext(ctx, "POST", strings.Replace(sp.RegionalCredVerificationURL, "{region}", sp.region, 1), nil)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -194,20 +194,14 @@ func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, e
}
req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)

resp, err := sp.Client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

respBody, err := internal.ReadAll(resp.Body)
resp, body, err := internal.DoRequest(sp.Client, req)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", respBody)
return "", fmt.Errorf("credentials: unable to retrieve AWS session token: %s", body)
}
return string(respBody), nil
return string(body), nil
}

func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
Expand All @@ -233,29 +227,21 @@ func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]
for name, value := range headers {
req.Header.Add(name, value)
}

resp, err := sp.Client.Do(req)
resp, body, err := internal.DoRequest(sp.Client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()

respBody, err := internal.ReadAll(resp.Body)
if err != nil {
return "", err
}

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", respBody)
return "", fmt.Errorf("credentials: unable to retrieve AWS region - %s", body)
}

// This endpoint will return the region in format: us-east-2b.
// Only the us-east-2 part should be used.
bodyLen := len(respBody)
bodyLen := len(body)
if bodyLen == 0 {
return "", nil
}
return string(respBody[:bodyLen-1]), nil
return string(body[:bodyLen-1]), nil
}

func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result *AwsSecurityCredentials, err error) {
Expand Down Expand Up @@ -299,22 +285,17 @@ func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context
for name, value := range headers {
req.Header.Add(name, value)
}

resp, err := sp.Client.Do(req)
if err != nil {
return result, err
}
defer resp.Body.Close()

respBody, err := internal.ReadAll(resp.Body)
resp, body, err := internal.DoRequest(sp.Client, req)
if err != nil {
return result, err
}
if resp.StatusCode != http.StatusOK {
return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", respBody)
return result, fmt.Errorf("credentials: unable to retrieve AWS security credentials - %s", body)
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
err = json.Unmarshal(respBody, &result)
return result, err
return result, nil
}

func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
Expand All @@ -329,20 +310,14 @@ func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers m
req.Header.Add(name, value)
}

resp, err := sp.Client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

respBody, err := internal.ReadAll(resp.Body)
resp, body, err := internal.DoRequest(sp.Client, req)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", respBody)
return "", fmt.Errorf("credentials: unable to retrieve AWS role name - %s", body)
}
return string(respBody), nil
return string(body), nil
}

// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
Expand Down
16 changes: 5 additions & 11 deletions auth/credentials/internal/externalaccount/url_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,21 @@ func (sp *urlSubjectProvider) subjectToken(ctx context.Context) (string, error)
for key, val := range sp.Headers {
req.Header.Add(key, val)
}
resp, err := sp.Client.Do(req)
resp, body, err := internal.DoRequest(sp.Client, req)
if err != nil {
return "", fmt.Errorf("credentials: invalid response when retrieving subject token: %w", err)
}
defer resp.Body.Close()

respBody, err := internal.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("credentials: invalid body in subject token URL query: %w", err)
}
if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
return "", fmt.Errorf("credentials: status code %d: %s", c, respBody)
return "", fmt.Errorf("credentials: status code %d: %s", c, body)
}

if sp.Format == nil {
return string(respBody), nil
return string(body), nil
}
switch sp.Format.Type {
case "json":
jsonData := make(map[string]interface{})
err = json.Unmarshal(respBody, &jsonData)
err = json.Unmarshal(body, &jsonData)
if err != nil {
return "", fmt.Errorf("credentials: failed to unmarshal subject token file: %w", err)
}
Expand All @@ -82,7 +76,7 @@ func (sp *urlSubjectProvider) subjectToken(ctx context.Context) (string, error)
}
return token, nil
case fileTypeText:
return string(respBody), nil
return string(body), nil
default:
return "", errors.New("credentials: invalid credential_source file format type: " + sp.Format.Type)
}
Expand Down
10 changes: 6 additions & 4 deletions auth/credentials/internal/gdch/gdch.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net/http"
"net/url"
"os"
"strings"
"time"

"cloud.google.com/go/auth"
Expand Down Expand Up @@ -129,12 +130,13 @@ func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
v.Set("requested_token_type", requestTokenType)
v.Set("subject_token", payload)
v.Set("subject_token_type", subjectTokenType)
resp, err := g.client.PostForm(g.tokenURL, v)

req, err := http.NewRequestWithContext(ctx, "POST", g.tokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
return nil, err
}
defer resp.Body.Close()
body, err := internal.ReadAll(resp.Body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, body, err := internal.DoRequest(g.client, req)
if err != nil {
return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
}
Expand Down
7 changes: 1 addition & 6 deletions auth/credentials/internal/impersonate/impersonate.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,10 @@ func (o *Options) Token(ctx context.Context) (*auth.Token, error) {
if err := setAuthHeader(ctx, o.Tp, req); err != nil {
return nil, err
}
resp, err := o.Client.Do(req)
resp, body, err := internal.DoRequest(o.Client, req)
if err != nil {
return nil, fmt.Errorf("credentials: unable to generate access token: %w", err)
}
defer resp.Body.Close()
body, err := internal.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("credentials: unable to read body: %w", err)
}
if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
return nil, fmt.Errorf("credentials: status code %d: %s", c, body)
}
Expand Down
Loading

0 comments on commit 297bd79

Please sign in to comment.