diff --git a/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed.go b/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed.go index 3097d82c1983..6274d229ad0f 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed.go @@ -466,14 +466,14 @@ func (ds *DistSender) singleRangeFeed( } args.Replica = transport.NextReplica() - clientCtx, client, err := transport.NextInternalClient(ctx) + client, err := transport.NextInternalClient(ctx) if err != nil { log.VErrEventf(ctx, 2, "RPC error: %s", err) continue } log.VEventf(ctx, 3, "attempting to create a RangeFeed over replica %s", args.Replica) - stream, err := client.RangeFeed(clientCtx, &args) + stream, err := client.RangeFeed(ctx, &args) if err != nil { log.VErrEventf(ctx, 2, "RPC error: %s", err) if grpcutil.IsAuthError(err) { diff --git a/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed_test.go b/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed_test.go index c9a0b8847afc..f3fefa0c403c 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed_test.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender_rangefeed_test.go @@ -95,7 +95,7 @@ func TestDistSenderRangeFeedRetryOnTransportErrors(t *testing.T) { transport.EXPECT().IsExhausted().Return(false) transport.EXPECT().NextReplica().Return(repl) transport.EXPECT().NextInternalClient(gomock.Any()).Return( - ctx, nil, grpcstatus.Error(spec.errorCode, "")) + nil, grpcstatus.Error(spec.errorCode, "")) } transport.EXPECT().IsExhausted().Return(true) transport.EXPECT().Release() @@ -123,7 +123,7 @@ func TestDistSenderRangeFeedRetryOnTransportErrors(t *testing.T) { client.EXPECT().RangeFeed(gomock.Any(), gomock.Any()).Return(stream, nil) transport.EXPECT().IsExhausted().Return(false) transport.EXPECT().NextReplica().Return(desc.InternalReplicas[0]) - transport.EXPECT().NextInternalClient(gomock.Any()).Return(ctx, client, nil) + transport.EXPECT().NextInternalClient(gomock.Any()).Return(client, nil) transport.EXPECT().Release() } diff --git a/pkg/kv/kvclient/kvcoord/dist_sender_test.go b/pkg/kv/kvclient/kvcoord/dist_sender_test.go index da2a67291ce3..bb97bd1764be 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender_test.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender_test.go @@ -169,7 +169,7 @@ func (l *simpleTransportAdapter) SendNext( func (l *simpleTransportAdapter) NextInternalClient( ctx context.Context, -) (context.Context, roachpb.InternalClient, error) { +) (rpc.RestrictedInternalClient, error) { panic("unimplemented") } diff --git a/pkg/kv/kvclient/kvcoord/mocks_generated_test.go b/pkg/kv/kvclient/kvcoord/mocks_generated_test.go index 597b69e49f0f..02323419cc74 100644 --- a/pkg/kv/kvclient/kvcoord/mocks_generated_test.go +++ b/pkg/kv/kvclient/kvcoord/mocks_generated_test.go @@ -9,6 +9,7 @@ import ( reflect "reflect" roachpb "github.com/cockroachdb/cockroach/pkg/roachpb" + rpc "github.com/cockroachdb/cockroach/pkg/rpc" gomock "github.com/golang/mock/gomock" ) @@ -62,13 +63,12 @@ func (mr *MockTransportMockRecorder) MoveToFront(arg0 interface{}) *gomock.Call } // NextInternalClient mocks base method. -func (m *MockTransport) NextInternalClient(arg0 context.Context) (context.Context, roachpb.InternalClient, error) { +func (m *MockTransport) NextInternalClient(arg0 context.Context) (rpc.RestrictedInternalClient, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "NextInternalClient", arg0) - ret0, _ := ret[0].(context.Context) - ret1, _ := ret[1].(roachpb.InternalClient) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(rpc.RestrictedInternalClient) + ret1, _ := ret[1].(error) + return ret0, ret1 } // NextInternalClient indicates an expected call of NextInternalClient. diff --git a/pkg/kv/kvclient/kvcoord/send_test.go b/pkg/kv/kvclient/kvcoord/send_test.go index 064b8e3a0c68..87d392e399cc 100644 --- a/pkg/kv/kvclient/kvcoord/send_test.go +++ b/pkg/kv/kvclient/kvcoord/send_test.go @@ -166,7 +166,7 @@ func (f *firstNErrorTransport) SendNext( func (f *firstNErrorTransport) NextInternalClient( ctx context.Context, -) (context.Context, roachpb.InternalClient, error) { +) (rpc.RestrictedInternalClient, error) { panic("unimplemented") } diff --git a/pkg/kv/kvclient/kvcoord/transport.go b/pkg/kv/kvclient/kvcoord/transport.go index 7bc422d64320..b82ac01bfb53 100644 --- a/pkg/kv/kvclient/kvcoord/transport.go +++ b/pkg/kv/kvclient/kvcoord/transport.go @@ -67,10 +67,8 @@ type Transport interface { SendNext(context.Context, roachpb.BatchRequest) (*roachpb.BatchResponse, error) // NextInternalClient returns the InternalClient to use for making RPC - // calls. Returns a context.Context which should be used when making RPC - // calls on the returned server (This context is annotated to mark this - // request as in-process and bypass ctx.Peer checks). - NextInternalClient(context.Context) (context.Context, roachpb.InternalClient, error) + // calls. + NextInternalClient(context.Context) (rpc.RestrictedInternalClient, error) // NextReplica returns the replica descriptor of the replica to be tried in // the next call to SendNext. MoveToFront will cause the return value to @@ -182,7 +180,7 @@ func (gt *grpcTransport) SendNext( ctx context.Context, ba roachpb.BatchRequest, ) (*roachpb.BatchResponse, error) { r := gt.replicas[gt.nextReplicaIdx] - ctx, iface, err := gt.NextInternalClient(ctx) + iface, err := gt.NextInternalClient(ctx) if err != nil { return nil, err } @@ -193,7 +191,10 @@ func (gt *grpcTransport) SendNext( // NB: nodeID is unused, but accessible in stack traces. func (gt *grpcTransport) sendBatch( - ctx context.Context, nodeID roachpb.NodeID, iface roachpb.InternalClient, ba roachpb.BatchRequest, + ctx context.Context, + nodeID roachpb.NodeID, + iface rpc.RestrictedInternalClient, + ba roachpb.BatchRequest, ) (*roachpb.BatchResponse, error) { // Bail out early if the context is already canceled. (GRPC will // detect this pretty quickly, but the first check of the context @@ -232,7 +233,7 @@ func (gt *grpcTransport) sendBatch( // RPCs. func (gt *grpcTransport) NextInternalClient( ctx context.Context, -) (context.Context, roachpb.InternalClient, error) { +) (rpc.RestrictedInternalClient, error) { r := gt.replicas[gt.nextReplicaIdx] gt.nextReplicaIdx++ return gt.nodeDialer.DialInternalClient(ctx, r.NodeID, gt.class) @@ -360,7 +361,7 @@ func (s *senderTransport) SendNext( func (s *senderTransport) NextInternalClient( ctx context.Context, -) (context.Context, roachpb.InternalClient, error) { +) (rpc.RestrictedInternalClient, error) { panic("unimplemented") } diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index 62c58f20501c..cc48f8b64ac4 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "heartbeat.go", "keepalive.go", "metrics.go", + "restricted_internal_client.go", "snappy.go", "tls.go", ], diff --git a/pkg/rpc/context.go b/pkg/rpc/context.go index 64c007fe016e..8a9348d6b7d7 100644 --- a/pkg/rpc/context.go +++ b/pkg/rpc/context.go @@ -353,7 +353,7 @@ type Context struct { rpcCompression bool - localInternalClient roachpb.InternalClient + localInternalClient RestrictedInternalClient conns syncmap.Map @@ -597,7 +597,7 @@ func (rpcCtx *Context) Metrics() *Metrics { // https://github.com/cockroachdb/cockroach/pull/73309 func (rpcCtx *Context) GetLocalInternalClientForAddr( target string, nodeID roachpb.NodeID, -) roachpb.InternalClient { +) RestrictedInternalClient { if target == rpcCtx.Config.AdvertiseAddr && nodeID == rpcCtx.NodeID.Get() { return rpcCtx.localInternalClient } @@ -608,6 +608,8 @@ type internalClientAdapter struct { server roachpb.InternalServer } +var _ RestrictedInternalClient = internalClientAdapter{} + // Batch implements the roachpb.InternalClient interface. func (a internalClientAdapter) Batch( ctx context.Context, ba *roachpb.BatchRequest, _ ...grpc.CallOption, @@ -615,56 +617,34 @@ func (a internalClientAdapter) Batch( // Mark this as originating locally, which is useful for the decision about // memory allocation tracking. ba.AdmissionHeader.SourceLocation = roachpb.AdmissionHeader_LOCAL - return a.server.Batch(ctx, ba) -} - -// RangeLookup implements the roachpb.InternalClient interface. -func (a internalClientAdapter) RangeLookup( - ctx context.Context, rl *roachpb.RangeLookupRequest, _ ...grpc.CallOption, -) (*roachpb.RangeLookupResponse, error) { - return a.server.RangeLookup(ctx, rl) -} - -// Join implements the roachpb.InternalClient interface. -func (a internalClientAdapter) Join( - ctx context.Context, req *roachpb.JoinNodeRequest, _ ...grpc.CallOption, -) (*roachpb.JoinNodeResponse, error) { - return a.server.Join(ctx, req) -} - -// ResetQuorum is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) ResetQuorum( - ctx context.Context, req *roachpb.ResetQuorumRequest, _ ...grpc.CallOption, -) (*roachpb.ResetQuorumResponse, error) { - return a.server.ResetQuorum(ctx, req) + // Create a new context from the existing one with the "local request" field set. + // This tells the handler that this is an in-process request, bypassing ctx.Peer checks. + return a.server.Batch(grpcutil.NewLocalRequestContext(ctx), ba) } -// TokenBucket is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) TokenBucket( - ctx context.Context, in *roachpb.TokenBucketRequest, opts ...grpc.CallOption, -) (*roachpb.TokenBucketResponse, error) { - return a.server.TokenBucket(ctx, in) -} - -// GetSpanConfigs is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) GetSpanConfigs( - ctx context.Context, req *roachpb.GetSpanConfigsRequest, _ ...grpc.CallOption, -) (*roachpb.GetSpanConfigsResponse, error) { - return a.server.GetSpanConfigs(ctx, req) -} +// RangeFeed implements the roachpb.InternalClient interface. +func (a internalClientAdapter) RangeFeed( + ctx context.Context, args *roachpb.RangeFeedRequest, _ ...grpc.CallOption, +) (roachpb.Internal_RangeFeedClient, error) { + ctx, cancel := context.WithCancel(ctx) + ctx, sp := tracing.ChildSpan(ctx, "/cockroach.roachpb.Internal/RangeFeed") + rfAdapter := rangeFeedClientAdapter{ + respStreamClientAdapter: makeRespStreamClientAdapter(grpcutil.NewLocalRequestContext(ctx)), + } -// GetAllSystemSpanConfigsThatApply is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) GetAllSystemSpanConfigsThatApply( - ctx context.Context, req *roachpb.GetAllSystemSpanConfigsThatApplyRequest, _ ...grpc.CallOption, -) (*roachpb.GetAllSystemSpanConfigsThatApplyResponse, error) { - return a.server.GetAllSystemSpanConfigsThatApply(ctx, req) -} + // Mark this as originating locally. + args.AdmissionHeader.SourceLocation = roachpb.AdmissionHeader_LOCAL + go func() { + defer cancel() + defer sp.Finish() + err := a.server.RangeFeed(args, rfAdapter) + if err == nil { + err = io.EOF + } + rfAdapter.errC <- err + }() -// UpdateSpanConfigs is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) UpdateSpanConfigs( - ctx context.Context, req *roachpb.UpdateSpanConfigsRequest, _ ...grpc.CallOption, -) (*roachpb.UpdateSpanConfigsResponse, error) { - return a.server.UpdateSpanConfigs(ctx, req) + return rfAdapter, nil } type respStreamClientAdapter struct { @@ -744,121 +724,8 @@ func (a rangeFeedClientAdapter) Send(e *roachpb.RangeFeedEvent) error { var _ roachpb.Internal_RangeFeedClient = rangeFeedClientAdapter{} var _ roachpb.Internal_RangeFeedServer = rangeFeedClientAdapter{} -// RangeFeed implements the roachpb.InternalClient interface. -func (a internalClientAdapter) RangeFeed( - ctx context.Context, args *roachpb.RangeFeedRequest, _ ...grpc.CallOption, -) (roachpb.Internal_RangeFeedClient, error) { - ctx, cancel := context.WithCancel(ctx) - ctx, sp := tracing.ChildSpan(ctx, "/cockroach.roachpb.Internal/RangeFeed") - rfAdapter := rangeFeedClientAdapter{ - respStreamClientAdapter: makeRespStreamClientAdapter(ctx), - } - - // Mark this as originating locally. - args.AdmissionHeader.SourceLocation = roachpb.AdmissionHeader_LOCAL - go func() { - defer cancel() - defer sp.Finish() - err := a.server.RangeFeed(args, rfAdapter) - if err == nil { - err = io.EOF - } - rfAdapter.errC <- err - }() - - return rfAdapter, nil -} - -type gossipSubscriptionClientAdapter struct { - respStreamClientAdapter -} - -// roachpb.Internal_GossipSubscriptionServer methods. -func (a gossipSubscriptionClientAdapter) Recv() (*roachpb.GossipSubscriptionEvent, error) { - e, err := a.recvInternal() - if err != nil { - return nil, err - } - return e.(*roachpb.GossipSubscriptionEvent), nil -} - -// roachpb.Internal_GossipSubscriptionServer methods. -func (a gossipSubscriptionClientAdapter) Send(e *roachpb.GossipSubscriptionEvent) error { - return a.sendInternal(e) -} - -var _ roachpb.Internal_GossipSubscriptionClient = gossipSubscriptionClientAdapter{} -var _ roachpb.Internal_GossipSubscriptionServer = gossipSubscriptionClientAdapter{} - -// GossipSubscription is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) GossipSubscription( - ctx context.Context, args *roachpb.GossipSubscriptionRequest, _ ...grpc.CallOption, -) (roachpb.Internal_GossipSubscriptionClient, error) { - ctx, cancel := context.WithCancel(ctx) - ctx, sp := tracing.ChildSpan(ctx, "/cockroach.roachpb.Internal/GossipSubscription") - gsAdapter := gossipSubscriptionClientAdapter{ - respStreamClientAdapter: makeRespStreamClientAdapter(ctx), - } - - go func() { - defer cancel() - defer sp.Finish() - err := a.server.GossipSubscription(args, gsAdapter) - if err == nil { - err = io.EOF - } - gsAdapter.errC <- err - }() - - return gsAdapter, nil -} - -type tenantSettingsClientAdapter struct { - respStreamClientAdapter -} - -// roachpb.Internal_TenantSettingsServer methods. -func (a tenantSettingsClientAdapter) Recv() (*roachpb.TenantSettingsEvent, error) { - e, err := a.recvInternal() - if err != nil { - return nil, err - } - return e.(*roachpb.TenantSettingsEvent), nil -} - -// roachpb.Internal_TenantSettingsServer methods. -func (a tenantSettingsClientAdapter) Send(e *roachpb.TenantSettingsEvent) error { - return a.sendInternal(e) -} - -var _ roachpb.Internal_TenantSettingsClient = tenantSettingsClientAdapter{} -var _ roachpb.Internal_TenantSettingsServer = tenantSettingsClientAdapter{} - -// TenantSettings is part of the roachpb.InternalClient interface. -func (a internalClientAdapter) TenantSettings( - ctx context.Context, args *roachpb.TenantSettingsRequest, _ ...grpc.CallOption, -) (roachpb.Internal_TenantSettingsClient, error) { - ctx, cancel := context.WithCancel(ctx) - gsAdapter := tenantSettingsClientAdapter{ - respStreamClientAdapter: makeRespStreamClientAdapter(ctx), - } - - go func() { - defer cancel() - err := a.server.TenantSettings(args, gsAdapter) - if err == nil { - err = io.EOF - } - gsAdapter.errC <- err - }() - - return gsAdapter, nil -} - -var _ roachpb.InternalClient = internalClientAdapter{} - // IsLocal returns true if the given InternalClient is local. -func IsLocal(iface roachpb.InternalClient) bool { +func IsLocal(iface RestrictedInternalClient) bool { _, ok := iface.(internalClientAdapter) return ok // internalClientAdapter is used for local connections. } diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index 7b50ce91b365..202e4cd240d7 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -140,18 +140,16 @@ func (n *Dialer) DialNoBreaker( // DialInternalClient is a specialization of DialClass for callers that // want a roachpb.InternalClient. This supports an optimization to bypass the -// network for the local node. Returns a context.Context which should be used -// when making RPC calls on the returned server. (This context is annotated to -// mark this request as in-process and bypass ctx.Peer checks). +// network for the local node. func (n *Dialer) DialInternalClient( ctx context.Context, nodeID roachpb.NodeID, class rpc.ConnectionClass, -) (context.Context, roachpb.InternalClient, error) { +) (rpc.RestrictedInternalClient, error) { if n == nil || n.resolver == nil { - return nil, nil, errors.New("no node dialer configured") + return nil, errors.New("no node dialer configured") } addr, err := n.resolver(nodeID) if err != nil { - return nil, nil, err + return nil, err } { @@ -159,20 +157,15 @@ func (n *Dialer) DialInternalClient( localClient := n.rpcContext.GetLocalInternalClientForAddr(addr.String(), nodeID) if localClient != nil && !n.testingKnobs.TestingNoLocalClientOptimization { log.VEvent(ctx, 2, kvbase.RoutingRequestLocallyMsg) - - // Create a new context from the existing one with the "local request" field set. - // This tells the handler that this is an in-process request, bypassing ctx.Peer checks. - localCtx := grpcutil.NewLocalRequestContext(ctx) - - return localCtx, localClient, nil + return localClient, nil } } log.VEventf(ctx, 2, "sending request to %s", addr) conn, err := n.dial(ctx, nodeID, addr, n.getBreaker(nodeID, class), true /* checkBreaker */, class) if err != nil { - return nil, nil, err + return nil, err } - return ctx, TracingInternalClient{InternalClient: roachpb.NewInternalClient(conn)}, err + return TracingInternalClient{InternalClient: roachpb.NewInternalClient(conn)}, err } // dial performs the dialing of the remote connection. If breaker is nil, diff --git a/pkg/rpc/restricted_internal_client.go b/pkg/rpc/restricted_internal_client.go new file mode 100644 index 000000000000..1f7b480c3eae --- /dev/null +++ b/pkg/rpc/restricted_internal_client.go @@ -0,0 +1,27 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package rpc + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "google.golang.org/grpc" +) + +// RestrictedInternalClient represents the part of the roachpb.InternalClient +// interface used by the DistSender. Besides the auto-generated gRPC client, +// this interface is also implemented by rpc.internalClientAdapter which +// bypasses gRPC to call into the local Node. +type RestrictedInternalClient interface { + Batch(ctx context.Context, in *roachpb.BatchRequest, opts ...grpc.CallOption) (*roachpb.BatchResponse, error) + RangeFeed(ctx context.Context, in *roachpb.RangeFeedRequest, opts ...grpc.CallOption) (roachpb.Internal_RangeFeedClient, error) +}