From 7f4ffeb8e316bfc6e42b3b6ac065e7affc0cb5de Mon Sep 17 00:00:00 2001 From: Tobias Grieger Date: Mon, 9 Dec 2024 11:34:41 +0100 Subject: [PATCH] rpc: make batch stream pool general over Conn This will help prototype drpc stream pooling. --- pkg/rpc/stream_pool.go | 70 ++++++++++++++++++------------------- pkg/rpc/stream_pool_test.go | 4 +-- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go index 890111f166c0..20773d714f0a 100644 --- a/pkg/rpc/stream_pool.go +++ b/pkg/rpc/stream_pool.go @@ -24,13 +24,12 @@ import ( type streamClient[Req, Resp any] interface { Send(Req) error Recv() (Resp, error) - grpc.ClientStream } // streamConstructor creates a new gRPC stream client over the provided client // connection, using the provided call options. -type streamConstructor[Req, Resp any] func( - context.Context, *grpc.ClientConn, ...grpc.CallOption, +type streamConstructor[Req, Resp, Conn any] func( + context.Context, Conn, ) (streamClient[Req, Resp], error) type result[Resp any] struct { @@ -67,8 +66,8 @@ const defaultPooledStreamIdleTimeout = 10 * time.Second // // A pooledStream must only be returned to the pool for reuse after a successful // Send call. If the Send call fails, the pooledStream must not be reused. -type pooledStream[Req, Resp any] struct { - pool *streamPool[Req, Resp] +type pooledStream[Req, Resp any, Conn comparable] struct { + pool *streamPool[Req, Resp, Conn] stream streamClient[Req, Resp] streamCtx context.Context streamCancel context.CancelFunc @@ -77,13 +76,13 @@ type pooledStream[Req, Resp any] struct { respC chan result[Resp] } -func newPooledStream[Req, Resp any]( - pool *streamPool[Req, Resp], +func newPooledStream[Req, Resp any, Conn comparable]( + pool *streamPool[Req, Resp, Conn], stream streamClient[Req, Resp], streamCtx context.Context, streamCancel context.CancelFunc, -) *pooledStream[Req, Resp] { - return &pooledStream[Req, Resp]{ +) *pooledStream[Req, Resp, Conn] { + return &pooledStream[Req, Resp, Conn]{ pool: pool, stream: stream, streamCtx: streamCtx, @@ -93,13 +92,13 @@ func newPooledStream[Req, Resp any]( } } -func (s *pooledStream[Req, Resp]) run(ctx context.Context) { +func (s *pooledStream[Req, Resp, Conn]) run(ctx context.Context) { defer s.close() for s.runOnce(ctx) { } } -func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) { +func (s *pooledStream[Req, Resp, Conn]) runOnce(ctx context.Context) (loop bool) { select { case req := <-s.reqC: err := s.stream.Send(req) @@ -137,7 +136,7 @@ func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) { } } -func (s *pooledStream[Req, Resp]) close() { +func (s *pooledStream[Req, Resp, Conn]) close() { // Make sure the stream's context is canceled to ensure that we clean up // resources in idle timeout case. // @@ -156,7 +155,7 @@ func (s *pooledStream[Req, Resp]) close() { // Send sends a request on the pooled stream and returns the response in a unary // RPC fashion. Context cancellation is respected. -func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) { +func (s *pooledStream[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) { var resp result[Resp] select { case s.reqC <- req: @@ -190,26 +189,26 @@ func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, erro // manner that mimics unary RPC invocation. Pooling these streams allows for // reuse of gRPC resources across calls, as opposed to native unary RPCs, which // create a new stream and throw it away for each request (see grpc.invoke). -type streamPool[Req, Resp any] struct { +type streamPool[Req, Resp any, Conn comparable] struct { stopper *stop.Stopper idleTimeout time.Duration - newStream streamConstructor[Req, Resp] + newStream streamConstructor[Req, Resp, Conn] // cc and ccCtx are set on bind, when the gRPC connection is established. - cc *grpc.ClientConn + cc Conn // Derived from rpc.Context.MasterCtx, canceled on stopper quiesce. ccCtx context.Context streams struct { syncutil.Mutex - s []*pooledStream[Req, Resp] + s []*pooledStream[Req, Resp, Conn] } } -func makeStreamPool[Req, Resp any]( - stopper *stop.Stopper, newStream streamConstructor[Req, Resp], -) streamPool[Req, Resp] { - return streamPool[Req, Resp]{ +func makeStreamPool[Req, Resp any, Conn comparable]( + stopper *stop.Stopper, newStream streamConstructor[Req, Resp, Conn], +) streamPool[Req, Resp, Conn] { + return streamPool[Req, Resp, Conn]{ stopper: stopper, idleTimeout: defaultPooledStreamIdleTimeout, newStream: newStream, @@ -218,18 +217,18 @@ func makeStreamPool[Req, Resp any]( // Bind sets the gRPC connection and context for the streamPool. This must be // called once before streamPool.Send. -func (p *streamPool[Req, Resp]) Bind(ctx context.Context, cc *grpc.ClientConn) { +func (p *streamPool[Req, Resp, Conn]) Bind(ctx context.Context, cc Conn) { p.cc = cc p.ccCtx = ctx } // Conn returns the gRPC connection bound to the streamPool. -func (p *streamPool[Req, Resp]) Conn() *grpc.ClientConn { +func (p *streamPool[Req, Resp, Conn]) Conn() Conn { return p.cc } // Close closes all streams in the pool. -func (p *streamPool[Req, Resp]) Close() { +func (p *streamPool[Req, Resp, Conn]) Close() { p.streams.Lock() defer p.streams.Unlock() for _, s := range p.streams.s { @@ -238,7 +237,7 @@ func (p *streamPool[Req, Resp]) Close() { p.streams.s = nil } -func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] { +func (p *streamPool[Req, Resp, Conn]) get() *pooledStream[Req, Resp, Conn] { p.streams.Lock() defer p.streams.Unlock() if len(p.streams.s) == 0 { @@ -253,7 +252,7 @@ func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] { return s } -func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) { +func (p *streamPool[Req, Resp, Conn]) putIfNotClosed(s *pooledStream[Req, Resp, Conn]) { p.streams.Lock() defer p.streams.Unlock() if s.streamCtx.Err() != nil { @@ -265,7 +264,7 @@ func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) { p.streams.s = append(p.streams.s, s) } -func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool { +func (p *streamPool[Req, Resp, Conn]) remove(s *pooledStream[Req, Resp, Conn]) bool { p.streams.Lock() defer p.streams.Unlock() i := slices.Index(p.streams.s, s) @@ -278,9 +277,10 @@ func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool { return true } -func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], error) { - if p.cc == nil { - return nil, errors.AssertionFailedf("streamPool not bound to a grpc.ClientConn") +func (p *streamPool[Req, Resp, Conn]) newPooledStream() (*pooledStream[Req, Resp, Conn], error) { + var zero Conn + if p.cc == zero { + return nil, errors.AssertionFailedf("streamPool not bound to a client conn") } ctx, cancel := context.WithCancel(p.ccCtx) @@ -305,7 +305,7 @@ func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], err // Send sends a request on a pooled stream and returns the response in a unary // RPC fashion. If no stream is available in the pool, a new stream is created. -func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) { +func (p *streamPool[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) { s := p.get() if s == nil { var err error @@ -320,7 +320,7 @@ func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) } // BatchStreamPool is a streamPool specialized for BatchStreamClient streams. -type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse] +type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse, *grpc.ClientConn] // BatchStreamClient is a streamClient specialized for the BatchStream RPC. // @@ -328,8 +328,6 @@ type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse] type BatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse] // newBatchStream constructs a BatchStreamClient from a grpc.ClientConn. -func newBatchStream( - ctx context.Context, cc *grpc.ClientConn, opts ...grpc.CallOption, -) (BatchStreamClient, error) { - return kvpb.NewInternalClient(cc).BatchStream(ctx, opts...) +func newBatchStream(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) { + return kvpb.NewInternalClient(cc).BatchStream(ctx) } diff --git a/pkg/rpc/stream_pool_test.go b/pkg/rpc/stream_pool_test.go index 214fdf28ed7a..192bda18b355 100644 --- a/pkg/rpc/stream_pool_test.go +++ b/pkg/rpc/stream_pool_test.go @@ -29,7 +29,7 @@ type mockBatchStreamConstructor struct { } func (m *mockBatchStreamConstructor) newStream( - ctx context.Context, conn *grpc.ClientConn, option ...grpc.CallOption, + ctx context.Context, conn *grpc.ClientConn, ) (BatchStreamClient, error) { m.streamCount++ if m.lastStreamCtx != nil { @@ -153,7 +153,7 @@ func TestStreamPool_SendBeforeBind(t *testing.T) { resp, err := p.Send(ctx, &kvpb.BatchRequest{}) require.Nil(t, resp) require.Error(t, err) - require.Regexp(t, err, "streamPool not bound to a grpc.ClientConn") + require.Regexp(t, err, "streamPool not bound to a client conn") require.Equal(t, 0, conn.streamCount) require.Len(t, p.streams.s, 0) }