Skip to content

Commit

Permalink
feat(auth): add UniverseDomain to DetectOptions (#9536)
Browse files Browse the repository at this point in the history
* Enable universe domain mismatch checks in transport packages
  • Loading branch information
quartzmo authored Mar 20, 2024
1 parent cc64719 commit 3618d3f
Show file tree
Hide file tree
Showing 19 changed files with 496 additions and 111 deletions.
16 changes: 8 additions & 8 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,18 @@ type Error struct {
uri string
}

func (r *Error) Error() string {
if r.code != "" {
s := fmt.Sprintf("auth: %q", r.code)
if r.description != "" {
s += fmt.Sprintf(" %q", r.description)
func (e *Error) Error() string {
if e.code != "" {
s := fmt.Sprintf("auth: %q", e.code)
if e.description != "" {
s += fmt.Sprintf(" %q", e.description)
}
if r.uri != "" {
s += fmt.Sprintf(" %q", r.uri)
if e.uri != "" {
s += fmt.Sprintf(" %q", e.uri)
}
return s
}
return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", r.Response.StatusCode, r.Body)
return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body)
}

// Temporary returns true if the error is considered temporary and may be able
Expand Down
4 changes: 4 additions & 0 deletions auth/credentials/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func DetectDefault(opts *DetectOptions) (*auth.Credentials, error) {
ProjectIDProvider: auth.CredentialsPropertyFunc(func(context.Context) (string, error) {
return metadata.ProjectID()
}),
UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{},
}), nil
}

Expand Down Expand Up @@ -140,6 +141,9 @@ type DetectOptions struct {
// Client configures the underlying client used to make network requests
// when fetching tokens. Optional.
Client *http.Client
// UniverseDomain is the default service domain for a given Cloud universe.
// The default value is "googleapis.com". Optional.
UniverseDomain string
}

func (o *DetectOptions) validate() error {
Expand Down
110 changes: 109 additions & 1 deletion auth/credentials/detect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ func TestDefaultCredentials_ExternalAccountKey(t *testing.T) {
if want := "googleapis.com"; got != want {
t.Fatalf("got %q, want %q", got, want)
}
tok, err := creds.Token(context.Background())
tok, err := creds.Token(ctx)
if err != nil {
t.Fatalf("creds.Token() = %v", err)
}
Expand Down Expand Up @@ -720,3 +720,111 @@ func TestDefaultCredentials_Validate(t *testing.T) {
})
}
}

func TestDefaultCredentials_UniverseDomain(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
opts *DetectOptions
want string
}{
{
name: "user json",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/user.json",
TokenURL: "example.com",
},
want: "googleapis.com",
},
{
name: "user json with file universe domain",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/user_universe_domain.json",
TokenURL: "example.com",
},
want: "googleapis.com",
},
{
name: "service account token URL json",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/sa.json",
},
want: "googleapis.com",
},
{
name: "external account json",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/exaccount_user.json",
UseSelfSignedJWT: true,
},
want: "googleapis.com",
},
{
name: "service account impersonation json",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/imp.json",
UseSelfSignedJWT: true,
},
want: "googleapis.com",
},
{
name: "service account json with file universe domain",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/sa_universe_domain.json",
UseSelfSignedJWT: true,
},
want: "example.com",
},
{
name: "service account json with options universe domain",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/sa.json",
UseSelfSignedJWT: true,
UniverseDomain: "foo.com",
},
want: "foo.com",
},
{
name: "service account json with file and options universe domain",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/sa_universe_domain.json",
UseSelfSignedJWT: true,
UniverseDomain: "bar.com",
},
want: "bar.com",
},
{
name: "external account json with options universe domain",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/exaccount_user.json",
UseSelfSignedJWT: true,
UniverseDomain: "foo.com",
},
want: "foo.com",
},
{
name: "impersonated service account json with options universe domain",
opts: &DetectOptions{
CredentialsFile: "../internal/testdata/imp.json",
UseSelfSignedJWT: true,
UniverseDomain: "foo.com",
},
want: "foo.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
creds, err := DetectDefault(tt.opts)
if err != nil {
t.Fatalf("%s: %v", tt.name, err)
}
ud, err := creds.UniverseDomain(ctx)
if err != nil {
t.Fatalf("%s: %v", tt.name, err)
}
if ud != tt.want {
t.Fatalf("%s: got %q, want %q", tt.name, ud, tt.want)
}
})
}
}
3 changes: 3 additions & 0 deletions auth/credentials/filetypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ func fileCredentials(b []byte, opts *DetectOptions) (*auth.Credentials, error) {
default:
return nil, fmt.Errorf("detect: unsupported filetype %q", fileType)
}
if opts.UniverseDomain != "" {
universeDomain = opts.UniverseDomain
}
return auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{
ExpireEarly: opts.EarlyTokenRefresh,
Expand Down
36 changes: 18 additions & 18 deletions auth/credentials/internal/externalaccount/aws_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,17 @@ func (sp *awsSubjectProvider) providerType() string {
return awsProviderType
}

func (cs *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
if cs.IMDSv2SessionTokenURL == "" {
func (sp *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, error) {
if sp.IMDSv2SessionTokenURL == "" {
return "", nil
}
req, err := http.NewRequestWithContext(ctx, "PUT", cs.IMDSv2SessionTokenURL, nil)
req, err := http.NewRequestWithContext(ctx, "PUT", sp.IMDSv2SessionTokenURL, nil)
if err != nil {
return "", err
}
req.Header.Set(awsIMDSv2SessionTTLHeader, awsIMDSv2SessionTTL)

resp, err := cs.Client.Do(req)
resp, err := sp.Client.Do(req)
if err != nil {
return "", err
}
Expand All @@ -199,19 +199,19 @@ func (cs *awsSubjectProvider) getAWSSessionToken(ctx context.Context) (string, e
return string(respBody), nil
}

func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
func (sp *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]string) (string, error) {
if canRetrieveRegionFromEnvironment() {
if envAwsRegion := getenv(awsRegionEnvVar); envAwsRegion != "" {
return envAwsRegion, nil
}
return getenv(awsDefaultRegionEnvVar), nil
}

if cs.RegionURL == "" {
if sp.RegionURL == "" {
return "", errors.New("detect: unable to determine AWS region")
}

req, err := http.NewRequestWithContext(ctx, "GET", cs.RegionURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", sp.RegionURL, nil)
if err != nil {
return "", err
}
Expand All @@ -220,7 +220,7 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]
req.Header.Add(name, value)
}

resp, err := cs.Client.Do(req)
resp, err := sp.Client.Do(req)
if err != nil {
return "", err
}
Expand All @@ -244,7 +244,7 @@ func (cs *awsSubjectProvider) getRegion(ctx context.Context, headers map[string]
return string(respBody[:bodyLen-1]), nil
}

func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result awsSecurityCredentials, err error) {
func (sp *awsSubjectProvider) getSecurityCredentials(ctx context.Context, headers map[string]string) (result awsSecurityCredentials, err error) {
if canRetrieveSecurityCredentialFromEnvironment() {
return awsSecurityCredentials{
AccessKeyID: getenv(awsAccessKeyIDEnvVar),
Expand All @@ -253,11 +253,11 @@ func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, header
}, nil
}

roleName, err := cs.getMetadataRoleName(ctx, headers)
roleName, err := sp.getMetadataRoleName(ctx, headers)
if err != nil {
return
}
credentials, err := cs.getMetadataSecurityCredentials(ctx, roleName, headers)
credentials, err := sp.getMetadataSecurityCredentials(ctx, roleName, headers)
if err != nil {
return
}
Expand All @@ -272,18 +272,18 @@ func (cs *awsSubjectProvider) getSecurityCredentials(ctx context.Context, header
return credentials, nil
}

func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (awsSecurityCredentials, error) {
func (sp *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context, roleName string, headers map[string]string) (awsSecurityCredentials, error) {
var result awsSecurityCredentials

req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/%s", sp.CredVerificationURL, roleName), nil)
if err != nil {
return result, err
}
for name, value := range headers {
req.Header.Add(name, value)
}

resp, err := cs.Client.Do(req)
resp, err := sp.Client.Do(req)
if err != nil {
return result, err
}
Expand All @@ -300,19 +300,19 @@ func (cs *awsSubjectProvider) getMetadataSecurityCredentials(ctx context.Context
return result, err
}

func (cs *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
if cs.CredVerificationURL == "" {
func (sp *awsSubjectProvider) getMetadataRoleName(ctx context.Context, headers map[string]string) (string, error) {
if sp.CredVerificationURL == "" {
return "", errors.New("detect: unable to determine the AWS metadata server security credentials endpoint")
}
req, err := http.NewRequestWithContext(ctx, "GET", cs.CredVerificationURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", sp.CredVerificationURL, nil)
if err != nil {
return "", err
}
for name, value := range headers {
req.Header.Add(name, value)
}

resp, err := cs.Client.Do(req)
resp, err := sp.Client.Do(req)
if err != nil {
return "", err
}
Expand Down
30 changes: 15 additions & 15 deletions auth/credentials/internal/externalaccount/executable_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ type executableResponse struct {
Message string `json:"message,omitempty"`
}

func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte, source string, now int64) (string, error) {
func (sp *executableSubjectProvider) parseSubjectTokenFromSource(response []byte, source string, now int64) (string, error) {
var result executableResponse
if err := json.Unmarshal(response, &result); err != nil {
return "", jsonParsingError(source, string(response))
Expand All @@ -143,7 +143,7 @@ func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte
if result.Version > executableSupportedMaxVersion || result.Version < 0 {
return "", unsupportedVersionError(source, result.Version)
}
if result.ExpirationTime == 0 && cs.OutputFile != "" {
if result.ExpirationTime == 0 && sp.OutputFile != "" {
return "", missingFieldError(source, "expiration_time")
}
if result.TokenType == "" {
Expand All @@ -169,24 +169,24 @@ func (cs *executableSubjectProvider) parseSubjectTokenFromSource(response []byte
}
}

func (cs *executableSubjectProvider) subjectToken(ctx context.Context) (string, error) {
if token, err := cs.getTokenFromOutputFile(); token != "" || err != nil {
func (sp *executableSubjectProvider) subjectToken(ctx context.Context) (string, error) {
if token, err := sp.getTokenFromOutputFile(); token != "" || err != nil {
return token, err
}
return cs.getTokenFromExecutableCommand(ctx)
return sp.getTokenFromExecutableCommand(ctx)
}

func (cs *executableSubjectProvider) providerType() string {
func (sp *executableSubjectProvider) providerType() string {
return executableProviderType
}

func (cs *executableSubjectProvider) getTokenFromOutputFile() (token string, err error) {
if cs.OutputFile == "" {
func (sp *executableSubjectProvider) getTokenFromOutputFile() (token string, err error) {
if sp.OutputFile == "" {
// This ExecutableCredentialSource doesn't use an OutputFile.
return "", nil
}

file, err := os.Open(cs.OutputFile)
file, err := os.Open(sp.OutputFile)
if err != nil {
// No OutputFile found. Hasn't been created yet, so skip it.
return "", nil
Expand All @@ -199,7 +199,7 @@ func (cs *executableSubjectProvider) getTokenFromOutputFile() (token string, err
return "", nil
}

token, err = cs.parseSubjectTokenFromSource(data, outputFileSource, cs.env.now().Unix())
token, err = sp.parseSubjectTokenFromSource(data, outputFileSource, sp.env.now().Unix())
if err != nil {
if _, ok := err.(nonCacheableError); ok {
// If the cached token is expired we need a new token,
Expand Down Expand Up @@ -231,20 +231,20 @@ func (sp *executableSubjectProvider) executableEnvironment() []string {
return result
}

func (cs *executableSubjectProvider) getTokenFromExecutableCommand(ctx context.Context) (string, error) {
func (sp *executableSubjectProvider) getTokenFromExecutableCommand(ctx context.Context) (string, error) {
// For security reasons, we need our consumers to set this environment variable to allow executables to be run.
if cs.env.getenv(allowExecutablesEnvVar) != "1" {
if sp.env.getenv(allowExecutablesEnvVar) != "1" {
return "", errors.New("detect: executables need to be explicitly allowed (set GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES to '1') to run")
}

ctx, cancel := context.WithDeadline(ctx, cs.env.now().Add(cs.Timeout))
ctx, cancel := context.WithDeadline(ctx, sp.env.now().Add(sp.Timeout))
defer cancel()

output, err := cs.env.run(ctx, cs.Command, cs.executableEnvironment())
output, err := sp.env.run(ctx, sp.Command, sp.executableEnvironment())
if err != nil {
return "", err
}
return cs.parseSubjectTokenFromSource(output, executableSource, cs.env.now().Unix())
return sp.parseSubjectTokenFromSource(output, executableSource, sp.env.now().Unix())
}

func missingFieldError(source, field string) error {
Expand Down
Loading

0 comments on commit 3618d3f

Please sign in to comment.