Skip to content

Commit

Permalink
grpcutil: new+fast FromIncomingContext variant
Browse files Browse the repository at this point in the history
If we want to properly leverage gRPC metadata, we need to find a way
around the fact that the default `metatada.FromIncomingContext` is
slow and expensive. This patch fixes it.

```
BenchmarkFromIncomingContext/stdlib-32          10987958               327.5 ns/op           432 B/op          3 allocs/op
BenchmarkFromIncomingContext/fast-32            698772889                5.152 ns/op           0 B/op          0 allocs/op
```

Release note: None
  • Loading branch information
knz committed Feb 2, 2023
1 parent 3f73cfb commit 4384a92
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 11 deletions.
1 change: 1 addition & 0 deletions pkg/blobs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ go_library(
"//pkg/rpc",
"//pkg/rpc/nodedialer",
"//pkg/util/fileutil",
"//pkg/util/grpcutil",
"//pkg/util/ioctx",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_errors//oserror",
Expand Down
13 changes: 6 additions & 7 deletions pkg/blobs/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ import (
"io"

"github.com/cockroachdb/cockroach/pkg/blobs/blobspb"
"github.com/cockroachdb/cockroach/pkg/util/grpcutil"
"github.com/cockroachdb/cockroach/pkg/util/ioctx"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/errors/oserror"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -63,20 +63,19 @@ func (s *Service) GetStream(req *blobspb.GetRequest, stream blobspb.Blob_GetStre

// PutStream implements the gRPC service.
func (s *Service) PutStream(stream blobspb.Blob_PutStreamServer) error {
md, ok := metadata.FromIncomingContext(stream.Context())
filename, ok := grpcutil.FastFirstValueFromIncomingContext(stream.Context(), "filename")
if !ok {
return errors.New("could not fetch metadata")
return errors.New("could not fetch metadata or no filename in metadata")
}
filename := md.Get("filename")
if len(filename) < 1 || filename[0] == "" {
return errors.New("no filename in metadata")
if filename == "" {
return errors.New("invalid filename in metadata")
}
reader := newPutStreamReader(stream)
defer reader.Close(stream.Context())
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()

w, err := s.localStorage.Writer(ctx, filename[0])
w, err := s.localStorage.Writer(ctx, filename)
if err != nil {
cancel()
return err
Expand Down
3 changes: 2 additions & 1 deletion pkg/server/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/ui"
"github.com/cockroachdb/cockroach/pkg/util/grpcutil"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/protoutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
Expand Down Expand Up @@ -307,7 +308,7 @@ func (s *authenticationServer) createSessionFor(
func (s *authenticationServer) UserLogout(
ctx context.Context, req *serverpb.UserLogoutRequest,
) (*serverpb.UserLogoutResponse, error) {
md, ok := metadata.FromIncomingContext(ctx)
md, ok := grpcutil.FastFromIncomingContext(ctx)
if !ok {
return nil, apiInternalError(ctx, fmt.Errorf("couldn't get incoming context"))
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ type metricMarshaler interface {
}

func propagateGatewayMetadata(ctx context.Context) context.Context {
if md, ok := metadata.FromIncomingContext(ctx); ok {
if md, ok := grpcutil.FastFromIncomingContext(ctx); ok {
return metadata.NewOutgoingContext(ctx, md)
}
return ctx
Expand Down Expand Up @@ -3549,7 +3549,7 @@ func marshalJSONResponse(value interface{}) (*serverpb.JSONResponse, error) {
}

func userFromContext(ctx context.Context) (res username.SQLUsername, err error) {
md, ok := metadata.FromIncomingContext(ctx)
md, ok := grpcutil.FastFromIncomingContext(ctx)
if !ok {
return username.RootUserName(), nil
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/util/grpcutil/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "grpcutil",
srcs = [
"fast_metadata.go",
"grpc_err_redaction.go",
"grpc_log.go",
"grpc_util.go",
Expand All @@ -24,6 +25,7 @@ go_library(
"@com_github_gogo_status//:status",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//grpclog",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
],
)
Expand All @@ -32,6 +34,7 @@ go_test(
name = "grpcutil_test",
size = "small",
srcs = [
"fast_metadata_test.go",
"grpc_err_redaction_test.go",
"grpc_log_test.go",
"grpc_util_test.go",
Expand All @@ -53,6 +56,7 @@ go_test(
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//health/grpc_health_v1",
"@org_golang_google_grpc//metadata",
],
)

Expand Down
104 changes: 104 additions & 0 deletions pkg/util/grpcutil/fast_metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package grpcutil

import (
"context"
"strings"
"time"

"google.golang.org/grpc/metadata"
)

// FastFromIncomingContext is a specialization of
// metadata.FromIncomingContext() which extracts the metadata.MD from
// the context, if any, by reference. Main differences:
//
// - This variant does not guarantee that all the MD keys are
// lowercase. This happens to be true when the MD is populated by
// gRPC itself on an incoming RPC call, but it may not be true for
// MD populated elsewhere.
// - The caller promises to not modify the returned MD -- the gRPC
// APIs assume that the map in the context remains constant.
func FastFromIncomingContext(ctx context.Context) (metadata.MD, bool) {
md, ok := ctx.Value(grpcIncomingKeyObj).(metadata.MD)
return md, ok
}

// FastFirstValueFromIncomingContext is a specialization of
// metadata.ValueFromIncomingContext() which extracts the first string
// from the given metadata key, if it exists. No extra objects are
// allocated. The key is assumed to contain only ASCII characters.
func FastFirstValueFromIncomingContext(ctx context.Context, key string) (string, bool) {
md, ok := ctx.Value(grpcIncomingKeyObj).(metadata.MD)
if !ok {
return "", false
}
if v, ok := md[key]; ok {
if len(v) > 0 {
return v[0], true
}
return "", false
}
for k, v := range md {
// The letter casing may not have been set properly when MD was
// attached to the context. So we need to normalize it here.
//
// We add len(k) == len(key) to avoid the overhead of
// strings.ToLower when the keys of different length, because then
// they are guaranteed to not match anyway. This is the
// optimization that requires the key to be all ASCII, as
// generally ToLower() on non-ascii unicode can change the length
// of the string.
if len(k) == len(key) && strings.ToLower(k) == key {
if len(v) > 0 {
return v[0], true
}
return "", false
}
}
return "", false
}

// grpcIncomingKeyObj is a copy of a value with the Go type
// `metadata.incomingKey{}` (from the grpc metadata package). We
// cannot construct an object of that type directly, but we can
// "steal" it by forcing the metadata package to give it to us:
// `metadata.FromIncomingContext` gives an instance of this object as
// parameter to the `Value` method of the context you give it as
// argument. We use a custom implementation of that to "steal" the
// argument of type `incomingKey{}` given to us that way.
var grpcIncomingKeyObj = func() interface{} {
var f fakeContext
_, _ = metadata.FromIncomingContext(&f)
if f.recordedKey == nil {
panic("ValueFromIncomingContext did not request a key")
}
return f.recordedKey
}()

type fakeContext struct {
recordedKey interface{}
}

var _ context.Context = (*fakeContext)(nil)

// Value implements the context.Context interface and is our helper
// that "steals" an instance of the private type `incomingKey` in the
// grpc metadata package.
func (f *fakeContext) Value(keyObj interface{}) interface{} {
f.recordedKey = keyObj
return nil
}

func (*fakeContext) Deadline() (time.Time, bool) { panic("unused") }
func (*fakeContext) Done() <-chan struct{} { panic("unused") }
func (*fakeContext) Err() error { panic("unused") }
35 changes: 35 additions & 0 deletions pkg/util/grpcutil/fast_metadata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package grpcutil

import (
"context"
"testing"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/metadata"
)

func TestFastFromIncomingContext(t *testing.T) {
defer leaktest.AfterTest(t)()

md := metadata.MD{"hello": []string{"world", "universe"}}

ctx := metadata.NewIncomingContext(context.Background(), md)
md2, ok := FastFromIncomingContext(ctx)
require.True(t, ok)
require.Equal(t, md2, md)

v, ok := FastFirstValueFromIncomingContext(ctx, "hello")
require.True(t, ok)
require.Equal(t, v, "world")
}
2 changes: 1 addition & 1 deletion pkg/util/tracing/grpcinterceptor/grpc_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
func ExtractSpanMetaFromGRPCCtx(
ctx context.Context, tracer *tracing.Tracer,
) (tracing.SpanMeta, error) {
md, ok := metadata.FromIncomingContext(ctx)
md, ok := grpcutil.FastFromIncomingContext(ctx)
if !ok {
return tracing.SpanMeta{}, nil
}
Expand Down

0 comments on commit 4384a92

Please sign in to comment.