Skip to content

Commit

Permalink
chore(auth): add a bunch of missing validation logic (#8718)
Browse files Browse the repository at this point in the history
- made the look of validation consistent across packages
- moved refresh token into the options struct for 3lo (this is the flagged breaking change.)
  • Loading branch information
codyoss authored Oct 16, 2023
1 parent 3a4ec65 commit a8c842f
Show file tree
Hide file tree
Showing 15 changed files with 371 additions and 82 deletions.
35 changes: 27 additions & 8 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -224,17 +225,17 @@ type Options2LO struct {
// contents of a PEM file that contains a private key. It is used to sign
// the JWT created.
PrivateKey []byte
// TokenURL is th URL the JWT is sent to. Required.
TokenURL string
// PrivateKeyID is the ID of the key used to sign the JWT. It is used as the
// "kid" in the JWT header.
// "kid" in the JWT header. Optional.
PrivateKeyID string
// Subject is the used for to impersonate a user. It is used as the "sub" in
// the JWT.m Optional.
Subject string
// Scopes specifies requested permissions for the token. Optional.
Scopes []string
// TokenURL is th URL the JWT is sent to.
TokenURL string
// Expires specifies the lifetime of the token.
// Expires specifies the lifetime of the token. Optional.
Expires time.Duration
// Audience specifies the "aud" in the JWT. Optional.
Audience string
Expand All @@ -249,16 +250,34 @@ type Options2LO struct {
UseIDToken bool
}

func (c *Options2LO) client() *http.Client {
if c.Client != nil {
return c.Client
func (o *Options2LO) client() *http.Client {
if o.Client != nil {
return o.Client
}
return internal.CloneDefaultClient()
}

func (o *Options2LO) validate() error {
if o == nil {
return errors.New("auth: options must be provided")
}
if o.Email == "" {
return errors.New("auth: email must be provided")
}
if len(o.PrivateKey) == 0 {
return errors.New("auth: private key must be provided")
}
if o.TokenURL == "" {
return errors.New("auth: token URL must be provided")
}
return nil
}

// New2LOTokenProvider returns a [TokenProvider] from the provided options.
func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) {
// TODO(codyoss): add validation
if err := opts.validate(); err != nil {
return nil, err
}
return tokenProvider2LO{opts: opts, Client: opts.client()}, nil
}

Expand Down
51 changes: 45 additions & 6 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestError_Error(t *testing.T) {
}
}

func TestConfigJWT2LO_JSONResponse(t *testing.T) {
func TestNew2LOTokenProvider_JSONResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
Expand Down Expand Up @@ -221,7 +221,7 @@ func TestConfigJWT2LO_JSONResponse(t *testing.T) {
}
}

func TestConfigJWT2LO_BadResponse(t *testing.T) {
func TestNew2LOTokenProvider_BadResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
Expand Down Expand Up @@ -259,7 +259,7 @@ func TestConfigJWT2LO_BadResponse(t *testing.T) {
}
}

func TestConfigJWT2LO_BadResponseType(t *testing.T) {
func TestNew2LOTokenProvider_BadResponseType(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
Expand All @@ -283,7 +283,7 @@ func TestConfigJWT2LO_BadResponseType(t *testing.T) {
}
}

func TestConfigJWT2LO_Assertion(t *testing.T) {
func TestNew2LOTokenProvider_Assertion(t *testing.T) {
var assertion string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
Expand Down Expand Up @@ -339,7 +339,7 @@ func TestConfigJWT2LO_Assertion(t *testing.T) {
}
}

func TestConfigJWT2LO_AssertionPayload(t *testing.T) {
func TestNew2LOTokenProvider_AssertionPayload(t *testing.T) {
var assertion string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
Expand Down Expand Up @@ -436,7 +436,7 @@ func TestConfigJWT2LO_AssertionPayload(t *testing.T) {
}
}

func TestConfigJWT2LO_TokenError(t *testing.T) {
func TestNew2LOTokenProvider_TokenError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-type", "application/json")
w.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -467,3 +467,42 @@ func TestConfigJWT2LO_TokenError(t *testing.T) {
t.Fatalf("got %#v, expected %#v", errStr, expected)
}
}

func TestNew2LOTokenProvider_Validate(t *testing.T) {
tests := []struct {
name string
opts *Options2LO
}{
{
name: "missing options",
},
{
name: "missing email",
opts: &Options2LO{
PrivateKey: []byte("key"),
TokenURL: "url",
},
},
{
name: "missing key",
opts: &Options2LO{
Email: "email",
TokenURL: "url",
},
},
{
name: "missing URL",
opts: &Options2LO{
Email: "email",
PrivateKey: []byte("key"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := New2LOTokenProvider(tt.opts); err == nil {
t.Error("got nil, want an error")
}
})
}
}
33 changes: 27 additions & 6 deletions auth/detect/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package detect

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -119,7 +120,9 @@ func OnGCE() bool {
// runtimes, and Google App Engine flexible environment, it fetches
// credentials from the metadata server.
func DefaultCredentials(opts *Options) (*Credentials, error) {
// TODO(codyoss): add some validation logic here.
if err := opts.validate(); err != nil {
return nil, err
}
if opts.CredentialsJSON != nil {
return readCredentialsFileJSON(opts.CredentialsJSON, opts)
}
Expand All @@ -145,7 +148,8 @@ func DefaultCredentials(opts *Options) (*Credentials, error) {
// Options provides configuration for [DefaultCredentials].
type Options struct {
// Scopes that credentials tokens should have. Example:
// https://www.googleapis.com/auth/cloud-platform
// https://www.googleapis.com/auth/cloud-platform. Required if Audience is
// not provided.
Scopes []string
// Audience that credentials tokens should have. Only applicable for 2LO
// flows with service accounts. If specified, scopes should not be provided.
Expand All @@ -168,10 +172,12 @@ type Options struct {
// Currently this only used for GDCH auth flow, for which it is required.
STSAudience string
// CredentialsFile overrides detection logic and sources a credential file
// from the provided filepath. Optional.
// from the provided filepath. If provided, CredentialsJSON must not be.
// Optional.
CredentialsFile string
// CredentialsJSON overrides detection logic and uses the JSON bytes as the
// source for the credential. Optional.
// source for the credential. If provided, CredentialsFile must not be.
// Optional.
CredentialsJSON []byte
// UseSelfSignedJWT directs service account based credentials to create a
// self-signed JWT with the private key found in the file, skipping any
Expand All @@ -182,6 +188,19 @@ type Options struct {
Client *http.Client
}

func (o *Options) validate() error {
if o == nil {
return errors.New("detect: options must be provided")
}
if len(o.Scopes) > 0 && o.Audience != "" {
return errors.New("detect: both scopes and audience were provided")
}
if len(o.CredentialsJSON) > 0 && o.CredentialsFile != "" {
return errors.New("detect: both credentials file and JSON were provided")
}
return nil
}

func (o *Options) tokenURL() string {
if o.TokenURL != "" {
return o.TokenURL
Expand Down Expand Up @@ -214,7 +233,10 @@ func readCredentialsFileJSON(b []byte, opts *Options) (*Credentials, error) {
// attempt to parse jsonData as a Google Developers Console client_credentials.json.
config := clientCredConfigFromJSON(b, opts)
if config != nil {
tp, err := auth.New3LOTokenProvider("", config)
if config.AuthHandlerOpts == nil {
return nil, errors.New("detect: auth handler must be specified for this credential filetype")
}
tp, err := auth.New3LOTokenProvider(config)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -242,7 +264,6 @@ func clientCredConfigFromJSON(b []byte, opts *Options) *auth.Options3LO {
}
var handleOpts *auth.AuthorizationHandlerOptions
if opts.AuthHandlerOptions != nil {
// TODO(codyoss): these have to be here for this flow, validate that
handleOpts = &auth.AuthorizationHandlerOptions{
Handler: opts.AuthHandlerOptions.Handler,
State: opts.AuthHandlerOptions.State,
Expand Down
33 changes: 33 additions & 0 deletions auth/detect/detect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,36 @@ func TestDefaultCredentials_BadFiletype(t *testing.T) {
t.Fatal("got nil, want non-nil err")
}
}

func TestDefaultCredentials_Validate(t *testing.T) {
tests := []struct {
name string
opts *Options
}{
{
name: "missing options",
},
{
name: "scope and audience provided",
opts: &Options{
Scopes: []string{"scope"},
Audience: "aud",
},
},
{
name: "file and json provided",
opts: &Options{
Scopes: []string{"scope"},
CredentialsFile: "path",
CredentialsJSON: []byte(`{"some":"json"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := DefaultCredentials(tt.opts); err == nil {
t.Error("got nil, want an error")
}
})
}
}
3 changes: 2 additions & 1 deletion auth/detect/filetypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ func handleUserCredential(f *internaldetect.UserCredentialsFile, opts *Options)
TokenURL: opts.tokenURL(),
AuthStyle: auth.StyleInParams,
EarlyTokenExpiry: opts.EarlyTokenRefresh,
RefreshToken: f.RefreshToken,
}
return auth.New3LOTokenProvider(f.RefreshToken, opts3LO)
return auth.New3LOTokenProvider(opts3LO)
}

func handleExternalAccount(f *internaldetect.ExternalAccountFile, opts *Options) (auth.TokenProvider, error) {
Expand Down
1 change: 0 additions & 1 deletion auth/detect/internal/externalaccount/sts_exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ type stsTokenExchangeResponse struct {
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
// TODO(codyoss): original impl parsed but did not use a refresh token here, do we need it?
}

// clientAuthentication represents an OAuth client ID and secret and the
Expand Down
15 changes: 14 additions & 1 deletion auth/detect/internal/impersonate/impersonate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -46,7 +47,9 @@ type impersonateTokenResponse struct {
// NewTokenProvider uses a source credential, stored in Ts, to request an access token to the provided URL.
// Scopes can be defined when the access token is requested.
func NewTokenProvider(opts *Options) (auth.TokenProvider, error) {
// TODO(codyoss): add validation
if err := opts.validate(); err != nil {
return nil, err
}
return opts, nil
}

Expand All @@ -73,6 +76,16 @@ type Options struct {
Client *http.Client
}

func (o *Options) validate() error {
if o.Tp == nil {
return errors.New("detect: missing required 'source_credentials' field in impersonated credentials")
}
if o.URL == "" {
return errors.New("detect: missing required 'service_account_impersonation_url' field in impersonated credentials")
}
return nil
}

// Token performs the exchange to get a temporary service account token to allow access to GCP.
func (tp *Options) Token(ctx context.Context) (*auth.Token, error) {
lifetime := defaultTokenLifetime
Expand Down
28 changes: 28 additions & 0 deletions auth/detect/internal/impersonate/impersonate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ func (tp mockProvider) Token(context.Context) (*auth.Token, error) {
}, nil
}

func TestNewImpersonatedTokenProvider_Validation(t *testing.T) {
tests := []struct {
name string
opt *Options
}{
{
name: "missing source creds",
opt: &Options{
URL: "some-url",
},
},
{
name: "missing url",
opt: &Options{
Tp: &Options{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewTokenProvider(tt.opt)
if err == nil {
t.Errorf("got nil, want an error")
}
})
}
}

func TestNewImpersonatedTokenProvider(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Authorization"), "Bearer fake_token_base"; got != want {
Expand Down
Loading

0 comments on commit a8c842f

Please sign in to comment.