From f20164ebb1d5021f2294b8683f1d7dd21e283881 Mon Sep 17 00:00:00 2001 From: Nathan VanBenschoten Date: Tue, 3 Dec 2024 18:28:15 -0500 Subject: [PATCH 1/3] rpc: reuse gRPC streams across unary BatchRequest RPCs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #136572. This commit introduces pooling of gRPC streams that are used to send requests and receive corresponding responses in a manner that mimics unary RPC invocation. Pooling these streams allows for reuse of gRPC resources across calls, as opposed to native unary RPCs, which create a new stream and throw it away for each request (see grpc.invoke). The new pooling mechanism is used for the Internal/Batch RPC method, which is the dominant RPC method used to communicate between the KV client and KV server. A new Internal/BatchStream RPC method is introduced to allow a client to send and receive BatchRequest/BatchResponse pairs over a long-lived, pooled stream. A pool of these streams is then maintained alongside each gRPC connection. The pool grows and shrinks dynamically based on demand. The change demonstrates a large performance improvement in both microbenchmarks and full system benchmarks, which reveals just how expensive the gRPC stream setup on each unary RPC is. Microbenchmarks: ``` name old time/op new time/op delta Sysbench/KV/1node_remote/oltp_point_select-10 45.9µs ± 1% 28.8µs ± 2% -37.31% (p=0.000 n=9+8) Sysbench/KV/1node_remote/oltp_read_only-10 958µs ± 6% 709µs ± 1% -26.00% (p=0.000 n=9+9) Sysbench/SQL/1node_remote/oltp_read_only-10 3.65ms ± 6% 2.81ms ± 8% -23.06% (p=0.000 n=8+9) Sysbench/KV/1node_remote/oltp_read_write-10 1.77ms ± 5% 1.38ms ± 1% -22.09% (p=0.000 n=10+8) Sysbench/KV/1node_remote/oltp_write_only-10 688µs ± 4% 557µs ± 1% -19.11% (p=0.000 n=9+9) Sysbench/SQL/1node_remote/oltp_point_select-10 181µs ± 8% 159µs ± 2% -12.10% (p=0.000 n=8+9) Sysbench/SQL/1node_remote/oltp_write_only-10 2.16ms ± 4% 1.92ms ± 3% -11.08% (p=0.000 n=9+9) Sysbench/SQL/1node_remote/oltp_read_write-10 5.89ms ± 2% 5.36ms ± 1% -8.89% (p=0.000 n=9+9) name old alloc/op new alloc/op delta Sysbench/KV/1node_remote/oltp_point_select-10 16.3kB ± 0% 6.4kB ± 0% -60.70% (p=0.000 n=8+10) Sysbench/KV/1node_remote/oltp_write_only-10 359kB ± 1% 256kB ± 1% -28.92% (p=0.000 n=10+10) Sysbench/SQL/1node_remote/oltp_write_only-10 748kB ± 0% 548kB ± 1% -26.78% (p=0.000 n=8+10) Sysbench/SQL/1node_remote/oltp_point_select-10 40.9kB ± 0% 30.8kB ± 0% -24.74% (p=0.000 n=9+10) Sysbench/KV/1node_remote/oltp_read_write-10 1.11MB ± 1% 0.88MB ± 1% -21.17% (p=0.000 n=9+10) Sysbench/SQL/1node_remote/oltp_read_write-10 2.00MB ± 0% 1.65MB ± 0% -17.60% (p=0.000 n=9+10) Sysbench/KV/1node_remote/oltp_read_only-10 790kB ± 0% 655kB ± 0% -17.11% (p=0.000 n=9+9) Sysbench/SQL/1node_remote/oltp_read_only-10 1.33MB ± 0% 1.19MB ± 0% -10.97% (p=0.000 n=10+9) name old allocs/op new allocs/op delta Sysbench/KV/1node_remote/oltp_point_select-10 210 ± 0% 61 ± 0% -70.95% (p=0.000 n=10+10) Sysbench/KV/1node_remote/oltp_read_only-10 3.98k ± 0% 1.88k ± 0% -52.68% (p=0.019 n=6+8) Sysbench/KV/1node_remote/oltp_read_write-10 7.10k ± 0% 3.47k ± 0% -51.07% (p=0.000 n=10+9) Sysbench/KV/1node_remote/oltp_write_only-10 3.10k ± 0% 1.58k ± 0% -48.89% (p=0.000 n=10+9) Sysbench/SQL/1node_remote/oltp_write_only-10 6.73k ± 0% 3.82k ± 0% -43.30% (p=0.000 n=10+10) Sysbench/SQL/1node_remote/oltp_read_write-10 14.4k ± 0% 9.2k ± 0% -36.29% (p=0.000 n=9+10) Sysbench/SQL/1node_remote/oltp_point_select-10 429 ± 0% 277 ± 0% -35.46% (p=0.000 n=9+10) Sysbench/SQL/1node_remote/oltp_read_only-10 7.52k ± 0% 5.37k ± 0% -28.60% (p=0.000 n=10+10) ``` Roachtests: ``` name old queries/s new queries/s delta sysbench/oltp_read_write/nodes=3/cpu=8/conc=64 17.6k ± 7% 19.2k ± 2% +9.22% (p=0.008 n=5+5) name old avg_ms/op new avg_ms/op delta sysbench/oltp_read_write/nodes=3/cpu=8/conc=64 72.9 ± 7% 66.6 ± 2% -8.57% (p=0.008 n=5+5) name old p95_ms/op new p95_ms/op delta sysbench/oltp_read_write/nodes=3/cpu=8/conc=64 116 ± 8% 106 ± 3% -9.02% (p=0.016 n=5+5) ``` Manual tests: ``` Running in a similar configuration to sysbench/oltp_read_write/nodes=3/cpu=8/conc=64, but with a benchmarking related cluster settings (before and after) to reduce variance. -- Before Mean: 19771.03 Median: 19714.22 Standard Deviation: 282.96 Coefficient of variance: .0143 -- After Mean: 21908.23 Median: 21923.03 Standard Deviation: 200.88 Coefficient of variance: .0091 ``` Release note (performance improvement): gRPC streams are now pooled across unary intra-cluster RPCs, allowing for reuse of gRPC resources to reduce the cost of remote key-value layer access. This pooling can be disabled using the rpc.batch_stream_pool.enabled cluster setting. --- build/bazelutil/check.sh | 1 + .../settings/settings-for-tenants.txt | 2 +- docs/generated/settings/settings.html | 2 +- pkg/clusterversion/cockroach_versions.go | 5 + pkg/kv/kvclient/kvcoord/transport_test.go | 6 + pkg/kv/kvclient/kvtenant/connector_test.go | 4 + pkg/kv/kvpb/api.proto | 26 +- pkg/kv/kvpb/kvpbmock/mocks_generated.go | 20 + pkg/rpc/BUILD.bazel | 8 +- pkg/rpc/auth_tenant.go | 3 +- pkg/rpc/auth_test.go | 26 +- pkg/rpc/connection.go | 18 +- pkg/rpc/context_test.go | 4 + pkg/rpc/mocks_generated_test.go | 141 +++++- pkg/rpc/nodedialer/BUILD.bazel | 4 + pkg/rpc/nodedialer/nodedialer.go | 76 +++- pkg/rpc/nodedialer/nodedialer_test.go | 6 +- pkg/rpc/peer.go | 13 +- pkg/rpc/stream_pool.go | 330 ++++++++++++++ pkg/rpc/stream_pool_test.go | 416 ++++++++++++++++++ pkg/server/node.go | 26 ++ .../grpcinterceptor/grpc_interceptor.go | 4 + 22 files changed, 1107 insertions(+), 34 deletions(-) create mode 100644 pkg/rpc/stream_pool.go create mode 100644 pkg/rpc/stream_pool_test.go diff --git a/build/bazelutil/check.sh b/build/bazelutil/check.sh index ca6ff332e576..e2703d20ac35 100755 --- a/build/bazelutil/check.sh +++ b/build/bazelutil/check.sh @@ -18,6 +18,7 @@ GIT_GREP="git $CONFIGS grep" EXISTING_GO_GENERATE_COMMENTS=" pkg/config/field.go://go:generate stringer --type=Field --linecomment pkg/rpc/context.go://go:generate mockgen -destination=mocks_generated_test.go --package=. Dialbacker +pkg/rpc/stream_pool.go://go:generate mockgen -destination=mocks_generated_test.go --package=. BatchStreamClient pkg/roachprod/vm/aws/config.go://go:generate terraformgen -o terraform/main.tf pkg/roachprod/prometheus/prometheus.go://go:generate mockgen -package=prometheus -destination=mocks_generated_test.go . Cluster pkg/cmd/roachtest/clusterstats/collector.go://go:generate mockgen -package=clusterstats -destination mocks_generated_test.go github.com/cockroachdb/cockroach/pkg/roachprod/prometheus Client diff --git a/docs/generated/settings/settings-for-tenants.txt b/docs/generated/settings/settings-for-tenants.txt index 4e996105888c..d9e3251ff718 100644 --- a/docs/generated/settings/settings-for-tenants.txt +++ b/docs/generated/settings/settings-for-tenants.txt @@ -401,4 +401,4 @@ trace.span_registry.enabled boolean false if set, ongoing traces can be seen at trace.zipkin.collector string the address of a Zipkin instance to receive traces, as :. If no port is specified, 9411 will be used. application ui.database_locality_metadata.enabled boolean true if enabled shows extended locality data about databases and tables in DB Console which can be expensive to compute application ui.display_timezone enumeration etc/utc the timezone used to format timestamps in the ui [etc/utc = 0, america/new_york = 1] application -version version 1000024.3-upgrading-to-1000025.1-step-008 set the active cluster version in the format '.' application +version version 1000024.3-upgrading-to-1000025.1-step-010 set the active cluster version in the format '.' application diff --git a/docs/generated/settings/settings.html b/docs/generated/settings/settings.html index f4f277f43c8b..3c7afe299303 100644 --- a/docs/generated/settings/settings.html +++ b/docs/generated/settings/settings.html @@ -360,6 +360,6 @@
trace.zipkin.collector
stringthe address of a Zipkin instance to receive traces, as <host>:<port>. If no port is specified, 9411 will be used.Serverless/Dedicated/Self-Hosted
ui.database_locality_metadata.enabled
booleantrueif enabled shows extended locality data about databases and tables in DB Console which can be expensive to computeServerless/Dedicated/Self-Hosted
ui.display_timezone
enumerationetc/utcthe timezone used to format timestamps in the ui [etc/utc = 0, america/new_york = 1]Serverless/Dedicated/Self-Hosted -
version
version1000024.3-upgrading-to-1000025.1-step-008set the active cluster version in the format '<major>.<minor>'Serverless/Dedicated/Self-Hosted +
version
version1000024.3-upgrading-to-1000025.1-step-010set the active cluster version in the format '<major>.<minor>'Serverless/Dedicated/Self-Hosted diff --git a/pkg/clusterversion/cockroach_versions.go b/pkg/clusterversion/cockroach_versions.go index 4c41adb0251a..cbb01bc889cc 100644 --- a/pkg/clusterversion/cockroach_versions.go +++ b/pkg/clusterversion/cockroach_versions.go @@ -198,6 +198,10 @@ const ( // range-ID local key, which is written below raft. V25_1_AddRangeForceFlushKey + // V25_1_BatchStreamRPC adds the BatchStream RPC, which allows for more + // efficient Batch unary RPCs. + V25_1_BatchStreamRPC + // ************************************************* // Step (1) Add new versions above this comment. // Do not add new versions to a patch release. @@ -240,6 +244,7 @@ var versionTable = [numKeys]roachpb.Version{ V25_1_AddJobsTables: {Major: 24, Minor: 3, Internal: 4}, V25_1_MoveRaftTruncatedState: {Major: 24, Minor: 3, Internal: 6}, V25_1_AddRangeForceFlushKey: {Major: 24, Minor: 3, Internal: 8}, + V25_1_BatchStreamRPC: {Major: 24, Minor: 3, Internal: 10}, // ************************************************* // Step (2): Add new versions above this comment. diff --git a/pkg/kv/kvclient/kvcoord/transport_test.go b/pkg/kv/kvclient/kvcoord/transport_test.go index ddb370236689..92f57021a680 100644 --- a/pkg/kv/kvclient/kvcoord/transport_test.go +++ b/pkg/kv/kvclient/kvcoord/transport_test.go @@ -255,6 +255,12 @@ func (m *mockInternalClient) Batch( return br, nil } +func (m *mockInternalClient) BatchStream( + ctx context.Context, opts ...grpc.CallOption, +) (kvpb.Internal_BatchStreamClient, error) { + return nil, fmt.Errorf("unsupported BatchStream call") +} + // RangeLookup implements the kvpb.InternalClient interface. func (m *mockInternalClient) RangeLookup( ctx context.Context, rl *kvpb.RangeLookupRequest, _ ...grpc.CallOption, diff --git a/pkg/kv/kvclient/kvtenant/connector_test.go b/pkg/kv/kvclient/kvtenant/connector_test.go index 22ed7e472ee4..a4264229846c 100644 --- a/pkg/kv/kvclient/kvtenant/connector_test.go +++ b/pkg/kv/kvclient/kvtenant/connector_test.go @@ -121,6 +121,10 @@ func (*mockServer) Batch(context.Context, *kvpb.BatchRequest) (*kvpb.BatchRespon panic("unimplemented") } +func (m *mockServer) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + panic("implement me") +} + func (m *mockServer) MuxRangeFeed(server kvpb.Internal_MuxRangeFeedServer) error { panic("implement me") } diff --git a/pkg/kv/kvpb/api.proto b/pkg/kv/kvpb/api.proto index eeaf95b84bba..befe0824b75a 100644 --- a/pkg/kv/kvpb/api.proto +++ b/pkg/kv/kvpb/api.proto @@ -3668,23 +3668,30 @@ message JoinNodeResponse { // Batch and RangeFeed service implemented by nodes for KV API requests. service Internal { - rpc Batch (BatchRequest) returns (BatchResponse) {} + rpc Batch (BatchRequest) returns (BatchResponse) {} + + // BatchStream is a streaming variant of Batch. There is a 1:1 correspondence + // between requests and responses. The method is used to facilitate pooling of + // gRPC streams to avoid the overhead of creating and discarding a new stream + // for each unary Batch RPC invocation. See rpc.BatchStreamPool. + rpc BatchStream (stream BatchRequest) returns (stream BatchResponse) {} + rpc RangeLookup (RangeLookupRequest) returns (RangeLookupResponse) {} - rpc MuxRangeFeed (stream RangeFeedRequest) returns (stream MuxRangeFeedEvent) {} + rpc MuxRangeFeed (stream RangeFeedRequest) returns (stream MuxRangeFeedEvent) {} rpc GossipSubscription (GossipSubscriptionRequest) returns (stream GossipSubscriptionEvent) {} rpc ResetQuorum (ResetQuorumRequest) returns (ResetQuorumResponse) {} // TokenBucket is used by tenants to obtain Request Units and report // consumption. - rpc TokenBucket (TokenBucketRequest) returns (TokenBucketResponse) {} + rpc TokenBucket (TokenBucketRequest) returns (TokenBucketResponse) {} // Join a bootstrapped cluster. If the target node is itself not part of a // bootstrapped cluster, an appropriate error is returned. - rpc Join(JoinNodeRequest) returns (JoinNodeResponse) { } + rpc Join (JoinNodeRequest) returns (JoinNodeResponse) {} // GetSpanConfigs is used to fetch the span configurations over a given // keyspan. - rpc GetSpanConfigs (GetSpanConfigsRequest) returns (GetSpanConfigsResponse) { } + rpc GetSpanConfigs (GetSpanConfigsRequest) returns (GetSpanConfigsResponse) {} // GetAllSystemSpanConfigsThatApply is used to fetch all system span // configurations that apply over a tenant's ranges. @@ -3692,20 +3699,19 @@ service Internal { // UpdateSpanConfigs is used to update the span configurations over given // keyspans. - rpc UpdateSpanConfigs (UpdateSpanConfigsRequest) returns (UpdateSpanConfigsResponse) { } + rpc UpdateSpanConfigs (UpdateSpanConfigsRequest) returns (UpdateSpanConfigsResponse) {} // SpanConfigConformance is used to determine whether ranges backing the given // keyspans conform to span configs that apply over them. - rpc SpanConfigConformance (SpanConfigConformanceRequest) returns (SpanConfigConformanceResponse) { } + rpc SpanConfigConformance (SpanConfigConformanceRequest) returns (SpanConfigConformanceResponse) {} // TenantSettings is used by tenants to obtain and stay up to date with tenant // setting overrides. - rpc TenantSettings (TenantSettingsRequest) returns (stream TenantSettingsEvent) { } - + rpc TenantSettings (TenantSettingsRequest) returns (stream TenantSettingsEvent) {} // GetRangeDescriptors is used by tenants to get range descriptors for their // own ranges. - rpc GetRangeDescriptors (GetRangeDescriptorsRequest) returns (stream GetRangeDescriptorsResponse) { } + rpc GetRangeDescriptors (GetRangeDescriptorsRequest) returns (stream GetRangeDescriptorsResponse) {} } // GetRangeDescriptorsRequest is used to fetch range descriptors. diff --git a/pkg/kv/kvpb/kvpbmock/mocks_generated.go b/pkg/kv/kvpb/kvpbmock/mocks_generated.go index fae67f8dd2c1..6f85105a1cc2 100644 --- a/pkg/kv/kvpb/kvpbmock/mocks_generated.go +++ b/pkg/kv/kvpb/kvpbmock/mocks_generated.go @@ -58,6 +58,26 @@ func (mr *MockInternalClientMockRecorder) Batch(arg0, arg1 interface{}, arg2 ... return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Batch", reflect.TypeOf((*MockInternalClient)(nil).Batch), varargs...) } +// BatchStream mocks base method. +func (m *MockInternalClient) BatchStream(arg0 context.Context, arg1 ...grpc.CallOption) (kvpb.Internal_BatchStreamClient, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchStream", varargs...) + ret0, _ := ret[0].(kvpb.Internal_BatchStreamClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchStream indicates an expected call of BatchStream. +func (mr *MockInternalClientMockRecorder) BatchStream(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchStream", reflect.TypeOf((*MockInternalClient)(nil).BatchStream), varargs...) +} + // GetAllSystemSpanConfigsThatApply mocks base method. func (m *MockInternalClient) GetAllSystemSpanConfigsThatApply(arg0 context.Context, arg1 *roachpb.GetAllSystemSpanConfigsThatApplyRequest, arg2 ...grpc.CallOption) (*roachpb.GetAllSystemSpanConfigsThatApplyResponse, error) { m.ctrl.T.Helper() diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index 9e4fc352c5a0..92e48762070d 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -25,6 +25,7 @@ go_library( "restricted_internal_client.go", "settings.go", "snappy.go", + "stream_pool.go", "tls.go", ], embed = [":rpc_go_proto"], @@ -90,7 +91,10 @@ go_library( gomock( name = "mock_rpc", out = "mocks_generated_test.go", - interfaces = ["Dialbacker"], + interfaces = [ + "BatchStreamClient", + "Dialbacker", + ], library = ":rpc", package = "rpc", self_package = "github.com/cockroachdb/cockroach/pkg/rpc", @@ -116,6 +120,7 @@ go_test( "metrics_test.go", "peer_test.go", "snappy_test.go", + "stream_pool_test.go", "tls_test.go", ":mock_rpc", # keep ], @@ -173,6 +178,7 @@ go_test( "@org_golang_google_grpc//metadata", "@org_golang_google_grpc//peer", "@org_golang_google_grpc//status", + "@org_golang_x_sync//errgroup", ], ) diff --git a/pkg/rpc/auth_tenant.go b/pkg/rpc/auth_tenant.go index c5233fc21b41..d923cf643fa4 100644 --- a/pkg/rpc/auth_tenant.go +++ b/pkg/rpc/auth_tenant.go @@ -55,7 +55,7 @@ func (a tenantAuthorizer) authorize( req interface{}, ) error { switch fullMethod { - case "/cockroach.roachpb.Internal/Batch": + case "/cockroach.roachpb.Internal/Batch", "/cockroach.roachpb.Internal/BatchStream": return a.authBatch(ctx, sv, tenID, req.(*kvpb.BatchRequest)) case "/cockroach.roachpb.Internal/RangeLookup": @@ -63,6 +63,7 @@ func (a tenantAuthorizer) authorize( case "/cockroach.roachpb.Internal/RangeFeed", "/cockroach.roachpb.Internal/MuxRangeFeed": return a.authRangeFeed(tenID, req.(*kvpb.RangeFeedRequest)) + case "/cockroach.roachpb.Internal/GossipSubscription": return a.authGossipSubscription(tenID, req.(*kvpb.GossipSubscriptionRequest)) diff --git a/pkg/rpc/auth_test.go b/pkg/rpc/auth_test.go index b3d90548e946..14c735a7163f 100644 --- a/pkg/rpc/auth_test.go +++ b/pkg/rpc/auth_test.go @@ -572,6 +572,30 @@ func TestTenantAuthRequest(t *testing.T) { expErr: noError, }, }, + "/cockroach.roachpb.Internal/BatchStream": { + { + req: &kvpb.BatchRequest{}, + expErr: `requested key span /Max not fully contained in tenant keyspace /Tenant/1{0-1}`, + }, + { + req: &kvpb.BatchRequest{Requests: makeReqs( + makeReq("a", "b"), + )}, + expErr: `requested key span {a-b} not fully contained in tenant keyspace /Tenant/1{0-1}`, + }, + { + req: &kvpb.BatchRequest{Requests: makeReqs( + makeReq(prefix(5, "a"), prefix(5, "b")), + )}, + expErr: `requested key span /Tenant/5{a-b} not fully contained in tenant keyspace /Tenant/1{0-1}`, + }, + { + req: &kvpb.BatchRequest{Requests: makeReqs( + makeReq(prefix(10, "a"), prefix(10, "b")), + )}, + expErr: noError, + }, + }, "/cockroach.roachpb.Internal/RangeLookup": { { req: &kvpb.RangeLookupRequest{}, @@ -1009,7 +1033,7 @@ func TestTenantAuthRequest(t *testing.T) { // cross-read capability and the request is a read, expect no error. if canCrossRead && strings.Contains(tc.expErr, "fully contained") { switch method { - case "/cockroach.roachpb.Internal/Batch": + case "/cockroach.roachpb.Internal/Batch", "/cockroach.roachpb.Internal/BatchStream": if tc.req.(*kvpb.BatchRequest).IsReadOnly() { tc.expErr = noError } diff --git a/pkg/rpc/connection.go b/pkg/rpc/connection.go index decfed37f61c..2adfd4560370 100644 --- a/pkg/rpc/connection.go +++ b/pkg/rpc/connection.go @@ -34,17 +34,26 @@ type Connection struct { // It always has to be signaled eventually, regardless of the stopper // draining, etc, since callers might be blocking on it. connFuture connFuture + // batchStreamPool holds a pool of BatchStreamClient streams established on + // the connection. The pool can be used to avoid the overhead of unary Batch + // RPCs. + // + // The pool is only initialized once the ClientConn is resolved. + batchStreamPool BatchStreamPool } // newConnectionToNodeID makes a Connection for the given node, class, and nontrivial Signal // that should be queried in Connect(). -func newConnectionToNodeID(k peerKey, breakerSignal func() circuit.Signal) *Connection { +func newConnectionToNodeID( + opts *ContextOptions, k peerKey, breakerSignal func() circuit.Signal, +) *Connection { c := &Connection{ breakerSignalFn: breakerSignal, k: k, connFuture: connFuture{ ready: make(chan struct{}), }, + batchStreamPool: makeStreamPool(opts.Stopper, newBatchStream), } return c } @@ -156,6 +165,13 @@ func (c *Connection) Signal() circuit.Signal { return c.breakerSignalFn() } +func (c *Connection) BatchStreamPool() *BatchStreamPool { + if !c.connFuture.Resolved() { + panic("BatchStreamPool called on unresolved connection") + } + return &c.batchStreamPool +} + type connFuture struct { ready chan struct{} cc *grpc.ClientConn diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 5c65bf6ca708..d66ffa55aa9d 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -306,6 +306,10 @@ func (*internalServer) Batch(context.Context, *kvpb.BatchRequest) (*kvpb.BatchRe return nil, nil } +func (*internalServer) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + panic("unimplemented") +} + func (*internalServer) RangeLookup( context.Context, *kvpb.RangeLookupRequest, ) (*kvpb.RangeLookupResponse, error) { diff --git a/pkg/rpc/mocks_generated_test.go b/pkg/rpc/mocks_generated_test.go index f0255bf145f4..ea3f1f7bdc2c 100644 --- a/pkg/rpc/mocks_generated_test.go +++ b/pkg/rpc/mocks_generated_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/cockroachdb/cockroach/pkg/rpc (interfaces: Dialbacker) +// Source: github.com/cockroachdb/cockroach/pkg/rpc (interfaces: BatchStreamClient,Dialbacker) // Package rpc is a generated GoMock package. package rpc @@ -8,12 +8,151 @@ import ( context "context" reflect "reflect" + kvpb "github.com/cockroachdb/cockroach/pkg/kv/kvpb" roachpb "github.com/cockroachdb/cockroach/pkg/roachpb" rpcpb "github.com/cockroachdb/cockroach/pkg/rpc/rpcpb" gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" ) +// MockBatchStreamClient is a mock of BatchStreamClient interface. +type MockBatchStreamClient struct { + ctrl *gomock.Controller + recorder *MockBatchStreamClientMockRecorder +} + +// MockBatchStreamClientMockRecorder is the mock recorder for MockBatchStreamClient. +type MockBatchStreamClientMockRecorder struct { + mock *MockBatchStreamClient +} + +// NewMockBatchStreamClient creates a new mock instance. +func NewMockBatchStreamClient(ctrl *gomock.Controller) *MockBatchStreamClient { + mock := &MockBatchStreamClient{ctrl: ctrl} + mock.recorder = &MockBatchStreamClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBatchStreamClient) EXPECT() *MockBatchStreamClientMockRecorder { + return m.recorder +} + +// CloseSend mocks base method. +func (m *MockBatchStreamClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend. +func (mr *MockBatchStreamClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockBatchStreamClient)(nil).CloseSend)) +} + +// Context mocks base method. +func (m *MockBatchStreamClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockBatchStreamClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockBatchStreamClient)(nil).Context)) +} + +// Header mocks base method. +func (m *MockBatchStreamClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header. +func (mr *MockBatchStreamClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockBatchStreamClient)(nil).Header)) +} + +// Recv mocks base method. +func (m *MockBatchStreamClient) Recv() (*kvpb.BatchResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*kvpb.BatchResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockBatchStreamClientMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockBatchStreamClient)(nil).Recv)) +} + +// RecvMsg mocks base method. +func (m *MockBatchStreamClient) RecvMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockBatchStreamClientMockRecorder) RecvMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockBatchStreamClient)(nil).RecvMsg), arg0) +} + +// Send mocks base method. +func (m *MockBatchStreamClient) Send(arg0 *kvpb.BatchRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockBatchStreamClientMockRecorder) Send(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockBatchStreamClient)(nil).Send), arg0) +} + +// SendMsg mocks base method. +func (m *MockBatchStreamClient) SendMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockBatchStreamClientMockRecorder) SendMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockBatchStreamClient)(nil).SendMsg), arg0) +} + +// Trailer mocks base method. +func (m *MockBatchStreamClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer. +func (mr *MockBatchStreamClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockBatchStreamClient)(nil).Trailer)) +} + // MockDialbacker is a mock of Dialbacker interface. type MockDialbacker struct { ctrl *gomock.Controller diff --git a/pkg/rpc/nodedialer/BUILD.bazel b/pkg/rpc/nodedialer/BUILD.bazel index 192eb71c9512..b405795e4d32 100644 --- a/pkg/rpc/nodedialer/BUILD.bazel +++ b/pkg/rpc/nodedialer/BUILD.bazel @@ -7,12 +7,16 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/base", + "//pkg/clusterversion", "//pkg/kv/kvbase", "//pkg/kv/kvpb", "//pkg/roachpb", "//pkg/rpc", + "//pkg/settings", + "//pkg/settings/cluster", "//pkg/util/circuit", "//pkg/util/log", + "//pkg/util/metamorphic", "//pkg/util/stop", "//pkg/util/tracing", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index 1715f5c7c3a5..ea444f668031 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -11,12 +11,16 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/kv/kvbase" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/settings" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" circuit2 "github.com/cockroachdb/cockroach/pkg/util/circuit" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/metamorphic" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" @@ -96,7 +100,8 @@ func (n *Dialer) Dial( err = errors.Wrapf(err, "failed to resolve n%d", nodeID) return nil, err } - return n.dial(ctx, nodeID, addr, locality, true, class) + conn, _, err := n.dial(ctx, nodeID, addr, locality, true, class) + return conn, err } // DialNoBreaker is like Dial, but will not check the circuit breaker before @@ -112,7 +117,8 @@ func (n *Dialer) DialNoBreaker( if err != nil { return nil, err } - return n.dial(ctx, nodeID, addr, locality, false, class) + conn, _, err := n.dial(ctx, nodeID, addr, locality, false, class) + return conn, err } // DialInternalClient is a specialization of DialClass for callers that @@ -141,11 +147,12 @@ func (n *Dialer) DialInternalClient( return nil, errors.Wrap(err, "resolver error") } log.VEventf(ctx, 2, "sending request to %s", addr) - conn, err := n.dial(ctx, nodeID, addr, locality, true, class) + conn, pool, err := n.dial(ctx, nodeID, addr, locality, true, class) if err != nil { return nil, err } client := kvpb.NewInternalClient(conn) + client = maybeWrapInBatchStreamPoolClient(ctx, n.rpcContext.Settings, client, pool) client = maybeWrapInTracingClient(ctx, client) return client, nil } @@ -160,11 +167,11 @@ func (n *Dialer) dial( locality roachpb.Locality, checkBreaker bool, class rpc.ConnectionClass, -) (_ *grpc.ClientConn, err error) { +) (_ *grpc.ClientConn, _ *rpc.BatchStreamPool, err error) { const ctxWrapMsg = "dial" // Don't trip the breaker if we're already canceled. if ctxErr := ctx.Err(); ctxErr != nil { - return nil, errors.Wrap(ctxErr, ctxWrapMsg) + return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) } rpcConn := n.rpcContext.GRPCDialNode(addr.String(), nodeID, locality, class) connect := rpcConn.Connect @@ -175,13 +182,13 @@ func (n *Dialer) dial( if err != nil { // If we were canceled during the dial, don't trip the breaker. if ctxErr := ctx.Err(); ctxErr != nil { - return nil, errors.Wrap(ctxErr, ctxWrapMsg) + return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) } err = errors.Wrapf(err, "failed to connect to n%d at %v", nodeID, addr) - return nil, err + return nil, nil, err } - - return conn, nil + pool := rpcConn.BatchStreamPool() + return conn, pool, nil } // ConnHealth returns nil if we have an open connection of the request @@ -275,25 +282,66 @@ func (n *Dialer) Latency(nodeID roachpb.NodeID) (time.Duration, error) { return latency, nil } -// TracingInternalClient wraps an InternalClient and fills in trace information +var batchStreamPoolingEnabled = settings.RegisterBoolSetting( + settings.ApplicationLevel, + "rpc.batch_stream_pool.enabled", + "if true, use pooled gRPC streams to execute Batch RPCs", + metamorphic.ConstantWithTestBool("rpc.batch_stream_pool.enabled", true), +) + +// batchStreamPoolClient is a client that sends Batch RPCs using a pooled +// BatchStream RPC stream. Pooling these streams allows for reuse of gRPC +// resources, as opposed to native unary RPCs, which create a new stream and +// throw it away for each unary request (see grpc.invoke). +type batchStreamPoolClient struct { + kvpb.InternalClient + pool *rpc.BatchStreamPool +} + +func maybeWrapInBatchStreamPoolClient( + ctx context.Context, st *cluster.Settings, client kvpb.InternalClient, pool *rpc.BatchStreamPool, +) kvpb.InternalClient { + // NOTE: we use ActiveVersionOrEmpty(ctx).IsActive(...) instead of the more + // common IsActive(ctx, ...) to avoid a fatal error if an RPC is made before + // the cluster version is initialized. + if !st.Version.ActiveVersionOrEmpty(ctx).IsActive(clusterversion.V25_1_BatchStreamRPC) { + return client + } + if !batchStreamPoolingEnabled.Get(&st.SV) { + return client + } + return &batchStreamPoolClient{InternalClient: client, pool: pool} +} + +// Batch overrides the Batch RPC client method to use the BatchStreamPool. +func (bsc *batchStreamPoolClient) Batch( + ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, +) (*kvpb.BatchResponse, error) { + if len(opts) > 0 { + return nil, errors.AssertionFailedf("BatchStreamPoolClient.Batch does not support CallOptions") + } + return bsc.pool.Send(ctx, ba) +} + +// tracingInternalClient wraps an InternalClient and fills in trace information // on Batch RPCs. // -// Note that TracingInternalClient is not used to wrap the internalClientAdapter +// Note that tracingInternalClient is not used to wrap the internalClientAdapter // - local RPCs don't need this tracing functionality. -type TracingInternalClient struct { +type tracingInternalClient struct { kvpb.InternalClient } func maybeWrapInTracingClient(ctx context.Context, client kvpb.InternalClient) kvpb.InternalClient { sp := tracing.SpanFromContext(ctx) if sp != nil && !sp.IsNoop() { - client = &TracingInternalClient{InternalClient: client} + client = &tracingInternalClient{InternalClient: client} } return client } // Batch overrides the Batch RPC client method and fills in tracing information. -func (tic *TracingInternalClient) Batch( +func (tic *tracingInternalClient) Batch( ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, ) (*kvpb.BatchResponse, error) { sp := tracing.SpanFromContext(ctx) diff --git a/pkg/rpc/nodedialer/nodedialer_test.go b/pkg/rpc/nodedialer/nodedialer_test.go index aa8298bd8d01..1cc7cf453fa5 100644 --- a/pkg/rpc/nodedialer/nodedialer_test.go +++ b/pkg/rpc/nodedialer/nodedialer_test.go @@ -415,6 +415,10 @@ func (*internalServer) Batch(context.Context, *kvpb.BatchRequest) (*kvpb.BatchRe return nil, nil } +func (*internalServer) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + panic("unimplemented") +} + func (*internalServer) RangeLookup( context.Context, *kvpb.RangeLookupRequest, ) (*kvpb.RangeLookupResponse, error) { @@ -422,7 +426,7 @@ func (*internalServer) RangeLookup( } func (s *internalServer) MuxRangeFeed(server kvpb.Internal_MuxRangeFeedServer) error { - panic("implement me") + panic("unimplemented") } func (*internalServer) GossipSubscription( diff --git a/pkg/rpc/peer.go b/pkg/rpc/peer.go index 7b24db0e7e81..3cc0bb599168 100644 --- a/pkg/rpc/peer.go +++ b/pkg/rpc/peer.go @@ -260,7 +260,7 @@ func (rpcCtx *Context) newPeer(k peerKey, locality roachpb.Locality) *peer { }, }) p.b = b - c := newConnectionToNodeID(k, b.Signal) + c := newConnectionToNodeID(p.opts, k, b.Signal) p.mu.PeerSnap = PeerSnap{c: c} return p @@ -361,7 +361,7 @@ func (p *peer) run(ctx context.Context, report func(error), done func()) { func() { p.mu.Lock() defer p.mu.Unlock() - p.mu.c = newConnectionToNodeID(p.k, p.mu.c.breakerSignalFn) + p.mu.c = newConnectionToNodeID(p.opts, p.k, p.mu.c.breakerSignalFn) }() if p.snap().deleteAfter != 0 { @@ -582,6 +582,11 @@ func (p *peer) onInitialHeartbeatSucceeded( p.ConnectionHeartbeats.Inc(1) // ConnectionFailures is not updated here. + // Bind the connection's stream pool to the active gRPC connection. Do this + // ahead of signaling the connFuture, so that the stream pool is ready for use + // by the time the connFuture is resolved. + p.mu.c.batchStreamPool.Bind(ctx, cc) + // Close the channel last which is helpful for unit tests that // first waitOrDefault for a healthy conn to then check metrics. p.mu.c.connFuture.Resolve(cc, nil /* err */) @@ -703,6 +708,10 @@ func (p *peer) onHeartbeatFailed( err = &netutil.InitialHeartbeatFailedError{WrappedErr: err} ls.c.connFuture.Resolve(nil /* cc */, err) } + + // Close down the stream pool that was bound to this connection. + ls.c.batchStreamPool.Close() + // By convention, we stick to updating breaker before updating peer // to make it easier to write non-flaky tests. report(err) diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go new file mode 100644 index 000000000000..7be2e9165887 --- /dev/null +++ b/pkg/rpc/stream_pool.go @@ -0,0 +1,330 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package rpc + +import ( + "context" + "io" + "slices" + "time" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" + "google.golang.org/grpc" +) + +// streamClient is a type constraint that is satisfied by a bidirectional gRPC +// client stream. +type streamClient[Req, Resp any] interface { + Send(Req) error + Recv() (Resp, error) + grpc.ClientStream +} + +// streamConstructor creates a new gRPC stream client over the provided client +// connection, using the provided call options. +type streamConstructor[Req, Resp any] func( + context.Context, *grpc.ClientConn, ...grpc.CallOption, +) (streamClient[Req, Resp], error) + +type result[Resp any] struct { + resp Resp + err error +} + +// defaultPooledStreamIdleTimeout is the default duration after which a pooled +// stream is considered idle and is closed. The idle timeout is used to ensure +// that stream pools eventually shrink when the load decreases. +const defaultPooledStreamIdleTimeout = 10 * time.Second + +// pooledStream is a wrapper around a grpc.ClientStream that is managed by a +// streamPool. It is responsible for sending a single request and receiving a +// single response on the stream at a time, mimicking the behavior of a gRPC +// unary RPC. However, unlike a unary RPC, the client stream is not discarded +// after a single use. Instead, it is returned to the pool for reuse. +// +// Most of the complexity around this type (e.g. the worker goroutine) comes +// from the need to handle context cancellation while a request is in-flight. +// gRPC streams support context cancellation, but they use the context provided +// to the stream when it was created for its entire lifetime. Meanwhile, we want +// to be able to handle context cancellation on a per-request basis while we +// layer unary RPC semantics on top of a pooled, bidirectional stream. To +// accomplish this, we use a worker goroutine to perform the (blocking) RPC +// function calls (Send and Recv) and let callers in Send wait on the result of +// the RPC call while also listening to their own context for cancellation. If +// the caller's context is canceled, it cancels the stream's context, which in +// turn cancels the RPC call. +// +// A pooledStream is not safe for concurrent use. It is intended to be used by +// only a single caller at a time. Mutual exclusion is coordinated by removing a +// pooledStream from the pool while it is in use. +// +// A pooledStream must only be returned to the pool for reuse after a successful +// Send call. If the Send call fails, the pooledStream must not be reused. +type pooledStream[Req, Resp any] struct { + pool *streamPool[Req, Resp] + stream streamClient[Req, Resp] + streamCtx context.Context + streamCancel context.CancelFunc + + reqC chan Req + respC chan result[Resp] +} + +func newPooledStream[Req, Resp any]( + pool *streamPool[Req, Resp], + stream streamClient[Req, Resp], + streamCtx context.Context, + streamCancel context.CancelFunc, +) *pooledStream[Req, Resp] { + return &pooledStream[Req, Resp]{ + pool: pool, + stream: stream, + streamCtx: streamCtx, + streamCancel: streamCancel, + reqC: make(chan Req), + respC: make(chan result[Resp], 1), + } +} + +func (s *pooledStream[Req, Resp]) run(ctx context.Context) { + defer s.close() + for s.runOnce(ctx) { + } +} + +func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) { + select { + case req := <-s.reqC: + err := s.stream.Send(req) + if err != nil { + // From grpc.ClientStream.SendMsg: + // > On error, SendMsg aborts the stream. + s.respC <- result[Resp]{err: err} + return false + } + resp, err := s.stream.Recv() + if err != nil { + // From grpc.ClientStream.RecvMsg: + // > It returns io.EOF when the stream completes successfully. On any + // > other error, the stream is aborted and the error contains the RPC + // > status. + if errors.Is(err, io.EOF) { + log.Errorf(ctx, "stream unexpectedly closed by server: %+v", err) + } + s.respC <- result[Resp]{err: err} + return false + } + s.respC <- result[Resp]{resp: resp} + return true + + case <-time.After(s.pool.idleTimeout): + // Try to remove ourselves from the pool. If we don't find ourselves in the + // pool, someone just grabbed us from the pool and we should keep running. + // If we do find and remove ourselves, we can close the stream and stop + // running. This ensures that callers never encounter spurious stream + // closures due to idle timeouts. + return !s.pool.remove(s) + + case <-ctx.Done(): + return false + } +} + +func (s *pooledStream[Req, Resp]) close() { + // Make sure the stream's context is canceled to ensure that we clean up + // resources in idle timeout case. + // + // From grpc.ClientConn.NewStream: + // > To ensure resources are not leaked due to the stream returned, one of the + // > following actions must be performed: + // > ... + // > 2. Cancel the context provided. + // > ... + s.streamCancel() + // Try to remove ourselves from the pool, now that we're closed. If we don't + // find ourselves in the pool, someone has already grabbed us from the pool + // and will check whether we are closed before putting us back. + s.pool.remove(s) +} + +// Send sends a request on the pooled stream and returns the response in a unary +// RPC fashion. Context cancellation is respected. +func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) { + var resp result[Resp] + select { + case s.reqC <- req: + // The request was passed to the stream's worker goroutine, which will + // invoke the RPC function calls (Send and Recv). Wait for a response. + select { + case resp = <-s.respC: + // Return the response. + case <-ctx.Done(): + // Cancel the stream and return the request's context error. + s.streamCancel() + resp.err = ctx.Err() + } + case <-s.streamCtx.Done(): + // The stream was closed before its worker goroutine could accept the + // request. Return the stream's context error. + resp.err = s.streamCtx.Err() + } + + if resp.err != nil { + // On error, wait until we see the streamCtx.Done() signal, to ensure that + // the stream has been cleaned up and won't be placed back in the pool by + // putIfNotClosed. + <-s.streamCtx.Done() + } + return resp.resp, resp.err +} + +// streamPool is a pool of grpc.ClientStream objects (wrapped in pooledStream) +// that are used to send requests and receive corresponding responses in a +// manner that mimics unary RPC invocation. Pooling these streams allows for +// reuse of gRPC resources across calls, as opposed to native unary RPCs, which +// create a new stream and throw it away for each request (see grpc.invoke). +type streamPool[Req, Resp any] struct { + stopper *stop.Stopper + idleTimeout time.Duration + newStream streamConstructor[Req, Resp] + + // cc and ccCtx are set on bind, when the gRPC connection is established. + cc *grpc.ClientConn + // Derived from rpc.Context.MasterCtx, canceled on stopper quiesce. + ccCtx context.Context + + streams struct { + syncutil.Mutex + s []*pooledStream[Req, Resp] + } +} + +func makeStreamPool[Req, Resp any]( + stopper *stop.Stopper, newStream streamConstructor[Req, Resp], +) streamPool[Req, Resp] { + return streamPool[Req, Resp]{ + stopper: stopper, + idleTimeout: defaultPooledStreamIdleTimeout, + newStream: newStream, + } +} + +// Bind sets the gRPC connection and context for the streamPool. This must be +// called once before streamPool.Send. +func (p *streamPool[Req, Resp]) Bind(ctx context.Context, cc *grpc.ClientConn) { + p.cc = cc + p.ccCtx = ctx +} + +// Close closes all streams in the pool. +func (p *streamPool[Req, Resp]) Close() { + p.streams.Lock() + defer p.streams.Unlock() + for _, s := range p.streams.s { + s.streamCancel() + } + p.streams.s = nil +} + +func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] { + p.streams.Lock() + defer p.streams.Unlock() + if len(p.streams.s) == 0 { + return nil + } + // Pop from the tail to bias towards reusing the same streams repeatedly so + // that streams at the head of the slice are more likely to be closed due to + // idle timeouts. + s := p.streams.s[len(p.streams.s)-1] + p.streams.s[len(p.streams.s)-1] = nil + p.streams.s = p.streams.s[:len(p.streams.s)-1] + return s +} + +func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) { + p.streams.Lock() + defer p.streams.Unlock() + if s.streamCtx.Err() != nil { + // The stream is closed, don't put it in the pool. Note that this must be + // done under lock to avoid racing with pooledStream.close, which attempts + // to remove a closing stream from the pool. + return + } + p.streams.s = append(p.streams.s, s) +} + +func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool { + p.streams.Lock() + defer p.streams.Unlock() + i := slices.Index(p.streams.s, s) + if i == -1 { + return false + } + copy(p.streams.s[i:], p.streams.s[i+1:]) + p.streams.s[len(p.streams.s)-1] = nil + p.streams.s = p.streams.s[:len(p.streams.s)-1] + return true +} + +func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], error) { + if p.cc == nil { + return nil, errors.AssertionFailedf("streamPool not bound to a grpc.ClientConn") + } + + ctx, cancel := context.WithCancel(p.ccCtx) + defer func() { + if cancel != nil { + cancel() + } + }() + + stream, err := p.newStream(ctx, p.cc) + if err != nil { + return nil, err + } + + s := newPooledStream(p, stream, ctx, cancel) + if err := p.stopper.RunAsyncTask(ctx, "pooled gRPC stream", s.run); err != nil { + return nil, err + } + cancel = nil + return s, nil +} + +// Send sends a request on a pooled stream and returns the response in a unary +// RPC fashion. If no stream is available in the pool, a new stream is created. +func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) { + s := p.get() + if s == nil { + var err error + s, err = p.newPooledStream() + if err != nil { + var zero Resp + return zero, err + } + } + defer p.putIfNotClosed(s) + return s.Send(ctx, req) +} + +// BatchStreamPool is a streamPool specialized for BatchStreamClient streams. +type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse] + +// BatchStreamClient is a streamClient specialized for the BatchStream RPC. +// +//go:generate mockgen -destination=mocks_generated_test.go --package=. BatchStreamClient +type BatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse] + +// newBatchStream constructs a BatchStreamClient from a grpc.ClientConn. +func newBatchStream( + ctx context.Context, cc *grpc.ClientConn, opts ...grpc.CallOption, +) (BatchStreamClient, error) { + return kvpb.NewInternalClient(cc).BatchStream(ctx, opts...) +} diff --git a/pkg/rpc/stream_pool_test.go b/pkg/rpc/stream_pool_test.go new file mode 100644 index 000000000000..214fdf28ed7a --- /dev/null +++ b/pkg/rpc/stream_pool_test.go @@ -0,0 +1,416 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package rpc + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" +) + +type mockBatchStreamConstructor struct { + stream BatchStreamClient + streamErr error + streamCount int + lastStreamCtx *context.Context +} + +func (m *mockBatchStreamConstructor) newStream( + ctx context.Context, conn *grpc.ClientConn, option ...grpc.CallOption, +) (BatchStreamClient, error) { + m.streamCount++ + if m.lastStreamCtx != nil { + if m.streamCount != 1 { + panic("unexpected stream creation with non-nil lastStreamCtx") + } + *m.lastStreamCtx = ctx + } + return m.stream, m.streamErr +} + +func TestStreamPool_Basic(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(2) + stream.EXPECT().Recv().Return(&kvpb.BatchResponse{}, nil).Times(2) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + defer p.Close() + + // Exercise the pool. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + require.Equal(t, 1, stopper.NumTasks()) + require.Len(t, p.streams.s, 1) + + // Exercise the pool again. Should re-use the same connection. + resp, err = p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + require.Equal(t, 1, stopper.NumTasks()) + require.Len(t, p.streams.s, 1) +} + +func TestStreamPool_Multi(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + const num = 3 + sendC := make(chan struct{}) + recvC := make(chan struct{}) + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT(). + Send(gomock.Any()). + DoAndReturn(func(_ *kvpb.BatchRequest) error { + sendC <- struct{}{} + return nil + }). + Times(num) + stream.EXPECT(). + Recv(). + DoAndReturn(func() (*kvpb.BatchResponse, error) { + <-recvC + return &kvpb.BatchResponse{}, nil + }). + Times(num) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + defer p.Close() + + // Exercise the pool with concurrent requests, pausing between each one to + // wait for its request to have been sent. + var g errgroup.Group + for range num { + g.Go(func() error { + _, err := p.Send(ctx, &kvpb.BatchRequest{}) + return err + }) + <-sendC + } + + // Assert that all requests have been sent and are waiting for responses. The + // pool is empty at this point, as all streams are in use. + require.Equal(t, num, conn.streamCount) + require.Equal(t, num, stopper.NumTasks()) + require.Len(t, p.streams.s, 0) + + // Allow all requests to complete. + for range num { + recvC <- struct{}{} + } + require.NoError(t, g.Wait()) + + // All three streams should be returned to the pool. + require.Len(t, p.streams.s, num) +} + +func TestStreamPool_SendBeforeBind(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Times(0) + stream.EXPECT().Recv().Times(0) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + + // Exercise the pool before it is bound to a gRPC connection. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.Regexp(t, err, "streamPool not bound to a grpc.ClientConn") + require.Equal(t, 0, conn.streamCount) + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_SendError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sendErr := errors.New("test error") + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(sendErr).Times(1) + stream.EXPECT().Recv().Times(0) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the error. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, sendErr) + + // The stream should not be returned to the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_RecvError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + recvErr := errors.New("test error") + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + stream.EXPECT().Recv().Return(nil, recvErr).Times(1) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the error. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, recvErr) + + // The stream should not be returned to the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_NewStreamError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + streamErr := errors.New("test error") + conn := &mockBatchStreamConstructor{streamErr: streamErr} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the error. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, streamErr) + + // The stream should not be placed in the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_Cancellation(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sendC := make(chan struct{}) + var streamCtx context.Context + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT(). + Send(gomock.Any()). + DoAndReturn(func(_ *kvpb.BatchRequest) error { + sendC <- struct{}{} + return nil + }). + Times(1) + stream.EXPECT(). + Recv(). + DoAndReturn(func() (*kvpb.BatchResponse, error) { + <-streamCtx.Done() + return nil, streamCtx.Err() + }). + Times(1) + conn := &mockBatchStreamConstructor{stream: stream, lastStreamCtx: &streamCtx} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the request get stuck. + reqCtx, cancel := context.WithCancel(ctx) + type result struct { + resp *kvpb.BatchResponse + err error + } + resC := make(chan result) + go func() { + resp, err := p.Send(reqCtx, &kvpb.BatchRequest{}) + resC <- result{resp: resp, err: err} + }() + <-sendC + select { + case <-resC: + t.Fatal("unexpected result") + case <-time.After(10 * time.Millisecond): + } + + // Cancel the request and observe the result. + cancel() + res := <-resC + require.Nil(t, res.resp) + require.Error(t, res.err) + require.ErrorIs(t, res.err, context.Canceled) + + // The stream should not be returned to the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_Quiesce(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + stream.EXPECT().Recv().Return(&kvpb.BatchResponse{}, nil).Times(1) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + stopperCtx, _ := stopper.WithCancelOnQuiesce(ctx) + p.Bind(stopperCtx, new(grpc.ClientConn)) + + // Exercise the pool to create a worker goroutine. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + require.Equal(t, 1, stopper.NumTasks()) + require.Len(t, p.streams.s, 1) + + // Stop the stopper, which closes the pool. + stopper.Stop(ctx) + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_QuiesceDuringSend(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sendC := make(chan struct{}) + var streamCtx context.Context + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT(). + Send(gomock.Any()). + DoAndReturn(func(_ *kvpb.BatchRequest) error { + sendC <- struct{}{} + return nil + }). + Times(1) + stream.EXPECT(). + Recv(). + DoAndReturn(func() (*kvpb.BatchResponse, error) { + <-streamCtx.Done() + return nil, streamCtx.Err() + }). + Times(1) + conn := &mockBatchStreamConstructor{stream: stream, lastStreamCtx: &streamCtx} + p := makeStreamPool(stopper, conn.newStream) + stopperCtx, _ := stopper.WithCancelOnQuiesce(ctx) + p.Bind(stopperCtx, new(grpc.ClientConn)) + + // Exercise the pool and observe the request get stuck. + type result struct { + resp *kvpb.BatchResponse + err error + } + resC := make(chan result) + go func() { + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + resC <- result{resp: resp, err: err} + }() + <-sendC + select { + case <-resC: + t.Fatal("unexpected result") + case <-time.After(10 * time.Millisecond): + } + + // Stop the stopper, which cancels the request and closes the pool. + stopper.Stop(ctx) + require.Len(t, p.streams.s, 0) + res := <-resC + require.Nil(t, res.resp) + require.Error(t, res.err) + require.ErrorIs(t, res.err, context.Canceled) +} + +func TestStreamPool_IdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + stream.EXPECT().Recv().Return(&kvpb.BatchResponse{}, nil).Times(1) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + p.idleTimeout = 10 * time.Millisecond + + // Exercise the pool to create a worker goroutine. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + + // Eventually the worker should be stopped due to the idle timeout. + testutils.SucceedsSoon(t, func() error { + if stopper.NumTasks() != 0 { + return errors.New("worker not stopped") + } + return nil + }) + + // Once the worker is stopped, the pool should be empty. + require.Len(t, p.streams.s, 0) +} diff --git a/pkg/server/node.go b/pkg/server/node.go index bbe86bcd3244..5f5a969990a7 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "fmt" + "io" "math" "net" "sort" @@ -1872,6 +1873,31 @@ func (n *Node) Batch(ctx context.Context, args *kvpb.BatchRequest) (*kvpb.BatchR return br, nil } +// BatchStream implements the kvpb.InternalServer interface. +func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + ctx := stream.Context() + for { + args, err := stream.Recv() + if err != nil { + // From grpc.ServerStream.Recv: + // > It returns io.EOF when the client has performed a CloseSend. + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + br, err := n.Batch(ctx, args) + if err != nil { + return err + } + err = stream.Send(br) + if err != nil { + return err + } + } +} + // spanForRequest is the retval of setupSpanForIncomingRPC. It groups together a // few variables needed when finishing an RPC's span. // diff --git a/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go b/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go index 949b27c2e1b9..8fee33b68b6e 100644 --- a/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go +++ b/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go @@ -46,6 +46,9 @@ func setGRPCErrorTag(sp *tracing.Span, err error) { // BatchMethodName is the method name of Internal.Batch RPC. const BatchMethodName = "/cockroach.roachpb.Internal/Batch" +// BatchStreamMethodName is the method name of the Internal.BatchStream RPC. +const BatchStreamMethodName = "/cockroach.roachpb.Internal/BatchStream" + // sendKVBatchMethodName is the method name for adminServer.SendKVBatch. const sendKVBatchMethodName = "/cockroach.server.serverpb.Admin/SendKVBatch" @@ -61,6 +64,7 @@ const flowStreamMethodName = "/cockroach.sql.distsqlrun.DistSQL/FlowStream" // tracing because it's not worth it. func methodExcludedFromTracing(method string) bool { return method == BatchMethodName || + method == BatchStreamMethodName || method == sendKVBatchMethodName || method == SetupFlowMethodName || method == flowStreamMethodName From a8c78e55df6f162bbc7ad81ef721994ea1b11052 Mon Sep 17 00:00:00 2001 From: Nathan VanBenschoten Date: Mon, 9 Dec 2024 17:42:29 -0500 Subject: [PATCH 2/3] rpc: avoid heap allocations of InternalClient objects MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit rearranges the InternalClient implementations returned by Dialer.DialInternalClient to avoid multiple heap allocations. With gRPC stream pooling: ``` name old time/op new time/op delta Sysbench/KV/1node_remote/oltp_point_select-10 30.7µs ± 6% 30.3µs ± 4% ~ (p=0.436 n=9+9) Sysbench/KV/1node_remote/oltp_read_only-10 743µs ± 3% 736µs ± 1% ~ (p=0.721 n=8+8) Sysbench/KV/1node_remote/oltp_read_write-10 1.50ms ± 3% 1.51ms ± 6% ~ (p=0.730 n=9+9) name old alloc/op new alloc/op delta Sysbench/KV/1node_remote/oltp_point_select-10 6.61kB ± 0% 6.58kB ± 0% -0.49% (p=0.000 n=9+8) Sysbench/KV/1node_remote/oltp_read_only-10 656kB ± 0% 656kB ± 0% ~ (p=0.743 n=8+9) Sysbench/KV/1node_remote/oltp_read_write-10 883kB ± 0% 882kB ± 1% ~ (p=0.143 n=10+10) name old allocs/op new allocs/op delta Sysbench/KV/1node_remote/oltp_point_select-10 63.0 ± 0% 61.0 ± 0% -3.17% (p=0.000 n=10+10) Sysbench/KV/1node_remote/oltp_read_only-10 1.90k ± 0% 1.87k ± 0% -1.49% (p=0.000 n=10+8) Sysbench/KV/1node_remote/oltp_read_write-10 3.51k ± 0% 3.46k ± 0% -1.35% (p=0.000 n=9+8) ``` Without gRPC stream pooling: ``` name old time/op new time/op delta Sysbench/KV/1node_remote/oltp_read_write-10 1.88ms ± 3% 1.83ms ± 2% -2.16% (p=0.029 n=10+10) Sysbench/KV/1node_remote/oltp_point_select-10 46.9µs ± 1% 46.7µs ± 1% -0.49% (p=0.022 n=9+10) Sysbench/KV/1node_remote/oltp_read_only-10 976µs ± 2% 973µs ± 4% ~ (p=0.497 n=9+10) name old alloc/op new alloc/op delta Sysbench/KV/1node_remote/oltp_point_select-10 16.3kB ± 0% 16.2kB ± 0% -0.06% (p=0.000 n=8+7) Sysbench/KV/1node_remote/oltp_read_only-10 788kB ± 0% 788kB ± 0% ~ (p=0.353 n=10+10) Sysbench/KV/1node_remote/oltp_read_write-10 1.11MB ± 0% 1.11MB ± 0% ~ (p=0.436 n=10+10) name old allocs/op new allocs/op delta Sysbench/KV/1node_remote/oltp_point_select-10 209 ± 0% 208 ± 0% -0.48% (p=0.000 n=10+10) Sysbench/KV/1node_remote/oltp_read_write-10 7.07k ± 0% 7.04k ± 0% -0.36% (p=0.000 n=10+10) Sysbench/KV/1node_remote/oltp_read_only-10 3.95k ± 0% 3.94k ± 0% -0.35% (p=0.008 n=6+7) ``` Epic: None Release note: None --- pkg/rpc/nodedialer/nodedialer.go | 100 ++++++++++++++++++++++--------- pkg/rpc/stream_pool.go | 5 ++ 2 files changed, 77 insertions(+), 28 deletions(-) diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index ea444f668031..4bebfd4f5dd4 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -151,8 +151,10 @@ func (n *Dialer) DialInternalClient( if err != nil { return nil, err } - client := kvpb.NewInternalClient(conn) - client = maybeWrapInBatchStreamPoolClient(ctx, n.rpcContext.Settings, client, pool) + client := newBaseInternalClient(conn) + if shouldUseBatchStreamPoolClient(ctx, n.rpcContext.Settings) { + client = newBatchStreamPoolClient(pool) + } client = maybeWrapInTracingClient(ctx, client) return client, nil } @@ -282,6 +284,35 @@ func (n *Dialer) Latency(nodeID roachpb.NodeID) (time.Duration, error) { return latency, nil } +// baseInternalClient is a wrapper around a grpc.ClientConn that implements the +// RestrictedInternalClient interface. By calling kvpb.NewInternalClient on each +// RPC invocation, that function can be inlined and the returned internalClient +// object (which itself is just a wrapper) never needs to be allocated on the +// heap. +type baseInternalClient grpc.ClientConn + +func newBaseInternalClient(conn *grpc.ClientConn) rpc.RestrictedInternalClient { + return (*baseInternalClient)(conn) +} + +func (c *baseInternalClient) asConn() *grpc.ClientConn { + return (*grpc.ClientConn)(c) +} + +// Batch implements the RestrictedInternalClient interface. +func (c *baseInternalClient) Batch( + ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, +) (*kvpb.BatchResponse, error) { + return kvpb.NewInternalClient(c.asConn()).Batch(ctx, ba, opts...) +} + +// MuxRangeFeed implements the RestrictedInternalClient interface. +func (c *baseInternalClient) MuxRangeFeed( + ctx context.Context, opts ...grpc.CallOption, +) (kvpb.Internal_MuxRangeFeedClient, error) { + return kvpb.NewInternalClient(c.asConn()).MuxRangeFeed(ctx, opts...) +} + var batchStreamPoolingEnabled = settings.RegisterBoolSetting( settings.ApplicationLevel, "rpc.batch_stream_pool.enabled", @@ -289,59 +320,72 @@ var batchStreamPoolingEnabled = settings.RegisterBoolSetting( metamorphic.ConstantWithTestBool("rpc.batch_stream_pool.enabled", true), ) -// batchStreamPoolClient is a client that sends Batch RPCs using a pooled -// BatchStream RPC stream. Pooling these streams allows for reuse of gRPC -// resources, as opposed to native unary RPCs, which create a new stream and -// throw it away for each unary request (see grpc.invoke). -type batchStreamPoolClient struct { - kvpb.InternalClient - pool *rpc.BatchStreamPool -} - -func maybeWrapInBatchStreamPoolClient( - ctx context.Context, st *cluster.Settings, client kvpb.InternalClient, pool *rpc.BatchStreamPool, -) kvpb.InternalClient { +func shouldUseBatchStreamPoolClient(ctx context.Context, st *cluster.Settings) bool { // NOTE: we use ActiveVersionOrEmpty(ctx).IsActive(...) instead of the more // common IsActive(ctx, ...) to avoid a fatal error if an RPC is made before // the cluster version is initialized. if !st.Version.ActiveVersionOrEmpty(ctx).IsActive(clusterversion.V25_1_BatchStreamRPC) { - return client + return false } if !batchStreamPoolingEnabled.Get(&st.SV) { - return client + return false } - return &batchStreamPoolClient{InternalClient: client, pool: pool} + return true +} + +// batchStreamPoolClient is a client that sends Batch RPCs using a pooled +// BatchStream RPC stream. Pooling these streams allows for reuse of gRPC +// resources, as opposed to native unary RPCs, which create a new stream and +// throw it away for each unary request (see grpc.invoke). +type batchStreamPoolClient rpc.BatchStreamPool + +func newBatchStreamPoolClient(pool *rpc.BatchStreamPool) rpc.RestrictedInternalClient { + return (*batchStreamPoolClient)(pool) } -// Batch overrides the Batch RPC client method to use the BatchStreamPool. -func (bsc *batchStreamPoolClient) Batch( +func (c *batchStreamPoolClient) asPool() *rpc.BatchStreamPool { + return (*rpc.BatchStreamPool)(c) +} + +// Batch implements the RestrictedInternalClient interface, using the pooled +// streams in the BatchStreamPool to issue the Batch RPC. +func (c *batchStreamPoolClient) Batch( ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, ) (*kvpb.BatchResponse, error) { if len(opts) > 0 { - return nil, errors.AssertionFailedf("BatchStreamPoolClient.Batch does not support CallOptions") + return nil, errors.AssertionFailedf("batchStreamPoolClient.Batch does not support CallOptions") } - return bsc.pool.Send(ctx, ba) + return c.asPool().Send(ctx, ba) +} + +// MuxRangeFeed implements the RestrictedInternalClient interface. +func (c *batchStreamPoolClient) MuxRangeFeed( + ctx context.Context, opts ...grpc.CallOption, +) (kvpb.Internal_MuxRangeFeedClient, error) { + return kvpb.NewInternalClient(c.asPool().Conn()).MuxRangeFeed(ctx, opts...) } -// tracingInternalClient wraps an InternalClient and fills in trace information -// on Batch RPCs. +// tracingInternalClient wraps a RestrictedInternalClient and fills in trace +// information on Batch RPCs. // // Note that tracingInternalClient is not used to wrap the internalClientAdapter // - local RPCs don't need this tracing functionality. type tracingInternalClient struct { - kvpb.InternalClient + rpc.RestrictedInternalClient } -func maybeWrapInTracingClient(ctx context.Context, client kvpb.InternalClient) kvpb.InternalClient { +func maybeWrapInTracingClient( + ctx context.Context, client rpc.RestrictedInternalClient, +) rpc.RestrictedInternalClient { sp := tracing.SpanFromContext(ctx) if sp != nil && !sp.IsNoop() { - client = &tracingInternalClient{InternalClient: client} + return &tracingInternalClient{RestrictedInternalClient: client} } return client } // Batch overrides the Batch RPC client method and fills in tracing information. -func (tic *tracingInternalClient) Batch( +func (c *tracingInternalClient) Batch( ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, ) (*kvpb.BatchResponse, error) { sp := tracing.SpanFromContext(ctx) @@ -349,5 +393,5 @@ func (tic *tracingInternalClient) Batch( ba = ba.ShallowCopy() ba.TraceInfo = sp.Meta().ToProto() } - return tic.InternalClient.Batch(ctx, ba, opts...) + return c.RestrictedInternalClient.Batch(ctx, ba, opts...) } diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go index 7be2e9165887..890111f166c0 100644 --- a/pkg/rpc/stream_pool.go +++ b/pkg/rpc/stream_pool.go @@ -223,6 +223,11 @@ func (p *streamPool[Req, Resp]) Bind(ctx context.Context, cc *grpc.ClientConn) { p.ccCtx = ctx } +// Conn returns the gRPC connection bound to the streamPool. +func (p *streamPool[Req, Resp]) Conn() *grpc.ClientConn { + return p.cc +} + // Close closes all streams in the pool. func (p *streamPool[Req, Resp]) Close() { p.streams.Lock() From db408d253b5044bda257b0b1cf88468426622e5e Mon Sep 17 00:00:00 2001 From: Tobias Grieger Date: Mon, 9 Dec 2024 11:34:41 +0100 Subject: [PATCH 3/3] rpc: make batch stream pool general over Conn This will help prototype drpc stream pooling. --- pkg/rpc/mocks_generated_test.go | 86 --------------------------------- pkg/rpc/stream_pool.go | 70 +++++++++++++-------------- pkg/rpc/stream_pool_test.go | 4 +- 3 files changed, 36 insertions(+), 124 deletions(-) diff --git a/pkg/rpc/mocks_generated_test.go b/pkg/rpc/mocks_generated_test.go index ea3f1f7bdc2c..efa2c112a466 100644 --- a/pkg/rpc/mocks_generated_test.go +++ b/pkg/rpc/mocks_generated_test.go @@ -13,7 +13,6 @@ import ( rpcpb "github.com/cockroachdb/cockroach/pkg/rpc/rpcpb" gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" - metadata "google.golang.org/grpc/metadata" ) // MockBatchStreamClient is a mock of BatchStreamClient interface. @@ -39,49 +38,6 @@ func (m *MockBatchStreamClient) EXPECT() *MockBatchStreamClientMockRecorder { return m.recorder } -// CloseSend mocks base method. -func (m *MockBatchStreamClient) CloseSend() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloseSend") - ret0, _ := ret[0].(error) - return ret0 -} - -// CloseSend indicates an expected call of CloseSend. -func (mr *MockBatchStreamClientMockRecorder) CloseSend() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockBatchStreamClient)(nil).CloseSend)) -} - -// Context mocks base method. -func (m *MockBatchStreamClient) Context() context.Context { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Context") - ret0, _ := ret[0].(context.Context) - return ret0 -} - -// Context indicates an expected call of Context. -func (mr *MockBatchStreamClientMockRecorder) Context() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockBatchStreamClient)(nil).Context)) -} - -// Header mocks base method. -func (m *MockBatchStreamClient) Header() (metadata.MD, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Header") - ret0, _ := ret[0].(metadata.MD) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Header indicates an expected call of Header. -func (mr *MockBatchStreamClientMockRecorder) Header() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockBatchStreamClient)(nil).Header)) -} - // Recv mocks base method. func (m *MockBatchStreamClient) Recv() (*kvpb.BatchResponse, error) { m.ctrl.T.Helper() @@ -97,20 +53,6 @@ func (mr *MockBatchStreamClientMockRecorder) Recv() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockBatchStreamClient)(nil).Recv)) } -// RecvMsg mocks base method. -func (m *MockBatchStreamClient) RecvMsg(arg0 interface{}) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RecvMsg", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// RecvMsg indicates an expected call of RecvMsg. -func (mr *MockBatchStreamClientMockRecorder) RecvMsg(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockBatchStreamClient)(nil).RecvMsg), arg0) -} - // Send mocks base method. func (m *MockBatchStreamClient) Send(arg0 *kvpb.BatchRequest) error { m.ctrl.T.Helper() @@ -125,34 +67,6 @@ func (mr *MockBatchStreamClientMockRecorder) Send(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockBatchStreamClient)(nil).Send), arg0) } -// SendMsg mocks base method. -func (m *MockBatchStreamClient) SendMsg(arg0 interface{}) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMsg", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SendMsg indicates an expected call of SendMsg. -func (mr *MockBatchStreamClientMockRecorder) SendMsg(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockBatchStreamClient)(nil).SendMsg), arg0) -} - -// Trailer mocks base method. -func (m *MockBatchStreamClient) Trailer() metadata.MD { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Trailer") - ret0, _ := ret[0].(metadata.MD) - return ret0 -} - -// Trailer indicates an expected call of Trailer. -func (mr *MockBatchStreamClientMockRecorder) Trailer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockBatchStreamClient)(nil).Trailer)) -} - // MockDialbacker is a mock of Dialbacker interface. type MockDialbacker struct { ctrl *gomock.Controller diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go index 890111f166c0..20773d714f0a 100644 --- a/pkg/rpc/stream_pool.go +++ b/pkg/rpc/stream_pool.go @@ -24,13 +24,12 @@ import ( type streamClient[Req, Resp any] interface { Send(Req) error Recv() (Resp, error) - grpc.ClientStream } // streamConstructor creates a new gRPC stream client over the provided client // connection, using the provided call options. -type streamConstructor[Req, Resp any] func( - context.Context, *grpc.ClientConn, ...grpc.CallOption, +type streamConstructor[Req, Resp, Conn any] func( + context.Context, Conn, ) (streamClient[Req, Resp], error) type result[Resp any] struct { @@ -67,8 +66,8 @@ const defaultPooledStreamIdleTimeout = 10 * time.Second // // A pooledStream must only be returned to the pool for reuse after a successful // Send call. If the Send call fails, the pooledStream must not be reused. -type pooledStream[Req, Resp any] struct { - pool *streamPool[Req, Resp] +type pooledStream[Req, Resp any, Conn comparable] struct { + pool *streamPool[Req, Resp, Conn] stream streamClient[Req, Resp] streamCtx context.Context streamCancel context.CancelFunc @@ -77,13 +76,13 @@ type pooledStream[Req, Resp any] struct { respC chan result[Resp] } -func newPooledStream[Req, Resp any]( - pool *streamPool[Req, Resp], +func newPooledStream[Req, Resp any, Conn comparable]( + pool *streamPool[Req, Resp, Conn], stream streamClient[Req, Resp], streamCtx context.Context, streamCancel context.CancelFunc, -) *pooledStream[Req, Resp] { - return &pooledStream[Req, Resp]{ +) *pooledStream[Req, Resp, Conn] { + return &pooledStream[Req, Resp, Conn]{ pool: pool, stream: stream, streamCtx: streamCtx, @@ -93,13 +92,13 @@ func newPooledStream[Req, Resp any]( } } -func (s *pooledStream[Req, Resp]) run(ctx context.Context) { +func (s *pooledStream[Req, Resp, Conn]) run(ctx context.Context) { defer s.close() for s.runOnce(ctx) { } } -func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) { +func (s *pooledStream[Req, Resp, Conn]) runOnce(ctx context.Context) (loop bool) { select { case req := <-s.reqC: err := s.stream.Send(req) @@ -137,7 +136,7 @@ func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) { } } -func (s *pooledStream[Req, Resp]) close() { +func (s *pooledStream[Req, Resp, Conn]) close() { // Make sure the stream's context is canceled to ensure that we clean up // resources in idle timeout case. // @@ -156,7 +155,7 @@ func (s *pooledStream[Req, Resp]) close() { // Send sends a request on the pooled stream and returns the response in a unary // RPC fashion. Context cancellation is respected. -func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) { +func (s *pooledStream[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) { var resp result[Resp] select { case s.reqC <- req: @@ -190,26 +189,26 @@ func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, erro // manner that mimics unary RPC invocation. Pooling these streams allows for // reuse of gRPC resources across calls, as opposed to native unary RPCs, which // create a new stream and throw it away for each request (see grpc.invoke). -type streamPool[Req, Resp any] struct { +type streamPool[Req, Resp any, Conn comparable] struct { stopper *stop.Stopper idleTimeout time.Duration - newStream streamConstructor[Req, Resp] + newStream streamConstructor[Req, Resp, Conn] // cc and ccCtx are set on bind, when the gRPC connection is established. - cc *grpc.ClientConn + cc Conn // Derived from rpc.Context.MasterCtx, canceled on stopper quiesce. ccCtx context.Context streams struct { syncutil.Mutex - s []*pooledStream[Req, Resp] + s []*pooledStream[Req, Resp, Conn] } } -func makeStreamPool[Req, Resp any]( - stopper *stop.Stopper, newStream streamConstructor[Req, Resp], -) streamPool[Req, Resp] { - return streamPool[Req, Resp]{ +func makeStreamPool[Req, Resp any, Conn comparable]( + stopper *stop.Stopper, newStream streamConstructor[Req, Resp, Conn], +) streamPool[Req, Resp, Conn] { + return streamPool[Req, Resp, Conn]{ stopper: stopper, idleTimeout: defaultPooledStreamIdleTimeout, newStream: newStream, @@ -218,18 +217,18 @@ func makeStreamPool[Req, Resp any]( // Bind sets the gRPC connection and context for the streamPool. This must be // called once before streamPool.Send. -func (p *streamPool[Req, Resp]) Bind(ctx context.Context, cc *grpc.ClientConn) { +func (p *streamPool[Req, Resp, Conn]) Bind(ctx context.Context, cc Conn) { p.cc = cc p.ccCtx = ctx } // Conn returns the gRPC connection bound to the streamPool. -func (p *streamPool[Req, Resp]) Conn() *grpc.ClientConn { +func (p *streamPool[Req, Resp, Conn]) Conn() Conn { return p.cc } // Close closes all streams in the pool. -func (p *streamPool[Req, Resp]) Close() { +func (p *streamPool[Req, Resp, Conn]) Close() { p.streams.Lock() defer p.streams.Unlock() for _, s := range p.streams.s { @@ -238,7 +237,7 @@ func (p *streamPool[Req, Resp]) Close() { p.streams.s = nil } -func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] { +func (p *streamPool[Req, Resp, Conn]) get() *pooledStream[Req, Resp, Conn] { p.streams.Lock() defer p.streams.Unlock() if len(p.streams.s) == 0 { @@ -253,7 +252,7 @@ func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] { return s } -func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) { +func (p *streamPool[Req, Resp, Conn]) putIfNotClosed(s *pooledStream[Req, Resp, Conn]) { p.streams.Lock() defer p.streams.Unlock() if s.streamCtx.Err() != nil { @@ -265,7 +264,7 @@ func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) { p.streams.s = append(p.streams.s, s) } -func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool { +func (p *streamPool[Req, Resp, Conn]) remove(s *pooledStream[Req, Resp, Conn]) bool { p.streams.Lock() defer p.streams.Unlock() i := slices.Index(p.streams.s, s) @@ -278,9 +277,10 @@ func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool { return true } -func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], error) { - if p.cc == nil { - return nil, errors.AssertionFailedf("streamPool not bound to a grpc.ClientConn") +func (p *streamPool[Req, Resp, Conn]) newPooledStream() (*pooledStream[Req, Resp, Conn], error) { + var zero Conn + if p.cc == zero { + return nil, errors.AssertionFailedf("streamPool not bound to a client conn") } ctx, cancel := context.WithCancel(p.ccCtx) @@ -305,7 +305,7 @@ func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], err // Send sends a request on a pooled stream and returns the response in a unary // RPC fashion. If no stream is available in the pool, a new stream is created. -func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) { +func (p *streamPool[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) { s := p.get() if s == nil { var err error @@ -320,7 +320,7 @@ func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) } // BatchStreamPool is a streamPool specialized for BatchStreamClient streams. -type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse] +type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse, *grpc.ClientConn] // BatchStreamClient is a streamClient specialized for the BatchStream RPC. // @@ -328,8 +328,6 @@ type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse] type BatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse] // newBatchStream constructs a BatchStreamClient from a grpc.ClientConn. -func newBatchStream( - ctx context.Context, cc *grpc.ClientConn, opts ...grpc.CallOption, -) (BatchStreamClient, error) { - return kvpb.NewInternalClient(cc).BatchStream(ctx, opts...) +func newBatchStream(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) { + return kvpb.NewInternalClient(cc).BatchStream(ctx) } diff --git a/pkg/rpc/stream_pool_test.go b/pkg/rpc/stream_pool_test.go index 214fdf28ed7a..192bda18b355 100644 --- a/pkg/rpc/stream_pool_test.go +++ b/pkg/rpc/stream_pool_test.go @@ -29,7 +29,7 @@ type mockBatchStreamConstructor struct { } func (m *mockBatchStreamConstructor) newStream( - ctx context.Context, conn *grpc.ClientConn, option ...grpc.CallOption, + ctx context.Context, conn *grpc.ClientConn, ) (BatchStreamClient, error) { m.streamCount++ if m.lastStreamCtx != nil { @@ -153,7 +153,7 @@ func TestStreamPool_SendBeforeBind(t *testing.T) { resp, err := p.Send(ctx, &kvpb.BatchRequest{}) require.Nil(t, resp) require.Error(t, err) - require.Regexp(t, err, "streamPool not bound to a grpc.ClientConn") + require.Regexp(t, err, "streamPool not bound to a client conn") require.Equal(t, 0, conn.streamCount) require.Len(t, p.streams.s, 0) }