Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add universe resolution logic #2284

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions internal/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ package internal
import (
"context"
"crypto/tls"
"errors"
"net"
"net/url"
"os"
Expand All @@ -56,29 +57,30 @@ const (
)

// getClientCertificateSourceAndEndpoint is a convenience function that invokes
// getClientCertificateSource and getEndpoint sequentially and returns the client
// cert source and endpoint as a tuple.
func getClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source, string, error) {
// getClientCertificateSource and getEndpointAndUniverse sequentially and returns the client
// cert source, endpoint, and universe as a tuple.
func getClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source, string, string, error) {
clientCertSource, err := getClientCertificateSource(settings)
if err != nil {
return nil, "", err
return nil, "", "", err
}
endpoint, err := getEndpoint(settings, clientCertSource)
endpoint, universe, err := getEndpointAndUniverse(settings, clientCertSource, getMTLSMode())
if err != nil {
return nil, "", err
return nil, "", "", err
}
return clientCertSource, endpoint, nil
return clientCertSource, endpoint, universe, nil
}

type transportConfig struct {
clientCertSource cert.Source // The client certificate source.
endpoint string // The corresponding endpoint to use based on client certificate source.
universe string // The corresponding universe (suffix domain).
s2aAddress string // The S2A address if it can be used, otherwise an empty string.
s2aMTLSEndpoint string // The MTLS endpoint to use with S2A.
}

