diff --git a/auth/credentials/impersonate/impersonate.go b/auth/credentials/impersonate/impersonate.go index 91b42bc3f7f3..3af236f7d07d 100644 --- a/auth/credentials/impersonate/impersonate.go +++ b/auth/credentials/impersonate/impersonate.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "net/http" + "strings" "time" "cloud.google.com/go/auth" @@ -30,11 +31,13 @@ import ( ) var ( - iamCredentialsEndpoint = "https://iamcredentials.googleapis.com" + universeDomainPlaceholder = "UNIVERSE_DOMAIN" + iamCredentialsEndpoint = "https://iamcredentials.UNIVERSE_DOMAIN" oauth2Endpoint = "https://oauth2.googleapis.com" errMissingTargetPrincipal = errors.New("impersonate: target service account must be provided") errMissingScopes = errors.New("impersonate: scopes must be provided") errLifetimeOverMax = errors.New("impersonate: max lifetime is 12 hours") + errClientAndCredentials = errors.New("impersonate: client and credentials must not both be provided") errUniverseNotSupportedDomainWideDelegation = errors.New("impersonate: service account user is configured for the credential. " + "Domain-wide delegation is not supported in universes other than googleapis.com") ) @@ -62,55 +65,49 @@ func NewCredentials(opts *CredentialsOptions) (*auth.Credentials, error) { var client *http.Client var creds *auth.Credentials - if opts.Client == nil && opts.Credentials == nil { + if opts.Client == nil { var err error - creds, err = credentials.DetectDefault(&credentials.DetectOptions{ - Scopes: []string{defaultScope}, - UseSelfSignedJWT: true, - }) - if err != nil { - return nil, err + if opts.Credentials == nil { + creds, err = credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{defaultScope}, + UseSelfSignedJWT: true, + }) + if err != nil { + return nil, err + } + } else { + creds = opts.Credentials } client, err = httptransport.NewClient(&httptransport.Options{ - Credentials: creds, + Credentials: creds, + UniverseDomain: opts.UniverseDomain, }) if err != nil { return nil, err } - } else if opts.Credentials != nil { - creds = opts.Credentials - client = internal.DefaultClient() - if err := httptransport.AddAuthorizationMiddleware(client, opts.Credentials); err != nil { - return nil, err - } } else { client = opts.Client } + universeDomainProvider := resolveUniverseDomainProvider(creds) // If a subject is specified a domain-wide delegation auth-flow is initiated // to impersonate as the provided subject (user). if opts.Subject != "" { - if !opts.isUniverseDomainGDU() { - return nil, errUniverseNotSupportedDomainWideDelegation - } - tp, err := user(opts, client, lifetime, isStaticToken) + tp, err := user(opts, client, lifetime, isStaticToken, universeDomainProvider) if err != nil { return nil, err } - var udp auth.CredentialsPropertyProvider - if creds != nil { - udp = auth.CredentialsPropertyFunc(creds.UniverseDomain) - } return auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: tp, - UniverseDomainProvider: udp, + UniverseDomainProvider: universeDomainProvider, }), nil } its := impersonatedTokenProvider{ - client: client, - targetPrincipal: opts.TargetPrincipal, - lifetime: fmt.Sprintf("%.fs", lifetime.Seconds()), + client: client, + targetPrincipal: opts.TargetPrincipal, + lifetime: fmt.Sprintf("%.fs", lifetime.Seconds()), + universeDomainProvider: universeDomainProvider, } for _, v := range opts.Delegates { its.delegates = append(its.delegates, formatIAMServiceAccountName(v)) @@ -125,16 +122,23 @@ func NewCredentials(opts *CredentialsOptions) (*auth.Credentials, error) { } } - var udp auth.CredentialsPropertyProvider - if creds != nil { - udp = auth.CredentialsPropertyFunc(creds.UniverseDomain) - } return auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: auth.NewCachedTokenProvider(its, tpo), - UniverseDomainProvider: udp, + UniverseDomainProvider: universeDomainProvider, }), nil } +// resolveUniverseDomainProvider returns the default service domain for a given +// Cloud universe. This is the universe domain configured for the credentials, +// which will be used in endpoint(s), and compared to the universe domain that +// is separately configured for the client. +func resolveUniverseDomainProvider(creds *auth.Credentials) auth.CredentialsPropertyProvider { + if creds != nil { + return auth.CredentialsPropertyFunc(creds.UniverseDomain) + } + return internal.StaticCredentialsProperty(internal.DefaultUniverseDomain) +} + // CredentialsOptions for generating an impersonated credential token. type CredentialsOptions struct { // TargetPrincipal is the email address of the service account to @@ -163,11 +167,13 @@ type CredentialsOptions struct { // will try to be detected from the environment. Optional. Credentials *auth.Credentials // Client configures the underlying client used to make network requests - // when fetching tokens. If provided the client should provide it's own + // when fetching tokens. If provided the client should provide its own // credentials at call time. Optional. Client *http.Client // UniverseDomain is the default service domain for a given Cloud universe. - // The default value is "googleapis.com". Optional. + // The default value is "googleapis.com". This is the universe domain + // configured for the client, which will be compared to the universe domain + // that is separately configured for the credentials. Optional. UniverseDomain string } @@ -184,22 +190,10 @@ func (o *CredentialsOptions) validate() error { if o.Lifetime.Hours() > 12 { return errLifetimeOverMax } - return nil -} - -// getUniverseDomain is the default service domain for a given Cloud universe. -// The default value is "googleapis.com". -func (o *CredentialsOptions) getUniverseDomain() string { - if o.UniverseDomain == "" { - return internal.DefaultUniverseDomain + if o.Client != nil && o.Credentials != nil { + return errClientAndCredentials } - return o.UniverseDomain -} - -// isUniverseDomainGDU returns true if the universe domain is the default Google -// universe. -func (o *CredentialsOptions) isUniverseDomainGDU() bool { - return o.getUniverseDomain() == internal.DefaultUniverseDomain + return nil } func formatIAMServiceAccountName(name string) string { @@ -218,7 +212,8 @@ type generateAccessTokenResponse struct { } type impersonatedTokenProvider struct { - client *http.Client + client *http.Client + universeDomainProvider auth.CredentialsPropertyProvider targetPrincipal string lifetime string @@ -237,7 +232,12 @@ func (i impersonatedTokenProvider) Token(ctx context.Context) (*auth.Token, erro if err != nil { return nil, fmt.Errorf("impersonate: unable to marshal request: %w", err) } - url := fmt.Sprintf("%s/v1/%s:generateAccessToken", iamCredentialsEndpoint, formatIAMServiceAccountName(i.targetPrincipal)) + universeDomain, err := i.universeDomainProvider.GetProperty(ctx) + if err != nil { + return nil, err + } + endpoint := strings.Replace(iamCredentialsEndpoint, universeDomainPlaceholder, universeDomain, 1) + url := fmt.Sprintf("%s/v1/%s:generateAccessToken", endpoint, formatIAMServiceAccountName(i.targetPrincipal)) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(b)) if err != nil { return nil, fmt.Errorf("impersonate: unable to create request: %w", err) diff --git a/auth/credentials/impersonate/impersonate_test.go b/auth/credentials/impersonate/impersonate_test.go index cd468a43a1f2..777b3a1e02c6 100644 --- a/auth/credentials/impersonate/impersonate_test.go +++ b/auth/credentials/impersonate/impersonate_test.go @@ -24,15 +24,18 @@ import ( "testing" "time" + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/internal" "github.com/google/go-cmp/cmp" ) func TestNewCredentials_serviceAccount(t *testing.T) { ctx := context.Background() tests := []struct { - name string - config CredentialsOptions - wantErr error + name string + config CredentialsOptions + wantErr error + wantUniverseDomain string }{ { name: "missing targetPrincipal", @@ -54,23 +57,55 @@ func TestNewCredentials_serviceAccount(t *testing.T) { }, wantErr: errLifetimeOverMax, }, + { + name: "credentials and client", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + Client: &http.Client{}, + Credentials: staticCredentials("googleapis.com"), + }, + wantErr: errClientAndCredentials, + }, { name: "works", config: CredentialsOptions{ TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", Scopes: []string{"scope"}, }, - wantErr: nil, + wantErr: nil, + wantUniverseDomain: "googleapis.com", }, { - name: "universe domain", + name: "universe domain from options", config: CredentialsOptions{ TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", Scopes: []string{"scope"}, - Subject: "admin@example.com", UniverseDomain: "example.com", }, - wantErr: errUniverseNotSupportedDomainWideDelegation, + wantErr: nil, + wantUniverseDomain: "googleapis.com", // From creds, not CredentialsOptions.UniverseDomain + }, + { + name: "universe domain from options and credentials", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + UniverseDomain: "NOT.example.com", + Credentials: staticCredentials("example.com"), + }, + wantErr: nil, + wantUniverseDomain: "example.com", // From creds, not CredentialsOptions.UniverseDomain + }, + { + name: "universe domain from credentials", + config: CredentialsOptions{ + TargetPrincipal: "foo@project-id.iam.gserviceaccount.com", + Scopes: []string{"scope"}, + Credentials: staticCredentials("example.com"), + }, + wantErr: nil, + wantUniverseDomain: "example.com", }, } @@ -80,53 +115,66 @@ func TestNewCredentials_serviceAccount(t *testing.T) { saTok := "sa-token" client := &http.Client{ Transport: RoundTripFn(func(req *http.Request) *http.Response { - if strings.Contains(req.URL.Path, "generateAccessToken") { - defer req.Body.Close() - b, err := io.ReadAll(req.Body) - if err != nil { - t.Error(err) - } - var r generateAccessTokenRequest - if err := json.Unmarshal(b, &r); err != nil { - t.Error(err) - } - if !cmp.Equal(r.Scope, tt.config.Scopes) { - t.Errorf("got %v, want %v", r.Scope, tt.config.Scopes) - } - if !strings.Contains(req.URL.Path, tt.config.TargetPrincipal) { - t.Errorf("got %q, want %q", req.URL.Path, tt.config.TargetPrincipal) - } + if !strings.Contains(req.URL.Path, "generateAccessToken") { + t.Fatal("path must contain 'generateAccessToken'") + } + defer req.Body.Close() + b, err := io.ReadAll(req.Body) + if err != nil { + t.Error(err) + } + var r generateAccessTokenRequest + if err := json.Unmarshal(b, &r); err != nil { + t.Error(err) + } + if !cmp.Equal(r.Scope, tt.config.Scopes) { + t.Errorf("got %v, want %v", r.Scope, tt.config.Scopes) + } + if !strings.Contains(req.URL.Path, tt.config.TargetPrincipal) { + t.Errorf("got %q, want %q", req.URL.Path, tt.config.TargetPrincipal) + } + if !strings.Contains(req.URL.Hostname(), tt.wantUniverseDomain) { + t.Errorf("got %q, want %q", req.URL.Hostname(), tt.wantUniverseDomain) + } - resp := generateAccessTokenResponse{ - AccessToken: saTok, - ExpireTime: time.Now().Format(time.RFC3339), - } - b, err = json.Marshal(&resp) - if err != nil { - t.Fatalf("unable to marshal response: %v", err) - } - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(b)), - Header: http.Header{}, - } + resp := generateAccessTokenResponse{ + AccessToken: saTok, + ExpireTime: time.Now().Format(time.RFC3339), + } + b, err = json.Marshal(&resp) + if err != nil { + t.Fatalf("unable to marshal response: %v", err) + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(b)), + Header: http.Header{}, } - return nil }), } - tt.config.Client = client - ts, err := NewCredentials(&tt.config) + if tt.config.Credentials == nil { + tt.config.Client = client + } + creds, err := NewCredentials(&tt.config) if err != nil { if err != tt.wantErr { t.Fatalf("err: %v", err) } + } else if tt.config.Credentials != nil { + // config.Credentials is invalid for Token request, just assert universe domain. + if got, _ := creds.UniverseDomain(ctx); got != tt.wantUniverseDomain { + t.Errorf("got %q, want %q", got, tt.wantUniverseDomain) + } } else { - tok, err := ts.Token(ctx) + tok, err := creds.Token(ctx) if err != nil { - t.Fatal(err) + t.Error(err) } if tok.Value != saTok { - t.Fatalf("got %q, want %q", tok.Value, saTok) + t.Errorf("got %q, want %q", tok.Value, saTok) + } + if got, _ := creds.UniverseDomain(ctx); got != tt.wantUniverseDomain { + t.Errorf("got %q, want %q", got, tt.wantUniverseDomain) } } }) @@ -137,44 +185,15 @@ type RoundTripFn func(req *http.Request) *http.Response func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } -func TestCredentialsOptions_UniverseDomain(t *testing.T) { - testCases := []struct { - name string - opts *CredentialsOptions - wantUniverseDomain string - wantIsGDU bool - }{ - { - name: "empty", - opts: &CredentialsOptions{}, - wantUniverseDomain: "googleapis.com", - wantIsGDU: true, - }, - { - name: "defaults", - opts: &CredentialsOptions{ - UniverseDomain: "googleapis.com", - }, - wantUniverseDomain: "googleapis.com", - wantIsGDU: true, - }, - { - name: "non-GDU", - opts: &CredentialsOptions{ - UniverseDomain: "example.com", - }, - wantUniverseDomain: "example.com", - wantIsGDU: false, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if got := tc.opts.getUniverseDomain(); got != tc.wantUniverseDomain { - t.Errorf("got %v, want %v", got, tc.wantUniverseDomain) - } - if got := tc.opts.isUniverseDomainGDU(); got != tc.wantIsGDU { - t.Errorf("got %v, want %v", got, tc.wantIsGDU) - } - }) - } +func staticCredentials(universeDomain string) *auth.Credentials { + return auth.NewCredentials(&auth.CredentialsOptions{ + TokenProvider: staticTokenProvider("base credentials Token should never be called"), + UniverseDomainProvider: internal.StaticCredentialsProperty(universeDomain), + }) +} + +type staticTokenProvider string + +func (s staticTokenProvider) Token(context.Context) (*auth.Token, error) { + return &auth.Token{Value: string(s)}, nil } diff --git a/auth/credentials/impersonate/integration_test.go b/auth/credentials/impersonate/integration_test.go index fd12b4fffc75..2637fc24fae6 100644 --- a/auth/credentials/impersonate/integration_test.go +++ b/auth/credentials/impersonate/integration_test.go @@ -58,15 +58,28 @@ func TestMain(m *testing.M) { readerEmail = os.Getenv(envReaderEmail) writerEmail = os.Getenv(envWriterEmail) - if !testing.Short() && (baseKeyFile == "" || - readerKeyFile == "" || - readerEmail == "" || - writerEmail == "" || - projectID == "") { - log.Println("required environment variable not set, skipping") - os.Exit(0) + if !testing.Short() { + missing := []string{} + if baseKeyFile == "" { + missing = append(missing, credsfile.GoogleAppCredsEnvVar) + } + if projectID == "" { + missing = append(missing, envProjectID) + } + if readerKeyFile == "" { + missing = append(missing, envReaderCreds) + } + if readerEmail == "" { + missing = append(missing, envReaderEmail) + } + if writerEmail == "" { + missing = append(missing, envWriterEmail) + } + if len(missing) > 0 { + log.Printf("skipping, required environment variable(s) not set: %s\n", missing) + os.Exit(0) + } } - os.Exit(m.Run()) } diff --git a/auth/credentials/impersonate/user.go b/auth/credentials/impersonate/user.go index 1acaaa922d9d..b5e5fc8f6645 100644 --- a/auth/credentials/impersonate/user.go +++ b/auth/credentials/impersonate/user.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -30,12 +31,16 @@ import ( // user provides an auth flow for domain-wide delegation, setting // CredentialsConfig.Subject to be the impersonated user. -func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool) (auth.TokenProvider, error) { +func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool, universeDomainProvider auth.CredentialsPropertyProvider) (auth.TokenProvider, error) { + if opts.Subject == "" { + return nil, errors.New("CredentialsConfig.Subject must not be empty") + } u := userTokenProvider{ - client: client, - targetPrincipal: opts.TargetPrincipal, - subject: opts.Subject, - lifetime: lifetime, + client: client, + targetPrincipal: opts.TargetPrincipal, + subject: opts.Subject, + lifetime: lifetime, + universeDomainProvider: universeDomainProvider, } u.delegates = make([]string, len(opts.Delegates)) for i, v := range opts.Delegates { @@ -84,14 +89,25 @@ type exchangeTokenResponse struct { type userTokenProvider struct { client *http.Client - targetPrincipal string - subject string - scopes []string - lifetime time.Duration - delegates []string + targetPrincipal string + subject string + scopes []string + lifetime time.Duration + delegates []string + universeDomainProvider auth.CredentialsPropertyProvider } func (u userTokenProvider) Token(ctx context.Context) (*auth.Token, error) { + // Because a subject is specified a domain-wide delegation auth-flow is initiated + // to impersonate as the provided subject (user). + // Return error if users try to use domain-wide delegation in a non-GDU universe. + ud, err := u.universeDomainProvider.GetProperty(ctx) + if err != nil { + return nil, err + } + if ud != internal.DefaultUniverseDomain { + return nil, errUniverseNotSupportedDomainWideDelegation + } signedJWT, err := u.signJWT(ctx) if err != nil { return nil, err diff --git a/auth/credentials/impersonate/user_test.go b/auth/credentials/impersonate/user_test.go index adb4612d5eca..87b897a1c248 100644 --- a/auth/credentials/impersonate/user_test.go +++ b/auth/credentials/impersonate/user_test.go @@ -37,6 +37,7 @@ func TestNewCredentials_user(t *testing.T) { lifetime time.Duration subject string wantErr bool + wantTokenErr bool universeDomain string }{ { @@ -60,14 +61,13 @@ func TestNewCredentials_user(t *testing.T) { targetPrincipal: "foo@project-id.iam.gserviceaccount.com", scopes: []string{"scope"}, subject: "admin@example.com", - wantErr: false, }, { name: "universeDomain", targetPrincipal: "foo@project-id.iam.gserviceaccount.com", scopes: []string{"scope"}, subject: "admin@example.com", - wantErr: true, + wantTokenErr: true, // Non-GDU Universe Domain should result in error if // CredentialsConfig.Subject is present for domain-wide delegation. universeDomain: "example.com", @@ -152,6 +152,9 @@ func TestNewCredentials_user(t *testing.T) { t.Fatal(err) } tok, err := ts.Token(ctx) + if tt.wantTokenErr && err != nil { + return + } if err != nil { t.Fatal(err) } diff --git a/auth/httptransport/httptransport.go b/auth/httptransport/httptransport.go index 30fedf9562f9..38e8c99399bb 100644 --- a/auth/httptransport/httptransport.go +++ b/auth/httptransport/httptransport.go @@ -155,6 +155,8 @@ type InternalOptions struct { // transport that sets the Authorization header with the value produced by the // provided [cloud.google.com/go/auth.Credentials]. An error is returned only // if client or creds is nil. +// +// This function does not support setting a universe domain value on the client. func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) error { if client == nil || creds == nil { return fmt.Errorf("httptransport: client and tp must not be nil") @@ -173,7 +175,6 @@ func AddAuthorizationMiddleware(client *http.Client, creds *auth.Credentials) er client.Transport = &authTransport{ creds: creds, base: base, - // TODO(quartzmo): Somehow set clientUniverseDomain from impersonate calls. } return nil } diff --git a/spanner/client.go b/spanner/client.go index 95bb7f7a606d..889fb43aa6da 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -25,7 +25,6 @@ import ( "regexp" "strconv" "strings" - "sync" "time" "cloud.google.com/go/internal/trace" @@ -651,29 +650,19 @@ func metricsInterceptor() grpc.UnaryClientInterceptor { // wrappedStream wraps around the embedded grpc.ClientStream, and intercepts the RecvMsg and // SendMsg method call. type wrappedStream struct { - sync.Mutex - isFirstRecv bool - method string - target string + method string + target string grpc.ClientStream } func (w *wrappedStream) RecvMsg(m any) error { - attempt := &attemptTracer{} - attempt.setStartTime(time.Now()) err := w.ClientStream.RecvMsg(m) - statusCode, _ := status.FromError(err) ctx := w.ClientStream.Context() mt, ok := ctx.Value(metricsTracerKey).(*builtinMetricsTracer) - if !ok || !w.isFirstRecv { + if !ok { return err } - w.Lock() - w.isFirstRecv = false - w.Unlock() mt.method = w.method - mt.currOp.incrementAttemptCount() - mt.currOp.currAttempt = attempt if strings.HasPrefix(w.target, "google-c2p") { mt.currOp.setDirectPathEnabled(true) } @@ -687,9 +676,9 @@ func (w *wrappedStream) RecvMsg(m any) error { } } } - mt.currOp.currAttempt.setStatus(statusCode.Code().String()) - mt.currOp.currAttempt.setDirectPathUsed(isDirectPathUsed) - recordAttemptCompletion(mt) + if mt.currOp.currAttempt != nil { + mt.currOp.currAttempt.setDirectPathUsed(isDirectPathUsed) + } return err } @@ -698,7 +687,7 @@ func (w *wrappedStream) SendMsg(m any) error { } func newWrappedStream(s grpc.ClientStream, method, target string) grpc.ClientStream { - return &wrappedStream{ClientStream: s, method: method, target: target, isFirstRecv: true} + return &wrappedStream{ClientStream: s, method: method, target: target} } // metricsInterceptor is a gRPC stream client interceptor that records metrics for stream RPCs. diff --git a/spanner/metrics_test.go b/spanner/metrics_test.go index e05cf8cba5be..605fc2501d4f 100644 --- a/spanner/metrics_test.go +++ b/spanner/metrics_test.go @@ -193,21 +193,26 @@ func TestNewBuiltinMetricsTracerFactory(t *testing.T) { // Get new CreateServiceTimeSeriesRequests gotCreateTSCalls := monitoringServer.CreateServiceTimeSeriesRequests() var gotExpectedMethods []string - gotOTELValues := make(map[string]map[string]int64) + gotOTELCountValues := make(map[string]map[string]int64) + gotOTELLatencyValues := make(map[string]map[string]float64) for _, gotCreateTSCall := range gotCreateTSCalls { gotMetricTypesPerMethod := make(map[string][]string) for _, ts := range gotCreateTSCall.TimeSeries { gotMetricTypesPerMethod[ts.Metric.GetLabels()["method"]] = append(gotMetricTypesPerMethod[ts.Metric.GetLabels()["method"]], ts.Metric.Type) - if _, ok := gotOTELValues[ts.Metric.GetLabels()["method"]]; !ok { - gotOTELValues[ts.Metric.GetLabels()["method"]] = make(map[string]int64) + if _, ok := gotOTELCountValues[ts.Metric.GetLabels()["method"]]; !ok { + gotOTELCountValues[ts.Metric.GetLabels()["method"]] = make(map[string]int64) + gotOTELLatencyValues[ts.Metric.GetLabels()["method"]] = make(map[string]float64) gotExpectedMethods = append(gotExpectedMethods, ts.Metric.GetLabels()["method"]) } if ts.MetricKind == metric.MetricDescriptor_CUMULATIVE && ts.GetValueType() == metric.MetricDescriptor_INT64 { - gotOTELValues[ts.Metric.GetLabels()["method"]][ts.Metric.Type] = ts.Points[0].Value.GetInt64Value() + gotOTELCountValues[ts.Metric.GetLabels()["method"]][ts.Metric.Type] = ts.Points[0].Value.GetInt64Value() } else { for _, p := range ts.Points { - if p.Value.GetInt64Value() > int64(elapsedTime) { - t.Errorf("Value %v is greater than elapsed time %v", p.Value.GetInt64Value(), elapsedTime) + if _, ok := gotOTELCountValues[ts.Metric.GetLabels()["method"]][ts.Metric.Type]; !ok { + gotOTELLatencyValues[ts.Metric.GetLabels()["method"]][ts.Metric.Type] = p.Value.GetDistributionValue().Mean + } else { + // sum up all attempt latencies + gotOTELLatencyValues[ts.Metric.GetLabels()["method"]][ts.Metric.Type] += p.Value.GetDistributionValue().Mean } } } @@ -216,7 +221,7 @@ func TestNewBuiltinMetricsTracerFactory(t *testing.T) { sort.Strings(gotMetricTypes) sort.Strings(test.wantOTELMetrics[method]) if !testutil.Equal(gotMetricTypes, test.wantOTELMetrics[method]) { - t.Errorf("Metric types missing in req. %s got: %v, want: %v", method, gotMetricTypes, wantMetricTypesGCM) + t.Errorf("Metric types missing in req. %s got: %v, want: %v", method, gotMetricTypes, test.wantOTELMetrics[method]) } } } @@ -226,10 +231,22 @@ func TestNewBuiltinMetricsTracerFactory(t *testing.T) { } for method, wantOTELValues := range test.wantOTELValue { for metricName, wantValue := range wantOTELValues { - if gotOTELValues[method][metricName] != wantValue { - t.Errorf("OTEL value for %s, %s: got: %v, want: %v", method, metricName, gotOTELValues[method][metricName], wantValue) + if gotOTELCountValues[method][metricName] != wantValue { + t.Errorf("OTEL value for %s, %s: got: %v, want: %v", method, metricName, gotOTELCountValues[method][metricName], wantValue) } } + // For StreamingRead, verify operation latency includes all attempt latencies + opLatency := gotOTELLatencyValues[method][nativeMetricsPrefix+metricNameOperationLatencies] + attemptLatency := gotOTELLatencyValues[method][nativeMetricsPrefix+metricNameAttemptLatencies] + // expect opLatency and attemptLatency to be non-zero + if opLatency == 0 || attemptLatency == 0 { + t.Errorf("Operation and attempt latencies should be non-zero for %s: operation_latency=%v, attempt_latency=%v", + method, opLatency, attemptLatency) + } + if opLatency <= attemptLatency { + t.Errorf("Operation latency should be greater than attempt latency for %s: operation_latency=%v, attempt_latency=%v", + method, opLatency, attemptLatency) + } } gotCreateTSCallsCount := len(gotCreateTSCalls) if gotCreateTSCallsCount < test.wantCreateTSCallsCount { diff --git a/spanner/read.go b/spanner/read.go index 34af289004dc..2752b0ec93e0 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -30,6 +30,7 @@ import ( "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" proto3 "google.golang.org/protobuf/types/known/structpb" ) @@ -84,12 +85,13 @@ func streamWithReplaceSessionFunc( ctx, cancel := context.WithCancel(ctx) ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.RowIterator") return &RowIterator{ - streamd: newResumableStreamDecoder(ctx, logger, meterTracerFactory, rpc, replaceSession), - rowd: &partialResultSetDecoder{}, - setTransactionID: setTransactionID, - setTimestamp: setTimestamp, - release: release, - cancel: cancel, + meterTracerFactory: meterTracerFactory, + streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession), + rowd: &partialResultSetDecoder{}, + setTransactionID: setTransactionID, + setTimestamp: setTimestamp, + release: release, + cancel: cancel, } } @@ -120,15 +122,17 @@ type RowIterator struct { // RowIterator.Next() returned an error that is not equal to iterator.Done. Metadata *sppb.ResultSetMetadata - streamd *resumableStreamDecoder - rowd *partialResultSetDecoder - setTransactionID func(transactionID) - setTimestamp func(time.Time) - release func(error) - cancel func() - err error - rows []*Row - sawStats bool + ctx context.Context + meterTracerFactory *builtinMetricsTracerFactory + streamd *resumableStreamDecoder + rowd *partialResultSetDecoder + setTransactionID func(transactionID) + setTimestamp func(time.Time) + release func(error) + cancel func() + err error + rows []*Row + sawStats bool } // this is for safety from future changes to RowIterator making sure that it implements rowIterator interface. @@ -138,10 +142,31 @@ var _ rowIterator = (*RowIterator)(nil) // there are no more results. Once Next returns Done, all subsequent calls // will return Done. func (r *RowIterator) Next() (*Row, error) { + mt := r.meterTracerFactory.createBuiltinMetricsTracer(r.ctx) if r.err != nil { return nil, r.err } - for len(r.rows) == 0 && r.streamd.next() { + // Start new attempt + mt.currOp.incrementAttemptCount() + mt.currOp.currAttempt = &attemptTracer{ + startTime: time.Now(), + } + defer func() { + // when mt method is not empty, it means the RPC was sent to backend and native metrics attributes were captured in interceptor + if mt.method != "" { + statusCode, _ := convertToGrpcStatusErr(r.err) + // record the attempt completion + mt.currOp.currAttempt.setStatus(statusCode.String()) + recordAttemptCompletion(&mt) + mt.currOp.setStatus(statusCode.String()) + // Record operation completion. + // Operational_latencies metric captures the full picture of all attempts including retries. + recordOperationCompletion(&mt) + mt.currOp.currAttempt = nil + } + }() + + for len(r.rows) == 0 && r.streamd.next(&mt) { prs := r.streamd.get() if r.setTransactionID != nil { // this is when Read/Query is executed using ReadWriteTransaction @@ -406,20 +431,17 @@ type resumableStreamDecoder struct { // backoff is used for the retry settings backoff gax.Backoff - - meterTracerFactory *builtinMetricsTracerFactory } // newResumableStreamDecoder creates a new resumeableStreamDecoder instance. // Parameter rpc should be a function that creates a new stream beginning at the // restartToken if non-nil. -func newResumableStreamDecoder(ctx context.Context, logger *log.Logger, meterTracerFactory *builtinMetricsTracerFactory, rpc func(ct context.Context, restartToken []byte) (streamingReceiver, error), replaceSession func(ctx context.Context) error) *resumableStreamDecoder { +func newResumableStreamDecoder(ctx context.Context, logger *log.Logger, rpc func(ct context.Context, restartToken []byte) (streamingReceiver, error), replaceSession func(ctx context.Context) error) *resumableStreamDecoder { return &resumableStreamDecoder{ ctx: ctx, logger: logger, rpc: rpc, replaceSessionFunc: replaceSession, - meterTracerFactory: meterTracerFactory, maxBytesBetweenResumeTokens: atomic.LoadInt32(&maxBytesBetweenResumeTokens), backoff: DefaultRetryBackoff, } @@ -503,25 +525,18 @@ var ( maxBytesBetweenResumeTokens = int32(128 * 1024 * 1024) ) -func (d *resumableStreamDecoder) next() bool { - mt := d.meterTracerFactory.createBuiltinMetricsTracer(d.ctx) - defer func() { - if mt.method != "" { - statusCode, _ := convertToGrpcStatusErr(d.lastErr()) - mt.currOp.setStatus(statusCode.String()) - recordOperationCompletion(&mt) - } - }() +func (d *resumableStreamDecoder) next(mt *builtinMetricsTracer) bool { retryer := onCodes(d.backoff, codes.Unavailable, codes.ResourceExhausted, codes.Internal) for { switch d.state { case unConnected: // If no gRPC stream is available, try to initiate one. - d.stream, d.err = d.rpc(context.WithValue(d.ctx, metricsTracerKey, &mt), d.resumeToken) + d.stream, d.err = d.rpc(context.WithValue(d.ctx, metricsTracerKey, mt), d.resumeToken) if d.err == nil { d.changeState(queueingRetryable) continue } + delay, shouldRetry := retryer.Retry(d.err) if !shouldRetry { d.changeState(aborted) @@ -529,6 +544,13 @@ func (d *resumableStreamDecoder) next() bool { } trace.TracePrintf(d.ctx, nil, "Backing off stream read for %s", delay) if err := gax.Sleep(d.ctx, delay); err == nil { + // record the attempt completion + mt.currOp.currAttempt.setStatus(status.Code(d.err).String()) + recordAttemptCompletion(mt) + mt.currOp.incrementAttemptCount() + mt.currOp.currAttempt = &attemptTracer{ + startTime: time.Now(), + } // Be explicit about state transition, although the // state doesn't actually change. State transition // will be triggered only by RPC activity, regardless of @@ -549,7 +571,7 @@ func (d *resumableStreamDecoder) next() bool { // Only the case that receiving queue is empty could cause // peekLast to return error and in such case, we should try to // receive from stream. - d.tryRecv(retryer) + d.tryRecv(mt, retryer) continue } if d.isNewResumeToken(last.ResumeToken) { @@ -578,7 +600,7 @@ func (d *resumableStreamDecoder) next() bool { } // Needs to receive more from gRPC stream till a new resume token // is observed. - d.tryRecv(retryer) + d.tryRecv(mt, retryer) continue case aborted: // Discard all pending items because none of them should be yield @@ -604,7 +626,7 @@ func (d *resumableStreamDecoder) next() bool { } // tryRecv attempts to receive a PartialResultSet from gRPC stream. -func (d *resumableStreamDecoder) tryRecv(retryer gax.Retryer) { +func (d *resumableStreamDecoder) tryRecv(mt *builtinMetricsTracer, retryer gax.Retryer) { var res *sppb.PartialResultSet res, d.err = d.stream.Recv() if d.err == nil { @@ -615,12 +637,16 @@ func (d *resumableStreamDecoder) tryRecv(retryer gax.Retryer) { d.changeState(d.state) return } + if d.err == io.EOF { d.err = nil d.changeState(finished) return } + if d.replaceSessionFunc != nil && isSessionNotFoundError(d.err) && d.resumeToken == nil { + mt.currOp.currAttempt.setStatus(status.Code(d.err).String()) + recordAttemptCompletion(mt) // A 'Session not found' error occurred before we received a resume // token and a replaceSessionFunc function is defined. Try to restart // the stream on a new session. @@ -629,7 +655,13 @@ func (d *resumableStreamDecoder) tryRecv(retryer gax.Retryer) { d.changeState(aborted) return } + mt.currOp.incrementAttemptCount() + mt.currOp.currAttempt = &attemptTracer{ + startTime: time.Now(), + } } else { + mt.currOp.currAttempt.setStatus(status.Code(d.err).String()) + recordAttemptCompletion(mt) delay, shouldRetry := retryer.Retry(d.err) if !shouldRetry || d.state != queueingRetryable { d.changeState(aborted) @@ -640,6 +672,10 @@ func (d *resumableStreamDecoder) tryRecv(retryer gax.Retryer) { d.changeState(aborted) return } + mt.currOp.incrementAttemptCount() + mt.currOp.currAttempt = &attemptTracer{ + startTime: time.Now(), + } } // Clear error and retry the stream. d.err = nil diff --git a/spanner/read_test.go b/spanner/read_test.go index e15ced784870..99ab2362a1ee 100644 --- a/spanner/read_test.go +++ b/spanner/read_test.go @@ -28,6 +28,7 @@ import ( sppb "cloud.google.com/go/spanner/apiv1/spannerpb" . "cloud.google.com/go/spanner/internal/testutil" "github.com/googleapis/gax-go/v2" + "go.opentelemetry.io/otel/metric/noop" "google.golang.org/api/iterator" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -805,10 +806,10 @@ func TestRsdNonblockingStates(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + mt := c.metricsTracerFactory.createBuiltinMetricsTracer(ctx) r := newResumableStreamDecoder( ctx, nil, - c.metricsTracerFactory, test.rpc, nil, ) @@ -882,8 +883,12 @@ func TestRsdNonblockingStates(t *testing.T) { } return } + mt.currOp.incrementAttemptCount() + mt.currOp.currAttempt = &attemptTracer{ + startTime: time.Now(), + } // Receive next decoded item. - if r.next() { + if r.next(&mt) { rs = append(rs, r.get()) } } @@ -1099,10 +1104,10 @@ func TestRsdBlockingStates(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) defer cancel() + mt := c.metricsTracerFactory.createBuiltinMetricsTracer(ctx) r := newResumableStreamDecoder( ctx, nil, - c.metricsTracerFactory, test.rpc, nil, ) @@ -1146,8 +1151,12 @@ func TestRsdBlockingStates(t *testing.T) { var rs []*sppb.PartialResultSet rowsFetched := make(chan int) go func() { + mt.currOp.incrementAttemptCount() + mt.currOp.currAttempt = &attemptTracer{ + startTime: time.Now(), + } for { - if !r.next() { + if !r.next(&mt) { // Note that r.Next also exits on context cancel/timeout. close(rowsFetched) return @@ -1261,10 +1270,10 @@ func TestQueueBytes(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + mt := c.metricsTracerFactory.createBuiltinMetricsTracer(ctx) decoder := newResumableStreamDecoder( ctx, nil, - c.metricsTracerFactory, func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, @@ -1286,24 +1295,24 @@ func TestQueueBytes(t *testing.T) { ResumeToken: rt1, }) - decoder.next() - decoder.next() - decoder.next() + decoder.next(&mt) + decoder.next(&mt) + decoder.next(&mt) if got, want := decoder.bytesBetweenResumeTokens, int32(2*sizeOfPRS); got != want { t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want) } - decoder.next() + decoder.next(&mt) if decoder.bytesBetweenResumeTokens != 0 { t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens) } - decoder.next() + decoder.next(&mt) if got, want := decoder.bytesBetweenResumeTokens, int32(sizeOfPRS); got != want { t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", got, want) } - decoder.next() + decoder.next(&mt) if decoder.bytesBetweenResumeTokens != 0 { t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", decoder.bytesBetweenResumeTokens) } @@ -1769,8 +1778,12 @@ func TestIteratorStopEarly(t *testing.T) { } func TestIteratorWithError(t *testing.T) { + metricsTracerFactory, err := newBuiltinMetricsTracerFactory(context.Background(), "projects/my-project/instances/my-instance/databases/my-database", noop.NewMeterProvider()) + if err != nil { + t.Fatalf("failed to create metrics tracer factory: %v", err) + } injected := errors.New("Failed iterator") - iter := RowIterator{err: injected} + iter := RowIterator{meterTracerFactory: metricsTracerFactory, err: injected} defer iter.Stop() if _, err := iter.Next(); err != injected { t.Fatalf("Expected error: %v, got %v", injected, err) diff --git a/spanner/transaction.go b/spanner/transaction.go index 86fedd7a4d8b..9e3e107d1065 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -242,16 +242,22 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key ) kset, err := keys.keySetProto() if err != nil { - return &RowIterator{err: err} + return &RowIterator{ + meterTracerFactory: t.sp.sc.metricsTracerFactory, + err: err} } if sh, ts, err = t.acquire(ctx); err != nil { - return &RowIterator{err: err} + return &RowIterator{ + meterTracerFactory: t.sp.sc.metricsTracerFactory, + err: err} } // Cloud Spanner will return "Session not found" on bad sessions. client := sh.getClient() if client == nil { // Might happen if transaction is closed in the middle of a API call. - return &RowIterator{err: errSessionClosed(sh)} + return &RowIterator{ + meterTracerFactory: t.sp.sc.metricsTracerFactory, + err: errSessionClosed(sh)} } index := t.ro.Index limit := t.ro.Limit @@ -573,7 +579,10 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que defer func() { trace.EndSpan(ctx, ri.err) }() req, sh, err := t.prepareExecuteSQL(ctx, statement, options) if err != nil { - return &RowIterator{err: err} + return &RowIterator{ + meterTracerFactory: t.sp.sc.metricsTracerFactory, + err: err, + } } var setTransactionID func(transactionID) if _, ok := req.Transaction.GetSelector().(*sppb.TransactionSelector_Begin); ok {