Skip to content

Commit

Permalink
fix(auth): get s2a logic up to date (#10093)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmenxk authored May 3, 2024
1 parent 2b576ab commit 4fe9ae4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 100 deletions.
2 changes: 2 additions & 0 deletions auth/grpctransport/grpctransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
if io := opts.InternalOptions; io != nil {
tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate
tOpts.DefaultMTLSEndpoint = io.DefaultMTLSEndpoint
tOpts.EnableDirectPath = io.EnableDirectPath
tOpts.EnableDirectPathXds = io.EnableDirectPathXds
}
transportCreds, endpoint, err := transport.GetGRPCTransportCredsAndEndpoint(tOpts)
if err != nil {
Expand Down
7 changes: 3 additions & 4 deletions auth/internal/transport/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ type Options struct {
ClientCertProvider cert.Provider
Client *http.Client
UniverseDomain string
EnableDirectPath bool
EnableDirectPathXds bool
}

// getUniverseDomain returns the default service domain for a given Cloud
Expand Down Expand Up @@ -195,10 +197,7 @@ func getTransportConfig(opts *Options) (*transportConfig, error) {
}

s2aMTLSEndpoint := opts.DefaultMTLSEndpoint
// If there is endpoint override, honor it.
if opts.Endpoint != "" {
s2aMTLSEndpoint = endpoint
}

s2aAddress := GetS2AAddress()
if s2aAddress == "" {
return &defaultTransportConfig, nil
Expand Down
149 changes: 66 additions & 83 deletions auth/internal/transport/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,70 +249,83 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {

func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
testCases := []struct {
name string
opts *Options
s2ARespFn func() (string, error)
mtlsEnabledFn func() bool
want string
name string
opts *Options
s2ARespFn func() (string, error)
want string
}{
{
name: "no client cert, endpoint is MTLS enabled, S2A address not empty",
name: "has client cert",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
ClientCertProvider: fakeClientCertSource,
},
s2ARespFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testMTLSEndpoint,
s2ARespFn: validConfigResp,
want: testMTLSEndpoint,
},
{
name: "has client cert",
name: "no client cert, S2A address not empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
ClientCertProvider: fakeClientCertSource,
},
s2ARespFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testMTLSEndpoint,
s2ARespFn: validConfigResp,
want: testMTLSEndpoint,
},
{
name: "no client cert, endpoint is not MTLS enabled",
name: "no client cert, S2A address not empty, EnableDirectPath == true",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
EnableDirectPath: true,
},
s2ARespFn: validConfigResp,
mtlsEnabledFn: func() bool { return false },
want: testRegularEndpoint,
s2ARespFn: validConfigResp,
want: testRegularEndpoint,
},
{
name: "no client cert, endpoint is MTLS enabled, S2A address empty",
name: "no client cert, S2A address not empty, EnableDirectPathXds == true",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
EnableDirectPathXds: true,
},
s2ARespFn: invalidConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testRegularEndpoint,
s2ARespFn: validConfigResp,
want: testRegularEndpoint,
},
{
name: "no client cert, endpoint is MTLS enabled, S2A address not empty, override endpoint",
name: "no client cert, S2A address empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2ARespFn: invalidConfigResp,
want: testRegularEndpoint,
},
{
name: "no client cert, S2A address not empty, override endpoint",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
Endpoint: testOverrideEndpoint,
},
s2ARespFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testOverrideEndpoint,
s2ARespFn: validConfigResp,
want: testOverrideEndpoint,
},
{
"no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
&Options{
DefaultMTLSEndpoint: "",
DefaultEndpointTemplate: testEndpointTemplate,
},
validConfigResp,
testRegularEndpoint,
},
}
defer setupTest(t)()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
httpGetMetadataMTLSConfig = tc.s2ARespFn
mtlsEndpointEnabledForS2A = tc.mtlsEnabledFn
if tc.opts.ClientCertProvider != nil {
t.Setenv(googleAPIUseCertSource, "true")
} else {
Expand All @@ -330,107 +343,79 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {

func TestGetHTTPTransportConfig_S2a(t *testing.T) {
testCases := []struct {
name string
opts *Options
s2aFn func() (string, error)
mtlsEnabledFn func() bool
want string
isDialFnNil bool
name string
opts *Options
s2aFn func() (string, error)
want string
isDialFnNil bool
}{
{
name: "no client cert, endpoint is MTLS enabled, S2A address not empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testMTLSEndpoint,
},
{
name: "has client cert",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
ClientCertProvider: fakeClientCertSource,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testMTLSEndpoint,
isDialFnNil: true,
s2aFn: validConfigResp,
want: testMTLSEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, endpoint is not MTLS enabled",
name: "no client cert, S2A address not empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return false },
want: testRegularEndpoint,
isDialFnNil: true,
s2aFn: validConfigResp,
want: testMTLSEndpoint,
},
{
name: "no client cert, endpoint is MTLS enabled, S2A address empty",
name: "no client cert, S2A address empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: invalidConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testRegularEndpoint,
isDialFnNil: true,
s2aFn: invalidConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, endpoint is MTLS enabled, S2A address not empty, override endpoint",
name: "no client cert, S2A address not empty, override endpoint",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
Endpoint: testOverrideEndpoint,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testOverrideEndpoint,
s2aFn: validConfigResp,
want: testOverrideEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set",
opts: &Options{
DefaultMTLSEndpoint: "",
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testRegularEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, S2A address not empty, override endpoint is set",
opts: &Options{
DefaultMTLSEndpoint: "",
Endpoint: testOverrideEndpoint,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testOverrideEndpoint,
s2aFn: validConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, endpoint is MTLS enabled, S2A address not empty, custom HTTP client",
name: "no client cert, S2A address not empty, custom HTTP client",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
Client: http.DefaultClient,
},
s2aFn: validConfigResp,
mtlsEnabledFn: func() bool { return true },
want: testRegularEndpoint,
isDialFnNil: true,
s2aFn: validConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
}
defer setupTest(t)()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
httpGetMetadataMTLSConfig = tc.s2aFn
mtlsEndpointEnabledForS2A = tc.mtlsEnabledFn
if tc.opts.ClientCertProvider != nil {
t.Setenv(googleAPIUseCertSource, "true")
} else {
Expand All @@ -450,7 +435,6 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
}

func setupTest(t *testing.T) func() {
oldDefaultMTLSEnabled := mtlsEndpointEnabledForS2A
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry

Expand All @@ -459,7 +443,6 @@ func setupTest(t *testing.T) func() {

return func() {
httpGetMetadataMTLSConfig = oldHTTPGet
mtlsEndpointEnabledForS2A = oldDefaultMTLSEnabled
configExpiry = oldExpiry
}
}
Expand Down
17 changes: 4 additions & 13 deletions auth/internal/transport/s2a.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ var (
// The period an MTLS config can be reused before needing refresh.
configExpiry = time.Hour

// mtlsEndpointEnabledForS2A checks if the endpoint is indeed MTLS-enabled, so that we can use S2A for MTLS connection.
mtlsEndpointEnabledForS2A = func() bool {
// TODO(xmenxk): determine this via discovery config.
return true
}

// mdsMTLSAutoConfigSource is an instance of reuseMTLSConfigSource, with metadataMTLSAutoConfig as its config source.
mtlsOnce sync.Once
)
Expand Down Expand Up @@ -165,19 +159,16 @@ func shouldUseS2A(clientCertSource cert.Provider, opts *Options) bool {
if !isGoogleS2AEnabled() {
return false
}
// If DefaultMTLSEndpoint is not set and no endpoint override, skip S2A.
if opts.DefaultMTLSEndpoint == "" && opts.Endpoint == "" {
return false
}
// If MTLS is not enabled for this endpoint, skip S2A.
if !mtlsEndpointEnabledForS2A() {
// If DefaultMTLSEndpoint is not set or has endpoint override, skip S2A.
if opts.DefaultMTLSEndpoint == "" || opts.Endpoint != "" {
return false
}
// If custom HTTP client is provided, skip S2A.
if opts.Client != nil {
return false
}
return true
// If directPath is enabled, skip S2A.
return !opts.EnableDirectPath && !opts.EnableDirectPathXds
}

func isGoogleS2AEnabled() bool {
Expand Down

0 comments on commit 4fe9ae4

Please sign in to comment.