diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index f33d3cad5ec5..e2faa26b6fc5 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -34,6 +34,7 @@ go_library( "//pkg/security", "//pkg/security/certnames", "//pkg/security/username", + "//pkg/settings", "//pkg/settings/cluster", "//pkg/util/buildutil", "//pkg/util/contextutil", @@ -42,6 +43,7 @@ go_library( "//pkg/util/grpcutil", "//pkg/util/hlc", "//pkg/util/log", + "//pkg/util/log/logcrash", "//pkg/util/log/severity", "//pkg/util/metric", "//pkg/util/netutil", diff --git a/pkg/rpc/auth.go b/pkg/rpc/auth.go index d158e588bb9d..1045c5e9a9dc 100644 --- a/pkg/rpc/auth.go +++ b/pkg/rpc/auth.go @@ -13,12 +13,15 @@ package rpc import ( "context" "crypto/x509" + "fmt" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security" "github.com/cockroachdb/cockroach/pkg/security/username" + "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/log/logcrash" "github.com/cockroachdb/errors" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -41,6 +44,7 @@ func authErrorf(format string, a ...interface{}) error { // validates that client TLS certificate provided by the incoming connection // contains a sufficiently privileged user. type kvAuth struct { + sv *settings.Values tenant tenantAuthorizer } @@ -68,13 +72,9 @@ func (a kvAuth) unaryInterceptor( return nil, err } - // Enhance the context if the peer is a tenant server. - switch ar := authnRes.(type) { - case authnSuccessPeerIsTenantServer: - ctx = contextWithClientTenant(ctx, roachpb.TenantID(ar)) - default: - ctx = contextWithoutClientTenant(ctx) - } + // Enhance the context to ensure the API handler only sees a client tenant ID + // via roachpb.ClientTenantFromContext when relevant. + ctx = contextForRequest(ctx, authnRes) // Handle authorization according to the selected authz method. switch ar := authz.(type) { @@ -103,13 +103,9 @@ func (a kvAuth) streamInterceptor( return err } - // Enhance the context if the peer is a tenant server. - switch ar := authnRes.(type) { - case authnSuccessPeerIsTenantServer: - ctx = contextWithClientTenant(ctx, roachpb.TenantID(ar)) - default: - ctx = contextWithoutClientTenant(ctx) - } + // Enhance the context to ensure the API handler only sees a client tenant ID + // via roachpb.ClientTenantFromContext when relevant. + ctx = contextForRequest(ctx, authnRes) // Handle authorization according to the selected authz method. switch ar := authz.(type) { @@ -188,7 +184,10 @@ func (authnSuccessPeerIsTenantServer) authnResult() {} func (authnSuccessPeerIsPrivileged) authnResult() {} // authenticate verifies the credentials of the client and performs -// some consistency check with the information provided. +// some consistency check with the information provided. The caller +// should discard the original context.Context and use the new one; +// the function also consumes and strips some fields from the incoming +// gRPC metadata.MD to avoid confusion if/when the RPC gets forwarded. func (a kvAuth) authenticate(ctx context.Context) (authnResult, error) { var ar authnResult if clientTenantID, localRequest := grpcutil.IsLocalRequestContext(ctx); localRequest { @@ -219,14 +218,26 @@ func (a kvAuth) authenticate(ctx context.Context) (authnResult, error) { return ar, nil } -// Deal with local requests done through the -// internalClientAdapter. There's no TLS for these calls, so the -// regular authentication code path doesn't apply. The clientTenantID -// should be the result of a call to grpcutil.IsLocalRequestContext. +// Deal with local requests done through the internalClientAdapter. +// There's no TLS for these calls, so the regular authentication code +// path doesn't apply. The clientTenantID should be the result of a +// call to grpcutil.IsLocalRequestContext. func (a kvAuth) authenticateLocalRequest( ctx context.Context, clientTenantID roachpb.TenantID, ) (authnResult, error) { - if !clientTenantID.IsSet() || clientTenantID.IsSystem() { + // Sanity check: verify that we do not also have gRPC network credentials + // in the context. This would indicate that metadata was improperly propagated. + maybeTid, err := tenantIDFromRPCMetadata(ctx) + if err != nil || maybeTid.IsSet() { + logcrash.ReportOrPanic(ctx, a.sv, "programming error: network credentials in internal adapter request (%v, %v)", maybeTid, err) + return nil, authErrorf("programming error") + } + + if !clientTenantID.IsSet() { + return authnSuccessPeerIsPrivileged{}, nil + } + + if clientTenantID.IsSystem() { return authnSuccessPeerIsPrivileged{}, nil } @@ -242,6 +253,11 @@ func (a kvAuth) authenticateNetworkRequest(ctx context.Context) (authnResult, er return nil, err } + tenantIDFromMetadata, err := tenantIDFromRPCMetadata(ctx) + if err != nil { + return nil, authErrorf("client provided invalid tenant ID: %v", err) + } + // Did the client peer use a tenant client cert? if security.IsTenantCertificate(clientCert) { // If the peer is using a client tenant cert, in any case we @@ -250,7 +266,14 @@ func (a kvAuth) authenticateNetworkRequest(ctx context.Context) (authnResult, er if err != nil { return nil, err } - + // If the peer is using a TenantCertificate and also + // provided a tenant ID via gRPC metadata, they must + // match. + if tenantIDFromMetadata.IsSet() && tenantIDFromMetadata != tlsID { + return nil, authErrorf( + "client wants to authenticate as tenant %v, but is using TLS cert for tenant %v", + tenantIDFromMetadata, tlsID) + } return authnSuccessPeerIsTenantServer(tlsID), nil } @@ -273,6 +296,9 @@ func (a kvAuth) authenticateNetworkRequest(ctx context.Context) (authnResult, er return nil, err } + if tenantIDFromMetadata.IsSet() { + return authnSuccessPeerIsTenantServer(tenantIDFromMetadata), nil + } return authnSuccessPeerIsPrivileged{}, nil } @@ -350,3 +376,96 @@ func checkRootOrNodeInScope( "need root or node client cert to perform RPCs on this server (this is tenant %v; cert is valid for %s)", serverTenantID, security.FormatUserScopes(certUserScope)) } + +// contextForRequest sets up the context.Context for use by +// the API handler. It covers two cases: +// +// - the request is coming from a secondary tenant. +// Then it uses roachpb.ContextWithTenantClient() to +// ensure that the API handler will find the tenant ID +// with roachpb.TenantClientFromContext(). +// - the request is coming from the system tenant. +// then it clears the tenant client information +// to ensure that the API handler will _not_ find +// a tenant ID with roachpb.TenantClientFromContext(). +// +// This latter case is important e.g. in the following scenario: +// +// SQL (a) -(network gRPC)-> KV (b) -(internal client adapter)-> KV (c) +// +// The authn in the call from (a) to (b) has added a tenant ID in the +// Go context for the handler at (b). This context.Context "pierces" +// the stack of calls in the internal client adapter, and thus the +// tenant ID is still present when the call is received at (c). +// However, we don't want the API handler at (c) to see it any more. +// So we need to remove it. +func contextForRequest(ctx context.Context, authnRes authnResult) context.Context { + switch ar := authnRes.(type) { + case authnSuccessPeerIsTenantServer: + // The simple context key will be used in various places via + // roachpb.ClientTenantFromContext(). This also adds a logging + // tag. + ctx = contextWithClientTenant(ctx, roachpb.TenantID(ar)) + default: + // The caller is not a tenant server, but it may have been in the + // process of handling an API call for a tenant server and so it + // may have a client tenant ID in its context already. To ensure + // none will be found, we need to clear it explicitly. + ctx = contextWithoutClientTenant(ctx) + } + return ctx +} + +// tenantClientCred is responsible for passing the tenant ID as +// medatada header to called RPCs. This makes it possible to pass the +// tenant ID even when using a different TLS cert than the "tenant +// client cert". +type tenantClientCred struct { + md map[string]string +} + +// clientTIDMetadataHeaderKey is the gRPC metadata key that indicates +// which tenant ID the client is intending to connect as (originating +// tenant identity). +// +// This is used instead of the cert CN field when connecting with a +// TLS client cert that is not marked as special "tenant client cert" +// via the "Tenants" string in the OU field. +// +// This metadata item is not meant to be used beyond authentication; +// to access the client tenant ID inside RPC handlers or other code, +// use roachpb.ClientTenantFromContext() instead. +const clientTIDMetadataHeaderKey = "client-tid" + +// newTenantClientCreds constructs a credentials.PerRPCCredentials +// which injects the client tenant ID as extra gRPC metadata in each +// RPC. +func newTenantClientCreds(tid roachpb.TenantID) credentials.PerRPCCredentials { + return &tenantClientCred{ + md: map[string]string{ + clientTIDMetadataHeaderKey: fmt.Sprint(tid), + }, + } +} + +// tenantIDFromRPCMetadata checks if there is a tenant ID in +// the incoming gRPC metadata. +func tenantIDFromRPCMetadata(ctx context.Context) (roachpb.TenantID, error) { + val, ok := grpcutil.FastFirstValueFromIncomingContext(ctx, clientTIDMetadataHeaderKey) + if !ok { + return roachpb.TenantID{}, nil + } + return tenantIDFromString(val, "gRPC metadata") +} + +// GetRequestMetadata implements the (grpc) +// credentials.PerRPCCredentials interface. +func (tcc *tenantClientCred) GetRequestMetadata( + ctx context.Context, uri ...string, +) (map[string]string, error) { + return tcc.md, nil +} + +// RequireTransportSecurity implements the (grpc) +// credentials.PerRPCCredentials interface. +func (tcc *tenantClientCred) RequireTransportSecurity() bool { return false } diff --git a/pkg/rpc/auth_test.go b/pkg/rpc/auth_test.go index 047dacfccccc..20aff433fd35 100644 --- a/pkg/rpc/auth_test.go +++ b/pkg/rpc/auth_test.go @@ -97,12 +97,13 @@ func TestAuthenticateTenant(t *testing.T) { stid := roachpb.SystemTenantID tenTen := roachpb.MustMakeTenantID(10) for _, tc := range []struct { - systemID roachpb.TenantID - ous []string - commonName string - expTenID roachpb.TenantID - expErr string - tenantScope uint64 + systemID roachpb.TenantID + ous []string + commonName string + expTenID roachpb.TenantID + expErr string + tenantScope uint64 + clientTenantInMD string }{ {systemID: stid, ous: correctOU, commonName: "10", expTenID: tenTen}, {systemID: stid, ous: correctOU, commonName: roachpb.MinTenantID.String(), expTenID: roachpb.MinTenantID}, @@ -127,8 +128,44 @@ func TestAuthenticateTenant(t *testing.T) { {systemID: tenTen, ous: correctOU, commonName: "1", expErr: `invalid tenant ID 1 in Common Name \(CN\)`}, {systemID: tenTen, ous: nil, commonName: "root"}, {systemID: tenTen, ous: nil, commonName: "node"}, + + // Passing a client ID in metadata instead of relying only on the TLS cert. + {clientTenantInMD: "invalid", expErr: `could not parse tenant ID from gRPC metadata`}, + {clientTenantInMD: "1", expErr: `invalid tenant ID 1 in gRPC metadata`}, + {clientTenantInMD: "-1", expErr: `could not parse tenant ID from gRPC metadata`}, + + // tenant ID in MD matches that in client cert. + // Server is KV node: expect tenant authorization. + {clientTenantInMD: "10", + systemID: stid, ous: correctOU, commonName: "10", expTenID: tenTen}, + // tenant ID in MD doesn't match that in client cert. + {clientTenantInMD: "10", + systemID: stid, ous: correctOU, commonName: "123", + expErr: `client wants to authenticate as tenant 10, but is using TLS cert for tenant 123`}, + // tenant ID present in MD, but not in client cert. However, + // client cert is valid. Use MD tenant ID. + // Server is KV node: expect tenant authorization. + {clientTenantInMD: "10", + systemID: stid, ous: nil, commonName: "root", expTenID: tenTen}, + // tenant ID present in MD, but not in client cert. However, + // client cert is valid. Use MD tenant ID. + // Server is KV node: expect tenant authorization. + {clientTenantInMD: "10", + systemID: stid, ous: nil, commonName: "node", expTenID: tenTen}, + // tenant ID present in MD, but not in client cert. However, + // client cert is valid. Use MD tenant ID. + // Server is secondary tenant: do not do additional tenant authorization. + {clientTenantInMD: "10", + systemID: tenTen, ous: nil, commonName: "root", expTenID: roachpb.TenantID{}}, + {clientTenantInMD: "10", + systemID: tenTen, ous: nil, commonName: "node", expTenID: roachpb.TenantID{}}, + // tenant ID present in MD, but not in client cert. Use MD tenant ID. + // Server tenant ID does not match client tenant ID. + {clientTenantInMD: "123", + systemID: tenTen, ous: nil, commonName: "root", + expErr: `client tenant identity \(123\) does not match server`}, } { - t.Run(fmt.Sprintf("from %v to %v", tc.commonName, tc.systemID), func(t *testing.T) { + t.Run(fmt.Sprintf("from %v to %v (md %q)", tc.commonName, tc.systemID, tc.clientTenantInMD), func(t *testing.T) { cert := &x509.Certificate{ Subject: pkix.Name{ CommonName: tc.commonName, @@ -150,6 +187,11 @@ func TestAuthenticateTenant(t *testing.T) { p := peer.Peer{AuthInfo: tlsInfo} ctx := peer.NewContext(context.Background(), &p) + if tc.clientTenantInMD != "" { + md := metadata.MD{"client-tid": []string{tc.clientTenantInMD}} + ctx = metadata.NewIncomingContext(ctx, md) + } + tenID, err := rpc.TestingAuthenticateTenant(ctx, tc.systemID) if tc.expErr == "" { diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index f4b641ca251b..c180ce947211 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -244,6 +244,7 @@ func NewServerEx( if !rpcCtx.Config.Insecure { a := kvAuth{ + sv: &rpcCtx.Settings.SV, tenant: tenantAuthorizer{ tenantID: rpcCtx.tenID, }, @@ -421,6 +422,9 @@ type Context struct { // The loopbackDialFn fits under that common case by transporting // the gRPC protocol over an in-memory pipe. loopbackDialFn func(context.Context) (net.Conn, error) + + // clientCreds is used to pass additional headers to called RPCs. + clientCreds credentials.PerRPCCredentials } // SetLoopbackDialer configures the loopback dialer function. @@ -615,6 +619,10 @@ func NewContext(ctx context.Context, opts ContextOptions) *Context { logClosingConnEvery: log.Every(time.Second), } + if !opts.TenantID.IsSystem() { + rpcCtx.clientCreds = newTenantClientCreds(opts.TenantID) + } + if opts.Knobs.NoLoopbackDialer { // The test has decided it doesn't need/want a loopback dialer. // Ensure we still have a working dial function in that case. @@ -801,6 +809,18 @@ func makeInternalClientAdapter( // the outer RPC. ctx = grpcutil.NewLocalRequestContext(ctx, tenantIDFromContext(ctx)) + // Clear any leftover gRPC incoming metadata, if this call + // is originating from a RPC handler function called as + // a result of a tenant call. This is this case: + // + // tenant -(rpc)-> tenant -(rpc)-> KV + // ^ YOU ARE HERE + // + // at this point, the left side RPC has left some incoming + // metadata in the context, but we need to get rid of it + // before we let the call go through KV. + ctx = grpcutil.ClearIncomingContext(ctx) + // If this is a tenant calling to the local KV server, we make things look // closer to a remote call from the tracing point of view. if tenant { @@ -1693,6 +1713,10 @@ func (rpcCtx *Context) dialOptsCommon( grpc.MaxCallSendMsgSize(math.MaxInt32), )} + if rpcCtx.clientCreds != nil { + dialOpts = append(dialOpts, grpc.WithPerRPCCredentials(rpcCtx.clientCreds)) + } + // We throw this one in for good measure, but it only disables the retries // for RPCs that were already pending (which are opt in anyway, and we don't // opt in). It doesn't disable what gRPC calls "transparent retries" (RPC diff --git a/pkg/util/grpcutil/fast_metadata.go b/pkg/util/grpcutil/fast_metadata.go index 0b85281aa95f..270d9598865a 100644 --- a/pkg/util/grpcutil/fast_metadata.go +++ b/pkg/util/grpcutil/fast_metadata.go @@ -33,6 +33,16 @@ func FastFromIncomingContext(ctx context.Context) (metadata.MD, bool) { return md, ok } +// ClearIncomingContext "removes" the gRPC incoming metadata +// if there's any already. No-op if there was no metadata +// to start with. +func ClearIncomingContext(ctx context.Context) context.Context { + if ctx.Value(grpcIncomingKeyObj) != nil { + return metadata.NewIncomingContext(ctx, nil) + } + return ctx +} + // FastFirstValueFromIncomingContext is a specialization of // metadata.ValueFromIncomingContext() which extracts the first string // from the given metadata key, if it exists. No extra objects are @@ -68,6 +78,57 @@ func FastFirstValueFromIncomingContext(ctx context.Context, key string) (string, return "", false } +// FastGetAndDeleteValueFromIncomingContext extracts the first string +// from the given metadata key, if it exists. If it does, the metadata +// key is removed from the context. +func FastGetAndDeleteValueFromIncomingContext( + ctx context.Context, key string, +) (found bool, val string, newCtx context.Context) { + md, ok := ctx.Value(grpcIncomingKeyObj).(metadata.MD) + if !ok { + return false, "", ctx + } + if v, ok := md[key]; ok { + if len(v) > 0 { + found = true + val = v[0] + } + newMd := make(metadata.MD, len(md)-1) + for k, v := range md { + if k == key { + continue + } + newMd[k] = v + } + return found, val, metadata.NewIncomingContext(ctx, newMd) + } + for k, v := range md { + // The letter caseing may not have been set properly when MD was attached to + // the context. + // See the comment in FastValueFromIncomingContext above relating + // to the length comparison. + if len(k) != len(key) { + continue + } + lowK := strings.ToLower(k) + if lowK == key { + if len(v) > 0 { + found = true + val = v[0] + } + newMd := make(metadata.MD, len(md)-1) + for otherK, otherV := range md { + if otherK == lowK { + continue + } + newMd[otherK] = otherV + } + return found, val, metadata.NewIncomingContext(ctx, newMd) + } + } + return false, "", ctx +} + // grpcIncomingKeyObj is a copy of a value with the Go type // `metadata.incomingKey{}` (from the grpc metadata package). We // cannot construct an object of that type directly, but we can diff --git a/pkg/util/grpcutil/fast_metadata_test.go b/pkg/util/grpcutil/fast_metadata_test.go index a10232583e7d..1e6bb1e19285 100644 --- a/pkg/util/grpcutil/fast_metadata_test.go +++ b/pkg/util/grpcutil/fast_metadata_test.go @@ -33,3 +33,101 @@ func TestFastFromIncomingContext(t *testing.T) { require.True(t, ok) require.Equal(t, v, "world") } + +func TestFastGetAndRemoveValue(t *testing.T) { + md := metadata.MD{"hello": []string{"world", "universe"}} + ctx := metadata.NewIncomingContext(context.Background(), md) + + found, _, newCtx := FastGetAndDeleteValueFromIncomingContext(ctx, "woo") + require.False(t, found) + require.Equal(t, ctx, newCtx) + + found, val, newCtx := FastGetAndDeleteValueFromIncomingContext(ctx, "hello") + require.True(t, found) + require.Equal(t, "world", val) + + found, _, newCtx2 := FastGetAndDeleteValueFromIncomingContext(newCtx, "hello") + require.False(t, found) + require.Equal(t, newCtx, newCtx2) +} + +func BenchmarkFromIncomingContext(b *testing.B) { + md := metadata.MD{"hello": []string{"world", "universe"}} + ctx := metadata.NewIncomingContext(context.Background(), md) + b.Run("stdlib", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = metadata.FromIncomingContext(ctx) + } + }) + b.Run("fast", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = FastFromIncomingContext(ctx) + } + }) +} + +func slowRemoveValueIfExists1( + ctx context.Context, key string, +) (found bool, val string, newCtx context.Context) { + vals := metadata.ValueFromIncomingContext(ctx, key) + if len(vals) > 0 { + md, _ := metadata.FromIncomingContext(ctx) + delete(md, key) + return true, vals[0], metadata.NewIncomingContext(ctx, md) + } + return false, "", ctx +} + +func slowRemoveValueIfExists2( + ctx context.Context, key string, +) (found bool, val string, newCtx context.Context) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return false, "", ctx + } + if v, ok := md[key]; ok { + if len(v) > 0 { + found = true + val = v[0] + } + delete(md, key) + return found, val, metadata.NewIncomingContext(ctx, md) + } + return false, "", ctx +} + +func BenchmarkGetAndRemoveValue(b *testing.B) { + md := metadata.MD{"hello": []string{"world", "universe"}} + ctx := metadata.NewIncomingContext(context.Background(), md) + + b.Run("stdlib1/exists", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = slowRemoveValueIfExists1(ctx, "hello") + } + }) + b.Run("stdlib2/exists", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = slowRemoveValueIfExists2(ctx, "hello") + } + }) + b.Run("fast/exists", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = FastGetAndDeleteValueFromIncomingContext(ctx, "hello") + } + }) + b.Run("stdlib1/notexists", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = slowRemoveValueIfExists1(ctx, "foo") + } + }) + b.Run("stdlib2/notexists", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = slowRemoveValueIfExists2(ctx, "foo") + } + }) + b.Run("fast/notexists", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = FastGetAndDeleteValueFromIncomingContext(ctx, "foo") + } + }) +}