func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
clientCertSource, endpoint, err := getClientCertificateSourceAndEndpoint(settings)
clientCertSource, endpoint, universe, err := getClientCertificateSourceAndEndpoint(settings)
if err != nil {
return &transportConfig{
clientCertSource: nil, endpoint: "", s2aAddress: "", s2aMTLSEndpoint: "",
Expand All @@ -87,6 +89,7 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
defaultTransportConfig := transportConfig{
clientCertSource: clientCertSource,
endpoint: endpoint,
universe: universe,
s2aAddress: "",
s2aMTLSEndpoint: "",
}
Expand Down Expand Up @@ -138,39 +141,78 @@ func isClientCertificateEnabled() bool {
return strings.ToLower(useClientCert) == "true"
}

// getEndpoint returns the endpoint for the service, taking into account the
// user-provided endpoint override "settings.Endpoint".
// getUniverse returns the effective universe.
func getUniverse(settings *DialSettings) string {
if settings.UniverseDomain != "" {
return settings.UniverseDomain
}
return getDefaultUniverse(settings)
}

// getDefaultUniverse returns the specified default universe, or the implicit default
// googleapis.com.
//
// If no endpoint override is specified, we will either return the default endpoint or
// the default mTLS endpoint if a client certificate is available.
// TODO: Once code generators supply WithDefaultUniverse as part of default options,
// this utility method may be removed, though this may need to be used for resolving
// env-based universe.
func getDefaultUniverse(settings *DialSettings) string {
if settings.DefaultUniverseDomain != "" {
return settings.DefaultUniverseDomain
}
return gdUniverse
}

var (
universePatternToken = "%%UNIVERSE%%"
gdUniverse = "googleapis.com"
ErrMTLSUniverse = errors.New("mTLS is not supported in any universe other than googleapis.com")
)

// getEndpointAndUniverse returns the endpoint for the service as well as the universe
// domain, taking in to account the various overrides the user may have provided.
//
// You can override the default endpoint choice (mtls vs. regular) by setting the
// GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
// This method will also select a default endpoint based on MTLS settings, controlled by
// the GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
//
// If the endpoint override is an address (host:port) rather than full base
// URL (ex. https://...), then the user-provided address will be merged into
// the default endpoint. For example, WithEndpoint("myhost:8000") and
// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
func getEndpointAndUniverse(settings *DialSettings, clientCertSource cert.Source, mtlsMode string) (string, string, error) {
// parameterize the default endpoints with the default universe.
defUniverse := getDefaultUniverse(settings)
defEndpoint := strings.Replace(settings.DefaultEndpoint, defUniverse, universePatternToken, 1)
defMTLSEndpoint := strings.Replace(settings.DefaultMTLSEndpoint, defUniverse, universePatternToken, 1)

universe := getUniverse(settings)
if settings.Endpoint == "" {
mtlsMode := getMTLSMode()
if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
return settings.DefaultMTLSEndpoint, nil
if universe != gdUniverse {
return "", "", ErrMTLSUniverse
}
return mergeDefaultEndpointUniverse(defMTLSEndpoint, universe), universe, nil
shollyman marked this conversation as resolved.
Show resolved Hide resolved
}
return settings.DefaultEndpoint, nil
return mergeDefaultEndpointUniverse(defEndpoint, universe), universe, nil
}
if strings.Contains(settings.Endpoint, "://") {
// User passed in a full URL path, use it verbatim.
return settings.Endpoint, nil
// user supplied an explicit endpoint with a full URL.
return settings.Endpoint, universe, nil
}
if settings.DefaultEndpoint == "" {
// If DefaultEndpoint is not configured, use the user provided endpoint verbatim.
// This allows a naked "host[:port]" URL to be used with GRPC Direct Path.
return settings.Endpoint, nil
if defEndpoint == "" {
// The default endpoint isn't configured, so use the use provided endpoint without
// normalizing.
return settings.Endpoint, universe, nil
}
merged, err := mergeEndpoints(settings.DefaultEndpoint, settings.Endpoint)
if err != nil {
return "", "", err
}
return merged, universe, nil
}

// Assume user-provided endpoint is host[:port], merge it with the default endpoint.
return mergeEndpoints(settings.DefaultEndpoint, settings.Endpoint)
// mergeDefaultEndpointUniverse handles replaceing a parameterized default endpoint with a universe value.
func mergeDefaultEndpointUniverse(endpoint, universe string) string {
return strings.Replace(endpoint, universePatternToken, universe, 1)
}

func getMTLSMode() string {
Expand Down
142 changes: 105 additions & 37 deletions internal/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ package internal

import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"os"
"testing"
"time"

"google.golang.org/api/internal/cert"
)

const (
Expand All @@ -20,59 +24,123 @@ const (

var dummyClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil }

func TestGetEndpoint(t *testing.T) {
func TestGetEndpointAndUniverse(t *testing.T) {

fakeCertSource := func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, fmt.Errorf("invalid source")
}
testCases := []struct {
UserEndpoint string
DefaultEndpoint string
Want string
WantErr bool
desc string
settings *DialSettings
clientCertSource cert.Source
mtlsMode string
wantEnd string
wantUni string
wantErr error
}{
{
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
Want: "https://foo.googleapis.com/bar/baz",
desc: "simple default",
settings: &DialSettings{
DefaultEndpoint: "https://foo.googleapis.com",
},
wantEnd: "https://foo.googleapis.com",
wantUni: gdUniverse,
},
{
UserEndpoint: "myhost:3999",
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
Want: "https://myhost:3999/bar/baz",
desc: "simple endpoint override",
settings: &DialSettings{
Endpoint: "https://bar.googleapis.com",
},
wantEnd: "https://bar.googleapis.com",
wantUni: gdUniverse,
},
{
UserEndpoint: "https://host/path/to/bar",
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
Want: "https://host/path/to/bar",
desc: "default + mtlsModeAuto + nocert",
settings: &DialSettings{
DefaultEndpoint: "https://foo.googleapis.com",
DefaultMTLSEndpoint: "https://foo.mtls.googleapis.com",
},
mtlsMode: mTLSModeAuto,
wantEnd: "https://foo.googleapis.com",
wantUni: gdUniverse,
},
{
UserEndpoint: "host:123",
DefaultEndpoint: "",
Want: "host:123",
desc: "default + mtlsModeAuto + cert",
settings: &DialSettings{
DefaultEndpoint: "https://foo.googleapis.com",
DefaultMTLSEndpoint: "https://foo.mtls.googleapis.com",
},
clientCertSource: fakeCertSource,
mtlsMode: mTLSModeAuto,
wantEnd: "https://foo.mtls.googleapis.com",
wantUni: gdUniverse,
},
{
desc: "default + mtlsModeAlways",
settings: &DialSettings{
DefaultEndpoint: "https://foo.googleapis.com",
DefaultMTLSEndpoint: "https://foo.mtls.googleapis.com",
},
mtlsMode: mTLSModeAlways,
wantEnd: "https://foo.mtls.googleapis.com",
wantUni: gdUniverse,
},
{
UserEndpoint: "host:123",
DefaultEndpoint: "default:443",
Want: "host:123",
desc: "custom uni + mtlsModeAlways",
settings: &DialSettings{
UniverseDomain: "blah.com",
DefaultEndpoint: "https://foo.googleapis.com",
DefaultMTLSEndpoint: "https://foo.mtls.googleapis.com",
},
mtlsMode: mTLSModeAlways,
wantErr: ErrMTLSUniverse,
},
{
UserEndpoint: "host:123",
DefaultEndpoint: "default:443/bar/baz",
Want: "host:123/bar/baz",
desc: "partial endpoint + default",
settings: &DialSettings{
Endpoint: "myhost:3999",
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
},
wantEnd: "https://myhost:3999/bar/baz",
wantUni: gdUniverse,
},
{
desc: "partial endpoint + default + custom uni",
settings: &DialSettings{
Endpoint: "myhost:3999",
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
UniverseDomain: "bar.com",
},
wantEnd: "https://myhost:3999/bar/baz",
wantUni: "bar.com",
},
{
desc: "partial endpoint + no default",
settings: &DialSettings{
Endpoint: "myhost:3999",
},
wantEnd: "myhost:3999",
wantUni: gdUniverse,
},
}

for _, tc := range testCases {
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
}, nil)
if tc.WantErr && err == nil {
t.Errorf("want err, got nil err")
continue
mtlsMode := mTLSModeAuto
if tc.mtlsMode != "" {
mtlsMode = tc.mtlsMode
}
if !tc.WantErr && err != nil {
t.Errorf("want nil err, got %v", err)
gotEnd, gotUni, gotErr := getEndpointAndUniverse(tc.settings, tc.clientCertSource, mtlsMode)
if tc.wantErr != nil {
if !errors.Is(gotErr, tc.wantErr) {
t.Errorf("%q: error mismatch, got %v want %v", tc.desc, gotErr, tc.wantErr)
}
continue
}
if tc.Want != got {
t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
} else {
if gotEnd != tc.wantEnd {
t.Errorf("%q: endpoint mismatch, got %q want %q", tc.desc, gotEnd, tc.wantEnd)
}
if gotUni != tc.wantUni {
t.Errorf("%q: universe mismatch, got %q want %q", tc.desc, gotUni, tc.wantUni)
}
}
}
}
Expand Down Expand Up @@ -114,11 +182,11 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
}

for _, tc := range testCases {
got, err := getEndpoint(&DialSettings{
got, _, err := getEndpointAndUniverse(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
}, dummyClientCertSource)
}, dummyClientCertSource, getMTLSMode())
if tc.WantErr && err == nil {
t.Errorf("want err, got nil err")
continue
Expand Down
2 changes: 1 addition & 1 deletion internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g

// Determine configurations for the OAuth2 transport, which is separate from the API transport.
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
clientCertSource, oauth2Endpoint, err := getClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
clientCertSource, oauth2Endpoint, _, err := getClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
if err != nil {
return nil, err
}
Expand Down