From c5040f9d5bb185dda254e89d9898abae37357c40 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:21:32 -0500 Subject: [PATCH 01/13] Add Schema field to Spec for introspection New field Schema of type any on Spec objects. For proto based schemas the type will be of protoreflect.MethodDescriptor. This allows for easy introspection to interceptors. --- client.go | 2 + client_ext_test.go | 121 ++++++++++++++++++ cmd/protoc-gen-connect-go/main.go | 45 +++---- connect.go | 8 +- handler.go | 2 + .../v1/collidev1connect/collide.connect.go | 8 +- .../v1/importv1connect/import.connect.go | 2 +- .../ping/v1/pingv1connect/ping.connect.go | 28 ++-- option.go | 22 ++++ 9 files changed, 194 insertions(+), 44 deletions(-) diff --git a/client.go b/client.go index 38bb541b..2abcaa5e 100644 --- a/client.go +++ b/client.go @@ -189,6 +189,7 @@ type clientConfig struct { URL *url.URL Protocol protocol Procedure string + Schema any CompressMinBytes int Interceptor Interceptor CompressionPools map[string]*compressionPool @@ -251,6 +252,7 @@ func (c *clientConfig) newSpec(t StreamType) Spec { return Spec{ StreamType: t, Procedure: c.Procedure, + Schema: c.Schema, IsClient: true, IdempotencyLevel: c.IdempotencyLevel, } diff --git a/client_ext_test.go b/client_ext_test.go index ce799958..e5f6995a 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -17,6 +17,7 @@ package connect_test import ( "context" "errors" + "fmt" "net/http" "strings" "testing" @@ -26,6 +27,7 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" + "google.golang.org/protobuf/reflect/protoreflect" ) func TestNewClient_InitFailure(t *testing.T) { @@ -186,6 +188,81 @@ func TestGetNotModified(t *testing.T) { assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) } +func TestSpecSchema(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + connect.WithInterceptors(&assertSchemaInterceptor{t}), + )) + server := memhttptest.NewServer(t, mux) + testcases := []struct { + name string + opts []connect.ClientOption + }{{ + name: connect.ProtocolConnect, + }, { + name: connect.ProtocolGRPC, + opts: []connect.ClientOption{ + connect.WithGRPC(), + }, + }, { + name: connect.ProtocolGRPCWeb, + opts: []connect.ClientOption{ + connect.WithGRPC(), + }, + }} + for _, testcase := range testcases { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + ctx := context.Background() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + connect.WithClientOptions(testcase.opts...), + connect.WithInterceptors(&assertSchemaInterceptor{t}), + ) + t.Parallel() + t.Run("unary", func(t *testing.T) { + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.Nil(t, err) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) + }) + t.Run("client_stream", func(t *testing.T) { + clientStream := client.Sum(ctx) + t.Cleanup(func() { + _, closeErr := clientStream.CloseAndReceive() + assert.Nil(t, closeErr) + }) + assert.NotZero(t, clientStream.Spec().Schema) + err := clientStream.Send(&pingv1.SumRequest{}) + assert.Nil(t, err) + }) + t.Run("server_stream", func(t *testing.T) { + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + t.Cleanup(func() { + assert.Nil(t, serverStream.Close()) + }) + assert.Nil(t, err) + }) + t.Run("bidi_stream", func(t *testing.T) { + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) + }) + assert.NotZero(t, bidiStream.Spec().Schema) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) + }) + }) + } +} + type notModifiedPingServer struct { pingv1connect.UnimplementedPingServiceHandler @@ -233,3 +310,47 @@ func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandl return next(ctx, conn) } } + +type assertSchemaInterceptor struct { + tb testing.TB +} + +func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + if !assert.NotNil(a.tb, req.Spec().Schema) { + return nil, fmt.Errorf("nil spec") + } + methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor) + assert.True(a.tb, ok) + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, req.Spec().Procedure) + return next(ctx, req) + } +} + +func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + conn := next(ctx, spec) + if !assert.NotNil(a.tb, spec.Schema) { + return conn + } + methodDescriptor, ok := spec.Schema.(protoreflect.MethodDescriptor) + assert.True(a.tb, ok) + procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name()) + assert.Equal(a.tb, procedure, spec.Procedure) + return conn + } +} + +func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + if !assert.NotNil(a.tb, conn.Spec().Schema) { + return fmt.Errorf("nil spec") + } + methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor) + assert.True(a.tb, ok) + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, conn.Spec().Procedure) + return next(ctx, conn) + } +} diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index bde19f51..4fda5b2a 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -102,17 +102,6 @@ func main() { ) } -func needsWithIdempotency(file *protogen.File) bool { - for _, service := range file.Services { - for _, method := range service.Methods { - if methodIdempotency(method) != connect.IdempotencyUnknown { - return true - } - } - } - return false -} - func generate(plugin *protogen.Plugin, file *protogen.File) { if len(file.Services) == 0 { return @@ -136,7 +125,7 @@ func generate(plugin *protogen.Plugin, file *protogen.File) { generatePreamble(generatedFile, file) generateServiceNameConstants(generatedFile, file.Services) for _, service := range file.Services { - generateService(generatedFile, service) + generateService(generatedFile, file, service) } } @@ -180,11 +169,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) { "is not defined, this code was generated with a version of connect newer than the one ", "compiled into your binary. You can fix the problem by either regenerating this code ", "with an older version of connect or updating the connect version compiled into your binary.") - if needsWithIdempotency(file) { - g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_7_0")) - } else { - g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_1_0")) - } + g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_13_0")) g.P() } @@ -225,12 +210,12 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge g.P() } -func generateService(g *protogen.GeneratedFile, service *protogen.Service) { +func generateService(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service) { names := newNames(service) generateClientInterface(g, service, names) - generateClientImplementation(g, service, names) + generateClientImplementation(g, file, service, names) generateServerInterface(g, service, names) - generateServerConstructor(g, service, names) + generateServerConstructor(g, file, service, names) generateUnimplementedServerImplementation(g, service, names) } @@ -255,7 +240,7 @@ func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Servic g.P() } -func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, names names) { clientOption := connectPackage.Ident("ClientOption") // Client constructor. @@ -283,17 +268,19 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S ) g.P("httpClient,") g.P(`baseURL + `, procedureConstName(method), `,`) + g.P(connectPackage.Ident("WithSchema"), "(", + g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`, + `.Methods().ByName("`, method.Desc.Name(), `")),`) idempotency := methodIdempotency(method) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") - g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") case connect.IdempotencyIdempotent: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") - g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") case connect.IdempotencyUnknown: - g.P("opts...,") } + g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") g.P("),") } g.P("}") @@ -390,7 +377,7 @@ func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Servic g.P() } -func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, names names) { wrapComments(g, names.ServerConstructor, " builds an HTTP handler from the service implementation.", " It returns the path on which to mount the handler and the handler itself.") g.P("//") @@ -419,16 +406,18 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv } g.P(procedureConstName(method), `,`) g.P("svc.", method.GoName, ",") + g.P(connectPackage.Ident("WithSchema"), "(", + g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`, + `.Methods().ByName("`, method.Desc.Name(), `")),`) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") - g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") case connect.IdempotencyIdempotent: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") - g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") case connect.IdempotencyUnknown: - g.P("opts...,") } + g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") g.P(")") } g.P(`return "/`, service.Desc.FullName(), `/", `, httpPackage.Ident("HandlerFunc"), `(func(w `, httpPackage.Ident("ResponseWriter"), `, r *`, httpPackage.Ident("Request"), `){`) diff --git a/connect.go b/connect.go index c7c41d38..622852fb 100644 --- a/connect.go +++ b/connect.go @@ -38,9 +38,10 @@ const Version = "1.13.0-dev" // These constants are used in compile-time handshakes with connect's generated // code. const ( - IsAtLeastVersion0_0_1 = true - IsAtLeastVersion0_1_0 = true - IsAtLeastVersion1_7_0 = true + IsAtLeastVersion0_0_1 = true + IsAtLeastVersion0_1_0 = true + IsAtLeastVersion1_7_0 = true + IsAtLeastVersion1_13_0 = true ) // StreamType describes whether the client, server, neither, or both is @@ -314,6 +315,7 @@ type HTTPClient interface { // fully-qualified Procedure corresponding to each RPC in your schema. type Spec struct { StreamType StreamType + Schema any // for protobuf RPCs, a protoreflect.MethodDescriptor Procedure string // for example, "/acme.foo.v1.FooService/Bar" IsClient bool // otherwise we're in a handler IdempotencyLevel IdempotencyLevel diff --git a/handler.go b/handler.go index 43bfe973..b207ac66 100644 --- a/handler.go +++ b/handler.go @@ -246,6 +246,7 @@ type handlerConfig struct { CompressMinBytes int Interceptor Interceptor Procedure string + Schema any HandleGRPC bool HandleGRPCWeb bool RequireConnectProtocolHeader bool @@ -279,6 +280,7 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler func (c *handlerConfig) newSpec() Spec { return Spec{ Procedure: c.Procedure, + Schema: c.Schema, StreamType: c.StreamType, IdempotencyLevel: c.IdempotencyLevel, } diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 20223410..4b21ff23 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -32,7 +32,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion0_1_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // CollideServiceName is the fully-qualified name of the CollideService service. @@ -69,7 +69,8 @@ func NewCollideServiceClient(httpClient connect.HTTPClient, baseURL string, opts _import: connect.NewClient[v1.ImportRequest, v1.ImportResponse]( httpClient, baseURL+CollideServiceImportProcedure, - opts..., + connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithClientOptions(opts...), ), } } @@ -98,7 +99,8 @@ func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.Handler collideServiceImportHandler := connect.NewUnaryHandler( CollideServiceImportProcedure, svc.Import, - opts..., + connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithHandlerOptions(opts...), ) return "/connect.collide.v1.CollideService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index d56d3efa..cb3f2803 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -30,7 +30,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion0_1_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // ImportServiceName is the fully-qualified name of the ImportService service. diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 036539a8..84d01aa6 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -37,7 +37,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion1_7_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // PingServiceName is the fully-qualified name of the PingService service. @@ -91,28 +91,33 @@ func NewPingServiceClient(httpClient connect.HTTPClient, baseURL string, opts .. ping: connect.NewClient[v1.PingRequest, v1.PingResponse]( httpClient, baseURL+PingServicePingProcedure, + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithClientOptions(opts...), ), fail: connect.NewClient[v1.FailRequest, v1.FailResponse]( httpClient, baseURL+PingServiceFailProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithClientOptions(opts...), ), sum: connect.NewClient[v1.SumRequest, v1.SumResponse]( httpClient, baseURL+PingServiceSumProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithClientOptions(opts...), ), countUp: connect.NewClient[v1.CountUpRequest, v1.CountUpResponse]( httpClient, baseURL+PingServiceCountUpProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithClientOptions(opts...), ), cumSum: connect.NewClient[v1.CumSumRequest, v1.CumSumResponse]( httpClient, baseURL+PingServiceCumSumProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithClientOptions(opts...), ), } } @@ -174,28 +179,33 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption pingServicePingHandler := connect.NewUnaryHandler( PingServicePingProcedure, svc.Ping, + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithHandlerOptions(opts...), ) pingServiceFailHandler := connect.NewUnaryHandler( PingServiceFailProcedure, svc.Fail, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithHandlerOptions(opts...), ) pingServiceSumHandler := connect.NewClientStreamHandler( PingServiceSumProcedure, svc.Sum, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithHandlerOptions(opts...), ) pingServiceCountUpHandler := connect.NewServerStreamHandler( PingServiceCountUpProcedure, svc.CountUp, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithHandlerOptions(opts...), ) pingServiceCumSumHandler := connect.NewBidiStreamHandler( PingServiceCumSumProcedure, svc.CumSum, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithHandlerOptions(opts...), ) return "/connect.ping.v1.PingService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/option.go b/option.go index 1eb8d7c3..c5d15f05 100644 --- a/option.go +++ b/option.go @@ -184,6 +184,16 @@ type Option interface { HandlerOption } +// WithSchema provides a parsed representation of the schema for an RPC to a +// client or handler. The supplied schema is exposed as [Spec.Schema]. This +// option is typically added by generated code. +// +// For services using protobuf schemas, the supplied schema should be a +// [protoreflect.MethodDescriptor]. +func WithSchema(schema any) Option { + return &schemaOption{Schema: schema} +} + // WithCodec registers a serialization method with a client or handler. // Handlers may have multiple codecs registered, and use whichever the client // chooses. Clients may only have a single codec. @@ -328,6 +338,18 @@ func WithOptions(options ...Option) Option { return &optionsOption{options} } +type schemaOption struct { + Schema any +} + +func (o *schemaOption) applyToClient(config *clientConfig) { + config.Schema = o.Schema +} + +func (o *schemaOption) applyToHandler(config *handlerConfig) { + config.Schema = o.Schema +} + type clientOptionsOption struct { options []ClientOption } From 817b5142fe78adf1fab0b5d9dca80e78e5cff930 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:46:14 -0500 Subject: [PATCH 02/13] Use serviceDescriptor var --- cmd/protoc-gen-connect-go/main.go | 18 +++++++++------ .../v1/collidev1connect/collide.connect.go | 6 +++-- .../v1/importv1connect/import.connect.go | 2 -- .../ping/v1/pingv1connect/ping.connect.go | 22 ++++++++++--------- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 4fda5b2a..fe4d0f6c 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -258,7 +258,11 @@ func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File } g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"), ", baseURL string, opts ...", clientOption, ") ", names.Client, " {") - g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + if len(service.Methods) > 0 { + g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + g.P("serviceDescriptor := ", g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`) + } g.P("return &", names.ClientImpl, "{") for _, method := range service.Methods { g.P(unexport(method.GoName), ": ", @@ -269,9 +273,7 @@ func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File g.P("httpClient,") g.P(`baseURL + `, procedureConstName(method), `,`) g.P(connectPackage.Ident("WithSchema"), "(", - g.QualifiedGoIdent(file.GoDescriptorIdent), - `.Services().ByName("`, service.Desc.Name(), `")`, - `.Methods().ByName("`, method.Desc.Name(), `")),`) + `serviceDescriptor.Methods().ByName("`, method.Desc.Name(), `")),`) idempotency := methodIdempotency(method) switch idempotency { case connect.IdempotencyNoSideEffects: @@ -390,6 +392,10 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s handlerOption := connectPackage.Ident("HandlerOption") g.P("func ", names.ServerConstructor, "(svc ", names.Server, ", opts ...", handlerOption, ") (string, ", httpPackage.Ident("Handler"), ") {") + if len(service.Methods) > 0 { + g.P("serviceDescriptor := ", g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`) + } for _, method := range service.Methods { isStreamingServer := method.Desc.IsStreamingServer() isStreamingClient := method.Desc.IsStreamingClient() @@ -407,9 +413,7 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s g.P(procedureConstName(method), `,`) g.P("svc.", method.GoName, ",") g.P(connectPackage.Ident("WithSchema"), "(", - g.QualifiedGoIdent(file.GoDescriptorIdent), - `.Services().ByName("`, service.Desc.Name(), `")`, - `.Methods().ByName("`, method.Desc.Name(), `")),`) + `serviceDescriptor.Methods().ByName("`, method.Desc.Name(), `")),`) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 4b21ff23..9bee99e4 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -65,11 +65,12 @@ type CollideServiceClient interface { // http://api.acme.com or https://acme.com/grpc). func NewCollideServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) CollideServiceClient { baseURL = strings.TrimRight(baseURL, "/") + serviceDescriptor := v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") return &collideServiceClient{ _import: connect.NewClient[v1.ImportRequest, v1.ImportResponse]( httpClient, baseURL+CollideServiceImportProcedure, - connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Import")), connect.WithClientOptions(opts...), ), } @@ -96,10 +97,11 @@ type CollideServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + serviceDescriptor := v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") collideServiceImportHandler := connect.NewUnaryHandler( CollideServiceImportProcedure, svc.Import, - connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Import")), connect.WithHandlerOptions(opts...), ) return "/connect.collide.v1.CollideService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index cb3f2803..043ff35f 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -22,7 +22,6 @@ import ( connect "connectrpc.com/connect" _ "connectrpc.com/connect/internal/gen/connect/import/v1" http "net/http" - strings "strings" ) // This is a compile-time assertion to ensure that this generated file and the connect package are @@ -49,7 +48,6 @@ type ImportServiceClient interface { // The URL supplied here should be the base URL for the Connect or gRPC server (for example, // http://api.acme.com or https://acme.com/grpc). func NewImportServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ImportServiceClient { - baseURL = strings.TrimRight(baseURL, "/") return &importServiceClient{} } diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 84d01aa6..e60a15d1 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -87,36 +87,37 @@ type PingServiceClient interface { // http://api.acme.com or https://acme.com/grpc). func NewPingServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) PingServiceClient { baseURL = strings.TrimRight(baseURL, "/") + serviceDescriptor := v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") return &pingServiceClient{ ping: connect.NewClient[v1.PingRequest, v1.PingResponse]( httpClient, baseURL+PingServicePingProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithClientOptions(opts...), ), fail: connect.NewClient[v1.FailRequest, v1.FailResponse]( httpClient, baseURL+PingServiceFailProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Fail")), connect.WithClientOptions(opts...), ), sum: connect.NewClient[v1.SumRequest, v1.SumResponse]( httpClient, baseURL+PingServiceSumProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Sum")), connect.WithClientOptions(opts...), ), countUp: connect.NewClient[v1.CountUpRequest, v1.CountUpResponse]( httpClient, baseURL+PingServiceCountUpProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CountUp")), connect.WithClientOptions(opts...), ), cumSum: connect.NewClient[v1.CumSumRequest, v1.CumSumResponse]( httpClient, baseURL+PingServiceCumSumProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CumSum")), connect.WithClientOptions(opts...), ), } @@ -176,35 +177,36 @@ type PingServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + serviceDescriptor := v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") pingServicePingHandler := connect.NewUnaryHandler( PingServicePingProcedure, svc.Ping, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithHandlerOptions(opts...), ) pingServiceFailHandler := connect.NewUnaryHandler( PingServiceFailProcedure, svc.Fail, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Fail")), connect.WithHandlerOptions(opts...), ) pingServiceSumHandler := connect.NewClientStreamHandler( PingServiceSumProcedure, svc.Sum, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Sum")), connect.WithHandlerOptions(opts...), ) pingServiceCountUpHandler := connect.NewServerStreamHandler( PingServiceCountUpProcedure, svc.CountUp, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CountUp")), connect.WithHandlerOptions(opts...), ) pingServiceCumSumHandler := connect.NewBidiStreamHandler( PingServiceCumSumProcedure, svc.CumSum, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CumSum")), connect.WithHandlerOptions(opts...), ) return "/connect.ping.v1.PingService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { From 741312f944d76f960ab301a1cfd6375af00e3ff7 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:50:58 -0500 Subject: [PATCH 03/13] Remove test invariants --- client_ext_test.go | 91 ++++++++++++++-------------------------------- 1 file changed, 27 insertions(+), 64 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index e5f6995a..d04f20b3 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -196,71 +196,34 @@ func TestSpecSchema(t *testing.T) { connect.WithInterceptors(&assertSchemaInterceptor{t}), )) server := memhttptest.NewServer(t, mux) - testcases := []struct { - name string - opts []connect.ClientOption - }{{ - name: connect.ProtocolConnect, - }, { - name: connect.ProtocolGRPC, - opts: []connect.ClientOption{ - connect.WithGRPC(), - }, - }, { - name: connect.ProtocolGRPCWeb, - opts: []connect.ClientOption{ - connect.WithGRPC(), - }, - }} - for _, testcase := range testcases { - testcase := testcase - t.Run(testcase.name, func(t *testing.T) { - ctx := context.Background() - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - connect.WithClientOptions(testcase.opts...), - connect.WithInterceptors(&assertSchemaInterceptor{t}), - ) - t.Parallel() - t.Run("unary", func(t *testing.T) { - unaryReq := connect.NewRequest[pingv1.PingRequest](nil) - _, err := client.Ping(ctx, unaryReq) - assert.Nil(t, err) - text := strings.Repeat(".", 256) - r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) - assert.Nil(t, err) - assert.Equal(t, r.Msg.Text, text) - }) - t.Run("client_stream", func(t *testing.T) { - clientStream := client.Sum(ctx) - t.Cleanup(func() { - _, closeErr := clientStream.CloseAndReceive() - assert.Nil(t, closeErr) - }) - assert.NotZero(t, clientStream.Spec().Schema) - err := clientStream.Send(&pingv1.SumRequest{}) - assert.Nil(t, err) - }) - t.Run("server_stream", func(t *testing.T) { - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) - t.Cleanup(func() { - assert.Nil(t, serverStream.Close()) - }) - assert.Nil(t, err) - }) - t.Run("bidi_stream", func(t *testing.T) { - bidiStream := client.CumSum(ctx) - t.Cleanup(func() { - assert.Nil(t, bidiStream.CloseRequest()) - assert.Nil(t, bidiStream.CloseResponse()) - }) - assert.NotZero(t, bidiStream.Spec().Schema) - err := bidiStream.Send(&pingv1.CumSumRequest{}) - assert.Nil(t, err) - }) + ctx := context.Background() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + connect.WithInterceptors(&assertSchemaInterceptor{t}), + ) + t.Run("unary", func(t *testing.T) { + t.Parallel() + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.NotNil(t, unaryReq.Spec().Schema) + assert.Nil(t, err) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) }) - } + assert.NotZero(t, bidiStream.Spec().Schema) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) + }) } type notModifiedPingServer struct { From 9f34b83380b22a0a9308c449db7161534c100838 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:53:26 -0500 Subject: [PATCH 04/13] Cleanup schema assert test --- client_ext_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index d04f20b3..16225491 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -281,7 +281,7 @@ type assertSchemaInterceptor struct { func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { if !assert.NotNil(a.tb, req.Spec().Schema) { - return nil, fmt.Errorf("nil spec") + return next(ctx, req) } methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor) assert.True(a.tb, ok) @@ -308,7 +308,7 @@ func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClie func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { if !assert.NotNil(a.tb, conn.Spec().Schema) { - return fmt.Errorf("nil spec") + return next(ctx, conn) } methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor) assert.True(a.tb, ok) From 59ee0ef0a3abd130af8741b54e8afa2ef7394a21 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 7 Nov 2023 12:40:14 -0500 Subject: [PATCH 05/13] Support dynamic messages with Initializer options An Initializer helps to construct dynamic messages on Receive. This lets clients and servers use dynamic messages. A default initializer for dynamicpb.Message is provided. Other IDLs can provide custom Initializers using the WithInitializer option. --- client.go | 19 ++++- client_ext_test.go | 112 ++++++++++++++++++++++++++ client_stream.go | 20 +++-- client_stream_test.go | 11 ++- connect.go | 7 +- connect_ext_test.go | 1 - handler.go | 38 +++++---- handler_ext_test.go | 176 +++++++++++++++++++++++++++++++++++++++++ handler_stream.go | 17 +++- handler_stream_test.go | 10 ++- option.go | 28 +++++++ protocol.go | 24 ++++++ 12 files changed, 430 insertions(+), 33 deletions(-) diff --git a/client.go b/client.go index 2abcaa5e..41906b7c 100644 --- a/client.go +++ b/client.go @@ -92,7 +92,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien _ = conn.CloseResponse() return nil, err } - response, err := receiveUnaryResponse[Res](conn) + response, err := receiveUnaryResponse[Res](conn, config) if err != nil { _ = conn.CloseResponse() return nil, err @@ -135,7 +135,10 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo if c.err != nil { return &ClientStreamForClient[Req, Res]{err: c.err} } - return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)} + return &ClientStreamForClient[Req, Res]{ + conn: c.newConn(ctx, StreamTypeClient, nil), + config: c.config, + } } // CallServerStream calls a server streaming procedure. @@ -160,7 +163,10 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if err := conn.CloseRequest(); err != nil { return nil, err } - return &ServerStreamForClient[Res]{conn: conn}, nil + return &ServerStreamForClient[Res]{ + conn: conn, + config: c.config, + }, nil } // CallBidiStream calls a bidirectional streaming procedure. @@ -168,7 +174,10 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli if c.err != nil { return &BidiStreamForClient[Req, Res]{err: c.err} } - return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)} + return &BidiStreamForClient[Req, Res]{ + conn: c.newConn(ctx, StreamTypeBidi, nil), + config: c.config, + } } func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { @@ -190,6 +199,7 @@ type clientConfig struct { Protocol protocol Procedure string Schema any + Initializer InitializerFunc CompressMinBytes int Interceptor Interceptor CompressionPools map[string]*compressionPool @@ -217,6 +227,7 @@ func newClientConfig(rawURL string, options []ClientOption) (*clientConfig, *Err Procedure: protoPath, CompressionPools: make(map[string]*compressionPool), BufferPool: newBufferPool(), + Initializer: defaultInitializer, } withProtoBinaryCodec().applyToClient(&config) withGzip().applyToClient(&config) diff --git a/client_ext_test.go b/client_ext_test.go index 16225491..17fcfddb 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -28,6 +28,8 @@ import ( "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/dynamicpb" ) func TestNewClient_InitFailure(t *testing.T) { @@ -226,6 +228,116 @@ func TestSpecSchema(t *testing.T) { }) } +func TestDynamicClient(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) + server := memhttptest.NewServer(t, mux) + ctx := context.Background() + t.Run("unary", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + server.Client(), + server.URL()+"/connect.ping.v1.PingService/Ping", + connect.WithSchema(methodDesc), + connect.WithIdempotency(connect.IdempotencyNoSideEffects), + ) + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + res, err := client.CallUnary(ctx, connect.NewRequest(msg)) + assert.Nil(t, err) + got := res.Msg.Get(methodDesc.Output().Fields().ByName("number")).Int() + assert.Equal(t, got, 42) + }) + t.Run("clientStream", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + server.Client(), + server.URL()+"/connect.ping.v1.PingService/Sum", + connect.WithSchema(methodDesc), + ) + stream := client.CallClientStream(ctx) + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + assert.Nil(t, stream.Send(msg)) + assert.Nil(t, stream.Send(msg)) + rsp, err := stream.CloseAndReceive() + if !assert.Nil(t, err) { + return + } + got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int() + assert.Equal(t, got, 42*2) + }) + t.Run("serverStream", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + server.Client(), + server.URL()+"/connect.ping.v1.PingService/CountUp", + connect.WithSchema(methodDesc), + ) + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(2), + ) + req := connect.NewRequest(msg) + stream, err := client.CallServerStream(ctx, req) + if !assert.Nil(t, err) { + return + } + for i := 1; stream.Receive(); i++ { + out := stream.Msg() + got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() + assert.Equal(t, got, int64(i)) + } + assert.Nil(t, stream.Close()) + }) + t.Run("bidi", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + server.Client(), + server.URL()+"/connect.ping.v1.PingService/CumSum", + connect.WithSchema(methodDesc), + ) + stream := client.CallBidiStream(ctx) + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + assert.Nil(t, stream.Send(msg)) + assert.Nil(t, stream.CloseRequest()) + out, err := stream.Receive() + if assert.Nil(t, err) { + return + } + got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() + assert.Equal(t, got, 42) + }) +} + type notModifiedPingServer struct { pingv1connect.UnimplementedPingServiceHandler diff --git a/client_stream.go b/client_stream.go index 60aff987..47a95f59 100644 --- a/client_stream.go +++ b/client_stream.go @@ -25,7 +25,8 @@ import ( // It's returned from [Client].CallClientStream, but doesn't currently have an // exported constructor function. type ClientStreamForClient[Req, Res any] struct { - conn StreamingClientConn + conn StreamingClientConn + config *clientConfig // Error from client construction. If non-nil, return for all calls. err error } @@ -78,7 +79,7 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err _ = c.conn.CloseResponse() return nil, err } - response, err := receiveUnaryResponse[Res](c.conn) + response, err := receiveUnaryResponse[Res](c.conn, c.config) if err != nil { _ = c.conn.CloseResponse() return nil, err @@ -97,8 +98,9 @@ func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) { // It's returned from [Client].CallServerStream, but doesn't currently have an // exported constructor function. type ServerStreamForClient[Res any] struct { - conn StreamingClientConn - msg *Res + conn StreamingClientConn + config *clientConfig + msg *Res // Error from client construction. If non-nil, return for all calls. constructErr error // Error from conn.Receive(). @@ -115,6 +117,10 @@ func (s *ServerStreamForClient[Res]) Receive() bool { return false } s.msg = new(Res) + if err := s.config.Initializer(s.conn.Spec(), s.msg); err != nil { + s.receiveErr = err + return false + } s.receiveErr = s.conn.Receive(s.msg) return s.receiveErr == nil } @@ -175,7 +181,8 @@ func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) { // It's returned from [Client].CallBidiStream, but doesn't currently have an // exported constructor function. type BidiStreamForClient[Req, Res any] struct { - conn StreamingClientConn + conn StreamingClientConn + config *clientConfig // Error from client construction. If non-nil, return for all calls. err error } @@ -234,6 +241,9 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { return nil, b.err } var msg Res + if err := b.config.Initializer(b.conn.Spec(), &msg); err != nil { + return nil, err + } if err := b.conn.Receive(&msg); err != nil { return nil, err } diff --git a/client_stream_test.go b/client_stream_test.go index 431599f6..557f597e 100644 --- a/client_stream_test.go +++ b/client_stream_test.go @@ -55,7 +55,12 @@ func TestServerStreamForClient_NoPanics(t *testing.T) { func TestServerStreamForClient(t *testing.T) { t.Parallel() - stream := &ServerStreamForClient[pingv1.PingResponse]{conn: &nopStreamingClientConn{}} + config, cerr := newClientConfig("http://localhost:1234", nil) + assert.Nil(t, cerr) + stream := &ServerStreamForClient[pingv1.PingResponse]{ + conn: &nopStreamingClientConn{}, + config: config, + } // Ensure that each call to Receive allocates a new message. This helps // vtprotobuf, which doesn't automatically zero messages before unmarshaling // (see https://connectrpc.com/connect/issues/345), and it's also @@ -104,3 +109,7 @@ type nopStreamingClientConn struct { func (c *nopStreamingClientConn) Receive(msg any) error { return nil } + +func (c *nopStreamingClientConn) Spec() Spec { + return Spec{} +} diff --git a/connect.go b/connect.go index 622852fb..62b0a62f 100644 --- a/connect.go +++ b/connect.go @@ -358,15 +358,18 @@ type handlerConnCloser interface { // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple // messages. -func receiveUnaryResponse[T any](conn StreamingClientConn) (*Response[T], error) { +func receiveUnaryResponse[T any](conn StreamingClientConn, config *clientConfig) (*Response[T], error) { var msg T + if err := config.Initializer(conn.Spec(), &msg); err != nil { + return nil, err + } if err := conn.Receive(&msg); err != nil { return nil, err } // In a well-formed stream, the response message may be followed by a block // of in-stream trailers or HTTP trailers. To ensure that we receive the // trailers, try to read another message from the stream. - if err := conn.Receive(new(T)); err == nil { + if err := conn.Receive(nil); err == nil { return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages")) } else if err != nil && !errors.Is(err, io.EOF) { return nil, NewError(CodeUnknown, err) diff --git a/connect_ext_test.go b/connect_ext_test.go index af541259..c5d9c5a6 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2360,7 +2360,6 @@ func (p *pluggablePingServer) CumSum( func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { tb.Helper() - if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { assert.ErrorIs(tb, err, io.EOF) assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) diff --git a/handler.go b/handler.go index b207ac66..03333a81 100644 --- a/handler.go +++ b/handler.go @@ -64,6 +64,9 @@ func NewUnaryHandler[Req, Res any]( // Given a stream, how should we call the unary function? implementation := func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req + if err := config.Initializer(conn.Spec(), &msg); err != nil { + return err + } if err := conn.Receive(&msg); err != nil { return err } @@ -103,11 +106,14 @@ func NewClientStreamHandler[Req, Res any]( implementation func(context.Context, *ClientStream[Req]) (*Response[Res], error), options ...HandlerOption, ) *Handler { + config := newHandlerConfig(procedure, StreamTypeClient, options) return newStreamHandler( - procedure, - StreamTypeClient, + config, func(ctx context.Context, conn StreamingHandlerConn) error { - stream := &ClientStream[Req]{conn: conn} + stream := &ClientStream[Req]{ + conn: conn, + config: config, + } res, err := implementation(ctx, stream) if err != nil { return err @@ -121,7 +127,6 @@ func NewClientStreamHandler[Req, Res any]( mergeHeaders(conn.ResponseTrailer(), res.trailer) return conn.Send(res.Msg) }, - options..., ) } @@ -131,11 +136,14 @@ func NewServerStreamHandler[Req, Res any]( implementation func(context.Context, *Request[Req], *ServerStream[Res]) error, options ...HandlerOption, ) *Handler { + config := newHandlerConfig(procedure, StreamTypeServer, options) return newStreamHandler( - procedure, - StreamTypeServer, + config, func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req + if err := config.Initializer(conn.Spec(), &msg); err != nil { + return err + } if err := conn.Receive(&msg); err != nil { return err } @@ -151,7 +159,6 @@ func NewServerStreamHandler[Req, Res any]( &ServerStream[Res]{conn: conn}, ) }, - options..., ) } @@ -161,16 +168,18 @@ func NewBidiStreamHandler[Req, Res any]( implementation func(context.Context, *BidiStream[Req, Res]) error, options ...HandlerOption, ) *Handler { + config := newHandlerConfig(procedure, StreamTypeBidi, options) return newStreamHandler( - procedure, - StreamTypeBidi, + config, func(ctx context.Context, conn StreamingHandlerConn) error { return implementation( ctx, - &BidiStream[Req, Res]{conn: conn}, + &BidiStream[Req, Res]{ + conn: conn, + config: config, + }, ) }, - options..., ) } @@ -247,6 +256,7 @@ type handlerConfig struct { Interceptor Interceptor Procedure string Schema any + Initializer InitializerFunc HandleGRPC bool HandleGRPCWeb bool RequireConnectProtocolHeader bool @@ -267,6 +277,7 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler HandleGRPCWeb: true, BufferPool: newBufferPool(), StreamType: streamType, + Initializer: defaultInitializer, } withProtoBinaryCodec().applyToHandler(&config) withProtoJSONCodecs().applyToHandler(&config) @@ -317,12 +328,9 @@ func (c *handlerConfig) newProtocolHandlers() []protocolHandler { } func newStreamHandler( - procedure string, - streamType StreamType, + config *handlerConfig, implementation StreamingHandlerFunc, - options ...HandlerOption, ) *Handler { - config := newHandlerConfig(procedure, streamType, options) if ic := config.Interceptor; ic != nil { implementation = ic.WrapStreamingHandler(implementation) } diff --git a/handler_ext_test.go b/handler_ext_test.go index 25cde595..4638e4b9 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -19,6 +19,7 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "io" "net/http" "strings" @@ -30,6 +31,9 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/dynamicpb" ) func TestHandler_ServeHTTP(t *testing.T) { @@ -250,6 +254,178 @@ func TestHandlerMaliciousPrefix(t *testing.T) { wg.Wait() } +func TestDynamicHandler(t *testing.T) { + t.Parallel() + t.Run("unary", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicPing := func(_ context.Context, req *connect.Request[dynamicpb.Message]) (*connect.Response[dynamicpb.Message], error) { + got := req.Msg.Get(methodDesc.Input().Fields().ByName("number")).Int() + msg := dynamicpb.NewMessage(methodDesc.Output()) + msg.Set( + methodDesc.Output().Fields().ByName("number"), + protoreflect.ValueOfInt64(got), + ) + return connect.NewResponse(msg), nil + } + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/Ping", + connect.NewUnaryHandler( + "/connect.ping.v1.PingService/Ping", + dynamicPing, + connect.WithSchema(methodDesc), + connect.WithIdempotency(connect.IdempotencyNoSideEffects), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + rsp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{ + Number: 42, + })) + if !assert.Nil(t, err) { + return + } + got := rsp.Msg.Number + assert.Equal(t, got, 42) + }) + t.Run("clientStream", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicSum := func(_ context.Context, stream *connect.ClientStream[dynamicpb.Message]) (*connect.Response[dynamicpb.Message], error) { + var sum int64 + for stream.Receive() { + got := stream.Msg().Get( + methodDesc.Input().Fields().ByName("number"), + ).Int() + sum += got + } + msg := dynamicpb.NewMessage(methodDesc.Output()) + msg.Set( + methodDesc.Output().Fields().ByName("sum"), + protoreflect.ValueOfInt64(sum), + ) + return connect.NewResponse(msg), nil + } + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/Sum", + connect.NewClientStreamHandler( + "/connect.ping.v1.PingService/Sum", + dynamicSum, + connect.WithSchema(methodDesc), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + stream := client.Sum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 42})) + assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 42})) + rsp, err := stream.CloseAndReceive() + if !assert.Nil(t, err) { + return + } + assert.Equal(t, rsp.Msg.Sum, 42*2) + }) + t.Run("serverStream", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicCountUp := func(_ context.Context, req *connect.Request[dynamicpb.Message], stream *connect.ServerStream[dynamicpb.Message]) error { + number := req.Msg.Get(methodDesc.Input().Fields().ByName("number")).Int() + for i := int64(1); i <= number; i++ { + msg := dynamicpb.NewMessage(methodDesc.Output()) + msg.Set( + methodDesc.Output().Fields().ByName("number"), + protoreflect.ValueOfInt64(i), + ) + if err := stream.Send(msg); err != nil { + return err + } + } + return nil + } + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/CountUp", + connect.NewServerStreamHandler( + "/connect.ping.v1.PingService/CountUp", + dynamicCountUp, + connect.WithSchema(methodDesc), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{ + Number: 2, + })) + if !assert.Nil(t, err) { + return + } + var sum int64 + for stream.Receive() { + sum += stream.Msg().Number + } + assert.Nil(t, stream.Err()) + assert.Equal(t, sum, 3) // 1 + 2 + }) + t.Run("bidi", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicCumSum := func( + _ context.Context, + stream *connect.BidiStream[dynamicpb.Message, dynamicpb.Message], + ) error { + var sum int64 + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + got := msg.Get(methodDesc.Input().Fields().ByName("number")).Int() + sum += got + out := dynamicpb.NewMessage(methodDesc.Output()) + out.Set( + methodDesc.Output().Fields().ByName("sum"), + protoreflect.ValueOfInt64(sum), + ) + if err := stream.Send(out); err != nil { + return err + } + } + } + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/CumSum", + connect.NewBidiStreamHandler( + "/connect.ping.v1.PingService/CumSum", + dynamicCumSum, + connect.WithSchema(methodDesc), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + stream := client.CumSum(context.Background()) + assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 1})) + msg, err := stream.Receive() + if !assert.Nil(t, err) { + return + } + assert.Equal(t, msg.Sum, int64(1)) + assert.Nil(t, stream.CloseRequest()) + assert.Nil(t, stream.CloseResponse()) + }) +} + type successPingServer struct { pingv1connect.UnimplementedPingServiceHandler } diff --git a/handler_stream.go b/handler_stream.go index 05eae981..314d52d2 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -25,9 +25,10 @@ import ( // It's constructed as part of [Handler] invocation, but doesn't currently have // an exported constructor. type ClientStream[Req any] struct { - conn StreamingHandlerConn - msg *Req - err error + conn StreamingHandlerConn + config *handlerConfig + msg *Req + err error } // Spec returns the specification for the RPC. @@ -55,6 +56,10 @@ func (c *ClientStream[Req]) Receive() bool { return false } c.msg = new(Req) + if err := c.config.Initializer(c.Spec(), c.msg); err != nil { + c.err = err + return false + } c.err = c.conn.Receive(c.msg) return c.err == nil } @@ -127,7 +132,8 @@ func (s *ServerStream[Res]) Conn() StreamingHandlerConn { // It's constructed as part of [Handler] invocation, but doesn't currently have // an exported constructor. type BidiStream[Req, Res any] struct { - conn StreamingHandlerConn + conn StreamingHandlerConn + config *handlerConfig } // Spec returns the specification for the RPC. @@ -149,6 +155,9 @@ func (b *BidiStream[Req, Res]) RequestHeader() http.Header { // return an error that wraps [io.EOF]. func (b *BidiStream[Req, Res]) Receive() (*Req, error) { var req Req + if err := b.config.Initializer(b.Spec(), &req); err != nil { + return nil, err + } if err := b.conn.Receive(&req); err != nil { return nil, err } diff --git a/handler_stream_test.go b/handler_stream_test.go index 41b7f019..25b0abf9 100644 --- a/handler_stream_test.go +++ b/handler_stream_test.go @@ -27,7 +27,11 @@ func TestClientStreamIterator(t *testing.T) { // The server's view of a client streaming RPC is an iterator. For safety, // and to match grpc-go's behavior, we should allocate a new message for each // iteration. - stream := &ClientStream[pingv1.PingRequest]{conn: &nopStreamingHandlerConn{}} + config := newHandlerConfig("/connect.ping.v1.PingService/Ping", StreamTypeUnary, nil) + stream := &ClientStream[pingv1.PingRequest]{ + conn: &nopStreamingHandlerConn{}, + config: config, + } assert.True(t, stream.Receive()) first := fmt.Sprintf("%p", stream.Msg()) assert.True(t, stream.Receive()) @@ -42,3 +46,7 @@ type nopStreamingHandlerConn struct { func (nopStreamingHandlerConn) Receive(msg any) error { return nil } + +func (nopStreamingHandlerConn) Spec() Spec { + return Spec{} +} diff --git a/option.go b/option.go index c5d15f05..6572a758 100644 --- a/option.go +++ b/option.go @@ -194,6 +194,22 @@ func WithSchema(schema any) Option { return &schemaOption{Schema: schema} } +// InitializerFunc is a function that initializes a message. It may be used to +// dynamically construct messages. It is called on client and handler Receive to +// construct the message to be unmarshaled into. +type InitializerFunc func(spec Spec, msg any) error + +// WithInitializer provides a function that initializes a message. It may be +// used to dynamically construct messages. +// +// By default, an initializer is provided to support [dynamicpb.Message] +// messages. This initializer sets the descriptor from the Schema field of the +// [Spec]. The Schema must be of type [protoreflect.MethodDescriptor] to use +// dynamicpb.Message. +func WithInitializer(initializer InitializerFunc) Option { + return &initializerOption{Initializer: initializer} +} + // WithCodec registers a serialization method with a client or handler. // Handlers may have multiple codecs registered, and use whichever the client // chooses. Clients may only have a single codec. @@ -350,6 +366,18 @@ func (o *schemaOption) applyToHandler(config *handlerConfig) { config.Schema = o.Schema } +type initializerOption struct { + Initializer InitializerFunc +} + +func (o *initializerOption) applyToClient(config *clientConfig) { + config.Initializer = o.Initializer +} + +func (o *initializerOption) applyToHandler(config *handlerConfig) { + config.Initializer = o.Initializer +} + type clientOptionsOption struct { options []ClientOption } diff --git a/protocol.go b/protocol.go index cc6455cd..112a4929 100644 --- a/protocol.go +++ b/protocol.go @@ -24,6 +24,9 @@ import ( "net/url" "sort" "strings" + + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) // The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by @@ -394,3 +397,24 @@ func canonicalizeContentTypeSlow(contentType string) string { } return mime.FormatMediaType(base, params) } + +// defaultInitializer is the default initializer for dynamic messages. It +// initializes the message to the type specified in the Spec. +func defaultInitializer(spec Spec, msg any) error { + dynamic, ok := msg.(*dynamicpb.Message) + if !ok { + return nil + } + desc, ok := spec.Schema.(protoreflect.MethodDescriptor) + if !ok { + return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic) + } + // If the message is a client message, initialize it to the output type of + // the RPC. Otherwise, initialize it to the input type. + if spec.IsClient { + *dynamic = *dynamicpb.NewMessage(desc.Output()) + } else { + *dynamic = *dynamicpb.NewMessage(desc.Input()) + } + return nil +} From 93b1292b04ddc1a3138e65ca96f6918433a7b01c Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 8 Nov 2023 11:47:03 -0500 Subject: [PATCH 06/13] Split into Request and Response initializer opts --- client_ext_test.go | 37 ++++++++++++++++++++++++++++++++ handler_ext_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ option.go | 39 +++++++++++++++++++++++----------- protocol.go | 9 ++++---- 4 files changed, 120 insertions(+), 16 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 17fcfddb..e75a93af 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -336,6 +336,43 @@ func TestDynamicClient(t *testing.T) { got := out.Get(methodDesc.Output().Fields().ByName("number")).Int() assert.Equal(t, got, 42) }) + t.Run("option", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + optionCalled := false + client := connect.NewClient[dynamicpb.Message, dynamicpb.Message]( + server.Client(), + server.URL()+"/connect.ping.v1.PingService/Ping", + connect.WithSchema(methodDesc), + connect.WithIdempotency(connect.IdempotencyNoSideEffects), + connect.WithResponseInitializer( + func(spec connect.Spec, msg any) error { + assert.NotNil(t, spec) + assert.NotNil(t, msg) + dynamic, ok := msg.(*dynamicpb.Message) + if !assert.True(t, ok) { + return fmt.Errorf("unexpected message type: %T", msg) + } + *dynamic = *dynamicpb.NewMessage(methodDesc.Output()) + optionCalled = true + return nil + }, + ), + ) + msg := dynamicpb.NewMessage(methodDesc.Input()) + msg.Set( + methodDesc.Input().Fields().ByName("number"), + protoreflect.ValueOfInt64(42), + ) + res, err := client.CallUnary(ctx, connect.NewRequest(msg)) + assert.Nil(t, err) + got := res.Msg.Get(methodDesc.Output().Fields().ByName("number")).Int() + assert.Equal(t, got, 42) + assert.True(t, optionCalled) + }) } type notModifiedPingServer struct { diff --git a/handler_ext_test.go b/handler_ext_test.go index 4638e4b9..bf89415d 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -424,6 +425,56 @@ func TestDynamicHandler(t *testing.T) { assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) }) + t.Run("option", func(t *testing.T) { + t.Parallel() + desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") + assert.Nil(t, err) + methodDesc, ok := desc.(protoreflect.MethodDescriptor) + assert.True(t, ok) + dynamicPing := func(_ context.Context, req *connect.Request[dynamicpb.Message]) (*connect.Response[dynamicpb.Message], error) { + got := req.Msg.Get(methodDesc.Input().Fields().ByName("number")).Int() + msg := dynamicpb.NewMessage(methodDesc.Output()) + msg.Set( + methodDesc.Output().Fields().ByName("number"), + protoreflect.ValueOfInt64(got), + ) + return connect.NewResponse(msg), nil + } + optionCalled := false + mux := http.NewServeMux() + mux.Handle("/connect.ping.v1.PingService/Ping", + connect.NewUnaryHandler( + "/connect.ping.v1.PingService/Ping", + dynamicPing, + connect.WithSchema(methodDesc), + connect.WithIdempotency(connect.IdempotencyNoSideEffects), + connect.WithRequestInitializer( + func(spec connect.Spec, msg any) error { + assert.NotNil(t, spec) + assert.NotNil(t, msg) + dynamic, ok := msg.(*dynamicpb.Message) + if !assert.True(t, ok) { + return fmt.Errorf("unexpected message type: %T", msg) + } + *dynamic = *dynamicpb.NewMessage(methodDesc.Input()) + optionCalled = true + return nil + }, + ), + ), + ) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + rsp, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{ + Number: 42, + })) + if !assert.Nil(t, err) { + return + } + got := rsp.Msg.Number + assert.Equal(t, got, 42) + assert.True(t, optionCalled) + }) } type successPingServer struct { diff --git a/option.go b/option.go index 6572a758..fb2888cd 100644 --- a/option.go +++ b/option.go @@ -195,19 +195,30 @@ func WithSchema(schema any) Option { } // InitializerFunc is a function that initializes a message. It may be used to -// dynamically construct messages. It is called on client and handler Receive to -// construct the message to be unmarshaled into. +// dynamically construct messages. It is called on client and handler receives +// to construct the message to be unmarshaled into. type InitializerFunc func(spec Spec, msg any) error -// WithInitializer provides a function that initializes a message. It may be -// used to dynamically construct messages. +// WithRequestInitializer provides a function that initializes a new message. +// It may be used to dynamically construct request messages. // // By default, an initializer is provided to support [dynamicpb.Message] -// messages. This initializer sets the descriptor from the Schema field of the -// [Spec]. The Schema must be of type [protoreflect.MethodDescriptor] to use -// dynamicpb.Message. -func WithInitializer(initializer InitializerFunc) Option { - return &initializerOption{Initializer: initializer} +// messages. This initializer sets the input descriptor from the Schema field +// of the [Spec]. The Schema must be of type [protoreflect.MethodDescriptor] +// to use dynamicpb.Message. +func WithRequestInitializer(initializer InitializerFunc) HandlerOption { + return &requestInitializerOption{Initializer: initializer} +} + +// WithResponseInitializer provides a function that initializes a new message. +// It may be used to dynamically construct response messages. +// +// By default, an initializer is provided to support [dynamicpb.Message] +// messages. This initializer sets the output descriptor from the Schema field +// of the [Spec]. The Schema must be of type [protoreflect.MethodDescriptor] +// to use dynamicpb.Message. +func WithResponseInitializer(initializer InitializerFunc) ClientOption { + return &responseInitializerOption{Initializer: initializer} } // WithCodec registers a serialization method with a client or handler. @@ -366,15 +377,19 @@ func (o *schemaOption) applyToHandler(config *handlerConfig) { config.Schema = o.Schema } -type initializerOption struct { +type requestInitializerOption struct { Initializer InitializerFunc } -func (o *initializerOption) applyToClient(config *clientConfig) { +func (o *requestInitializerOption) applyToHandler(config *handlerConfig) { config.Initializer = o.Initializer } -func (o *initializerOption) applyToHandler(config *handlerConfig) { +type responseInitializerOption struct { + Initializer InitializerFunc +} + +func (o *responseInitializerOption) applyToClient(config *clientConfig) { config.Initializer = o.Initializer } diff --git a/protocol.go b/protocol.go index 112a4929..a4d6426c 100644 --- a/protocol.go +++ b/protocol.go @@ -398,8 +398,9 @@ func canonicalizeContentTypeSlow(contentType string) string { return mime.FormatMediaType(base, params) } -// defaultInitializer is the default initializer for dynamic messages. It -// initializes the message to the type specified in the Spec. +// defaultInitializer is the default initializer that adds support for +// dynamicpb.Message. If of message type the Schema is cast to access the +// method descriptor and initialized with the message descriptor. func defaultInitializer(spec Spec, msg any) error { dynamic, ok := msg.(*dynamicpb.Message) if !ok { @@ -409,8 +410,8 @@ func defaultInitializer(spec Spec, msg any) error { if !ok { return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic) } - // If the message is a client message, initialize it to the output type of - // the RPC. Otherwise, initialize it to the input type. + // If the message is a client message, initialize it to the output type + // of the method. Otherwise, initialize it to the input type. if spec.IsClient { *dynamic = *dynamicpb.NewMessage(desc.Output()) } else { From eca979320f9610451af13e277dab3406dcc62ff5 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 9 Nov 2023 10:12:50 -0500 Subject: [PATCH 07/13] Remove default initializer --- client.go | 1 - client_ext_test.go | 20 ++++++++++++++++++++ client_stream.go | 14 +++++++++----- connect.go | 6 ++++-- handler.go | 13 ++++++++----- handler_ext_test.go | 20 ++++++++++++++++++++ handler_stream.go | 14 +++++++++----- protocol.go | 25 ------------------------- 8 files changed, 70 insertions(+), 43 deletions(-) diff --git a/client.go b/client.go index 41906b7c..c658e86b 100644 --- a/client.go +++ b/client.go @@ -227,7 +227,6 @@ func newClientConfig(rawURL string, options []ClientOption) (*clientConfig, *Err Procedure: protoPath, CompressionPools: make(map[string]*compressionPool), BufferPool: newBufferPool(), - Initializer: defaultInitializer, } withProtoBinaryCodec().applyToClient(&config) withGzip().applyToClient(&config) diff --git a/client_ext_test.go b/client_ext_test.go index e50a6556..4db88fe8 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -234,6 +234,22 @@ func TestDynamicClient(t *testing.T) { mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) server := memhttptest.NewServer(t, mux) ctx := context.Background() + initializer := func(spec connect.Spec, msg any) error { + dynamic, ok := msg.(*dynamicpb.Message) + if !ok { + return nil + } + desc, ok := spec.Schema.(protoreflect.MethodDescriptor) + if !ok { + return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic) + } + if spec.IsClient { + *dynamic = *dynamicpb.NewMessage(desc.Output()) + } else { + *dynamic = *dynamicpb.NewMessage(desc.Input()) + } + return nil + } t.Run("unary", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") @@ -245,6 +261,7 @@ func TestDynamicClient(t *testing.T) { server.URL()+"/connect.ping.v1.PingService/Ping", connect.WithSchema(methodDesc), connect.WithIdempotency(connect.IdempotencyNoSideEffects), + connect.WithResponseInitializer(initializer), ) msg := dynamicpb.NewMessage(methodDesc.Input()) msg.Set( @@ -266,6 +283,7 @@ func TestDynamicClient(t *testing.T) { server.Client(), server.URL()+"/connect.ping.v1.PingService/Sum", connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), ) stream := client.CallClientStream(ctx) msg := dynamicpb.NewMessage(methodDesc.Input()) @@ -292,6 +310,7 @@ func TestDynamicClient(t *testing.T) { server.Client(), server.URL()+"/connect.ping.v1.PingService/CountUp", connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), ) msg := dynamicpb.NewMessage(methodDesc.Input()) msg.Set( @@ -320,6 +339,7 @@ func TestDynamicClient(t *testing.T) { server.Client(), server.URL()+"/connect.ping.v1.PingService/CumSum", connect.WithSchema(methodDesc), + connect.WithResponseInitializer(initializer), ) stream := client.CallBidiStream(ctx) msg := dynamicpb.NewMessage(methodDesc.Input()) diff --git a/client_stream.go b/client_stream.go index 47a95f59..9536215f 100644 --- a/client_stream.go +++ b/client_stream.go @@ -117,9 +117,11 @@ func (s *ServerStreamForClient[Res]) Receive() bool { return false } s.msg = new(Res) - if err := s.config.Initializer(s.conn.Spec(), s.msg); err != nil { - s.receiveErr = err - return false + if s.config.Initializer != nil { + if err := s.config.Initializer(s.conn.Spec(), s.msg); err != nil { + s.receiveErr = err + return false + } } s.receiveErr = s.conn.Receive(s.msg) return s.receiveErr == nil @@ -241,8 +243,10 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { return nil, b.err } var msg Res - if err := b.config.Initializer(b.conn.Spec(), &msg); err != nil { - return nil, err + if b.config.Initializer != nil { + if err := b.config.Initializer(b.conn.Spec(), &msg); err != nil { + return nil, err + } } if err := b.conn.Receive(&msg); err != nil { return nil, err diff --git a/connect.go b/connect.go index 62b0a62f..b1c48bba 100644 --- a/connect.go +++ b/connect.go @@ -360,8 +360,10 @@ type handlerConnCloser interface { // messages. func receiveUnaryResponse[T any](conn StreamingClientConn, config *clientConfig) (*Response[T], error) { var msg T - if err := config.Initializer(conn.Spec(), &msg); err != nil { - return nil, err + if config.Initializer != nil { + if err := config.Initializer(conn.Spec(), &msg); err != nil { + return nil, err + } } if err := conn.Receive(&msg); err != nil { return nil, err diff --git a/handler.go b/handler.go index 03333a81..91f0bcf9 100644 --- a/handler.go +++ b/handler.go @@ -64,8 +64,10 @@ func NewUnaryHandler[Req, Res any]( // Given a stream, how should we call the unary function? implementation := func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req - if err := config.Initializer(conn.Spec(), &msg); err != nil { - return err + if config.Initializer != nil { + if err := config.Initializer(conn.Spec(), &msg); err != nil { + return err + } } if err := conn.Receive(&msg); err != nil { return err @@ -141,8 +143,10 @@ func NewServerStreamHandler[Req, Res any]( config, func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req - if err := config.Initializer(conn.Spec(), &msg); err != nil { - return err + if config.Initializer != nil { + if err := config.Initializer(conn.Spec(), &msg); err != nil { + return err + } } if err := conn.Receive(&msg); err != nil { return err @@ -277,7 +281,6 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler HandleGRPCWeb: true, BufferPool: newBufferPool(), StreamType: streamType, - Initializer: defaultInitializer, } withProtoBinaryCodec().applyToHandler(&config) withProtoJSONCodecs().applyToHandler(&config) diff --git a/handler_ext_test.go b/handler_ext_test.go index bf89415d..7cab9727 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -257,6 +257,22 @@ func TestHandlerMaliciousPrefix(t *testing.T) { func TestDynamicHandler(t *testing.T) { t.Parallel() + initializer := func(spec connect.Spec, msg any) error { + dynamic, ok := msg.(*dynamicpb.Message) + if !ok { + return nil + } + desc, ok := spec.Schema.(protoreflect.MethodDescriptor) + if !ok { + return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic) + } + if spec.IsClient { + *dynamic = *dynamicpb.NewMessage(desc.Output()) + } else { + *dynamic = *dynamicpb.NewMessage(desc.Input()) + } + return nil + } t.Run("unary", func(t *testing.T) { t.Parallel() desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping") @@ -279,6 +295,7 @@ func TestDynamicHandler(t *testing.T) { dynamicPing, connect.WithSchema(methodDesc), connect.WithIdempotency(connect.IdempotencyNoSideEffects), + connect.WithRequestInitializer(initializer), ), ) server := memhttptest.NewServer(t, mux) @@ -319,6 +336,7 @@ func TestDynamicHandler(t *testing.T) { "/connect.ping.v1.PingService/Sum", dynamicSum, connect.WithSchema(methodDesc), + connect.WithRequestInitializer(initializer), ), ) server := memhttptest.NewServer(t, mux) @@ -358,6 +376,7 @@ func TestDynamicHandler(t *testing.T) { "/connect.ping.v1.PingService/CountUp", dynamicCountUp, connect.WithSchema(methodDesc), + connect.WithRequestInitializer(initializer), ), ) server := memhttptest.NewServer(t, mux) @@ -411,6 +430,7 @@ func TestDynamicHandler(t *testing.T) { "/connect.ping.v1.PingService/CumSum", dynamicCumSum, connect.WithSchema(methodDesc), + connect.WithRequestInitializer(initializer), ), ) server := memhttptest.NewServer(t, mux) diff --git a/handler_stream.go b/handler_stream.go index 314d52d2..0755baa8 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -56,9 +56,11 @@ func (c *ClientStream[Req]) Receive() bool { return false } c.msg = new(Req) - if err := c.config.Initializer(c.Spec(), c.msg); err != nil { - c.err = err - return false + if c.config.Initializer != nil { + if err := c.config.Initializer(c.Spec(), c.msg); err != nil { + c.err = err + return false + } } c.err = c.conn.Receive(c.msg) return c.err == nil @@ -155,8 +157,10 @@ func (b *BidiStream[Req, Res]) RequestHeader() http.Header { // return an error that wraps [io.EOF]. func (b *BidiStream[Req, Res]) Receive() (*Req, error) { var req Req - if err := b.config.Initializer(b.Spec(), &req); err != nil { - return nil, err + if b.config.Initializer != nil { + if err := b.config.Initializer(b.Spec(), &req); err != nil { + return nil, err + } } if err := b.conn.Receive(&req); err != nil { return nil, err diff --git a/protocol.go b/protocol.go index a4d6426c..cc6455cd 100644 --- a/protocol.go +++ b/protocol.go @@ -24,9 +24,6 @@ import ( "net/url" "sort" "strings" - - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/dynamicpb" ) // The names of the Connect, gRPC, and gRPC-Web protocols (as exposed by @@ -397,25 +394,3 @@ func canonicalizeContentTypeSlow(contentType string) string { } return mime.FormatMediaType(base, params) } - -// defaultInitializer is the default initializer that adds support for -// dynamicpb.Message. If of message type the Schema is cast to access the -// method descriptor and initialized with the message descriptor. -func defaultInitializer(spec Spec, msg any) error { - dynamic, ok := msg.(*dynamicpb.Message) - if !ok { - return nil - } - desc, ok := spec.Schema.(protoreflect.MethodDescriptor) - if !ok { - return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic) - } - // If the message is a client message, initialize it to the output type - // of the method. Otherwise, initialize it to the input type. - if spec.IsClient { - *dynamic = *dynamicpb.NewMessage(desc.Output()) - } else { - *dynamic = *dynamicpb.NewMessage(desc.Input()) - } - return nil -} From 2fef10869ebc0545cfac1328e5e789d1891357c2 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 9 Nov 2023 10:30:29 -0500 Subject: [PATCH 08/13] Cleanup docs --- option.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/option.go b/option.go index fb2888cd..d96284bf 100644 --- a/option.go +++ b/option.go @@ -197,26 +197,20 @@ func WithSchema(schema any) Option { // InitializerFunc is a function that initializes a message. It may be used to // dynamically construct messages. It is called on client and handler receives // to construct the message to be unmarshaled into. -type InitializerFunc func(spec Spec, msg any) error +// +// The message will be a non nil pointer to the type created by the client or +// handler. Use the Schema field of the [Spec] to determine the type of the +// message. +type InitializerFunc func(spec Spec, message any) error // WithRequestInitializer provides a function that initializes a new message. // It may be used to dynamically construct request messages. -// -// By default, an initializer is provided to support [dynamicpb.Message] -// messages. This initializer sets the input descriptor from the Schema field -// of the [Spec]. The Schema must be of type [protoreflect.MethodDescriptor] -// to use dynamicpb.Message. func WithRequestInitializer(initializer InitializerFunc) HandlerOption { return &requestInitializerOption{Initializer: initializer} } // WithResponseInitializer provides a function that initializes a new message. // It may be used to dynamically construct response messages. -// -// By default, an initializer is provided to support [dynamicpb.Message] -// messages. This initializer sets the output descriptor from the Schema field -// of the [Spec]. The Schema must be of type [protoreflect.MethodDescriptor] -// to use dynamicpb.Message. func WithResponseInitializer(initializer InitializerFunc) ClientOption { return &responseInitializerOption{Initializer: initializer} } From 98ff1f91f57021e9e1b8ddc401c6191391db0610 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 9 Nov 2023 10:37:51 -0500 Subject: [PATCH 09/13] Fix unary handler 2nd msg --- connect.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/connect.go b/connect.go index b1c48bba..fd0dbb43 100644 --- a/connect.go +++ b/connect.go @@ -371,7 +371,14 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, config *clientConfig) // In a well-formed stream, the response message may be followed by a block // of in-stream trailers or HTTP trailers. To ensure that we receive the // trailers, try to read another message from the stream. - if err := conn.Receive(nil); err == nil { + // TODO: optimise unary calls to avoid this extra receive. + var msg2 T + if config.Initializer != nil { + if err := config.Initializer(conn.Spec(), &msg2); err != nil { + return nil, err + } + } + if err := conn.Receive(&msg2); err == nil { return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages")) } else if err != nil && !errors.Is(err, io.EOF) { return nil, NewError(CodeUnknown, err) From 01ceb88f6e8d72451a6fb93b1a02ec7fec736bfa Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 20 Nov 2023 09:51:07 -0500 Subject: [PATCH 10/13] Unexport InitializerFunc --- client.go | 2 +- handler.go | 2 +- option.go | 27 ++++++++++++--------------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index c658e86b..78206415 100644 --- a/client.go +++ b/client.go @@ -199,7 +199,7 @@ type clientConfig struct { Protocol protocol Procedure string Schema any - Initializer InitializerFunc + Initializer func(Spec, any) error CompressMinBytes int Interceptor Interceptor CompressionPools map[string]*compressionPool diff --git a/handler.go b/handler.go index 91f0bcf9..e9ab2b75 100644 --- a/handler.go +++ b/handler.go @@ -260,7 +260,7 @@ type handlerConfig struct { Interceptor Interceptor Procedure string Schema any - Initializer InitializerFunc + Initializer func(Spec, any) error HandleGRPC bool HandleGRPCWeb bool RequireConnectProtocolHeader bool diff --git a/option.go b/option.go index d96284bf..e756416a 100644 --- a/option.go +++ b/option.go @@ -194,24 +194,21 @@ func WithSchema(schema any) Option { return &schemaOption{Schema: schema} } -// InitializerFunc is a function that initializes a message. It may be used to -// dynamically construct messages. It is called on client and handler receives -// to construct the message to be unmarshaled into. -// -// The message will be a non nil pointer to the type created by the client or -// handler. Use the Schema field of the [Spec] to determine the type of the -// message. -type InitializerFunc func(spec Spec, message any) error - // WithRequestInitializer provides a function that initializes a new message. -// It may be used to dynamically construct request messages. -func WithRequestInitializer(initializer InitializerFunc) HandlerOption { +// It may be used to dynamically construct request messages. It is called on +// server receives to construct the message to be unmarshaled into. The message +// will be a non nil pointer to the type created by the handler. Use the Schema +// field of the [Spec] to determine the type of the message. +func WithRequestInitializer(initializer func(spec Spec, message any) error) HandlerOption { return &requestInitializerOption{Initializer: initializer} } // WithResponseInitializer provides a function that initializes a new message. -// It may be used to dynamically construct response messages. -func WithResponseInitializer(initializer InitializerFunc) ClientOption { +// It may be used to dynamically construct response messages. It is called on +// client receives to construct the message to be unmarshaled into. The message +// will be a non nil pointer to the type created by the client. Use the Schema +// field of the [Spec] to determine the type of the message. +func WithResponseInitializer(initializer func(spec Spec, message any) error) ClientOption { return &responseInitializerOption{Initializer: initializer} } @@ -372,7 +369,7 @@ func (o *schemaOption) applyToHandler(config *handlerConfig) { } type requestInitializerOption struct { - Initializer InitializerFunc + Initializer func(spec Spec, message any) error } func (o *requestInitializerOption) applyToHandler(config *handlerConfig) { @@ -380,7 +377,7 @@ func (o *requestInitializerOption) applyToHandler(config *handlerConfig) { } type responseInitializerOption struct { - Initializer InitializerFunc + Initializer func(spec Spec, message any) error } func (o *responseInitializerOption) applyToClient(config *clientConfig) { From cbafb36de27560c2cb7d494df957352e2712439d Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 20 Nov 2023 10:04:21 -0500 Subject: [PATCH 11/13] Pass initializer as a func --- client.go | 14 +++++++------- client_stream.go | 24 ++++++++++++------------ client_stream_test.go | 5 +---- connect.go | 10 +++++----- handler.go | 8 ++++---- handler_stream.go | 20 ++++++++++---------- handler_stream_test.go | 4 +--- 7 files changed, 40 insertions(+), 45 deletions(-) diff --git a/client.go b/client.go index 78206415..87a2da32 100644 --- a/client.go +++ b/client.go @@ -92,7 +92,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien _ = conn.CloseResponse() return nil, err } - response, err := receiveUnaryResponse[Res](conn, config) + response, err := receiveUnaryResponse[Res](conn, config.Initializer) if err != nil { _ = conn.CloseResponse() return nil, err @@ -136,8 +136,8 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo return &ClientStreamForClient[Req, Res]{err: c.err} } return &ClientStreamForClient[Req, Res]{ - conn: c.newConn(ctx, StreamTypeClient, nil), - config: c.config, + conn: c.newConn(ctx, StreamTypeClient, nil), + initializer: c.config.Initializer, } } @@ -164,8 +164,8 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques return nil, err } return &ServerStreamForClient[Res]{ - conn: conn, - config: c.config, + conn: conn, + initializer: c.config.Initializer, }, nil } @@ -175,8 +175,8 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli return &BidiStreamForClient[Req, Res]{err: c.err} } return &BidiStreamForClient[Req, Res]{ - conn: c.newConn(ctx, StreamTypeBidi, nil), - config: c.config, + conn: c.newConn(ctx, StreamTypeBidi, nil), + initializer: c.config.Initializer, } } diff --git a/client_stream.go b/client_stream.go index 9536215f..2eec9bd7 100644 --- a/client_stream.go +++ b/client_stream.go @@ -25,8 +25,8 @@ import ( // It's returned from [Client].CallClientStream, but doesn't currently have an // exported constructor function. type ClientStreamForClient[Req, Res any] struct { - conn StreamingClientConn - config *clientConfig + conn StreamingClientConn + initializer func(Spec, any) error // Error from client construction. If non-nil, return for all calls. err error } @@ -79,7 +79,7 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err _ = c.conn.CloseResponse() return nil, err } - response, err := receiveUnaryResponse[Res](c.conn, c.config) + response, err := receiveUnaryResponse[Res](c.conn, c.initializer) if err != nil { _ = c.conn.CloseResponse() return nil, err @@ -98,9 +98,9 @@ func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) { // It's returned from [Client].CallServerStream, but doesn't currently have an // exported constructor function. type ServerStreamForClient[Res any] struct { - conn StreamingClientConn - config *clientConfig - msg *Res + conn StreamingClientConn + initializer func(Spec, any) error + msg *Res // Error from client construction. If non-nil, return for all calls. constructErr error // Error from conn.Receive(). @@ -117,8 +117,8 @@ func (s *ServerStreamForClient[Res]) Receive() bool { return false } s.msg = new(Res) - if s.config.Initializer != nil { - if err := s.config.Initializer(s.conn.Spec(), s.msg); err != nil { + if s.initializer != nil { + if err := s.initializer(s.conn.Spec(), s.msg); err != nil { s.receiveErr = err return false } @@ -183,8 +183,8 @@ func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) { // It's returned from [Client].CallBidiStream, but doesn't currently have an // exported constructor function. type BidiStreamForClient[Req, Res any] struct { - conn StreamingClientConn - config *clientConfig + conn StreamingClientConn + initializer func(Spec, any) error // Error from client construction. If non-nil, return for all calls. err error } @@ -243,8 +243,8 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { return nil, b.err } var msg Res - if b.config.Initializer != nil { - if err := b.config.Initializer(b.conn.Spec(), &msg); err != nil { + if b.initializer != nil { + if err := b.initializer(b.conn.Spec(), &msg); err != nil { return nil, err } } diff --git a/client_stream_test.go b/client_stream_test.go index 557f597e..e2b91598 100644 --- a/client_stream_test.go +++ b/client_stream_test.go @@ -55,11 +55,8 @@ func TestServerStreamForClient_NoPanics(t *testing.T) { func TestServerStreamForClient(t *testing.T) { t.Parallel() - config, cerr := newClientConfig("http://localhost:1234", nil) - assert.Nil(t, cerr) stream := &ServerStreamForClient[pingv1.PingResponse]{ - conn: &nopStreamingClientConn{}, - config: config, + conn: &nopStreamingClientConn{}, } // Ensure that each call to Receive allocates a new message. This helps // vtprotobuf, which doesn't automatically zero messages before unmarshaling diff --git a/connect.go b/connect.go index fd0dbb43..5479e78b 100644 --- a/connect.go +++ b/connect.go @@ -358,10 +358,10 @@ type handlerConnCloser interface { // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple // messages. -func receiveUnaryResponse[T any](conn StreamingClientConn, config *clientConfig) (*Response[T], error) { +func receiveUnaryResponse[T any](conn StreamingClientConn, initializer func(Spec, any) error) (*Response[T], error) { var msg T - if config.Initializer != nil { - if err := config.Initializer(conn.Spec(), &msg); err != nil { + if initializer != nil { + if err := initializer(conn.Spec(), &msg); err != nil { return nil, err } } @@ -373,8 +373,8 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, config *clientConfig) // trailers, try to read another message from the stream. // TODO: optimise unary calls to avoid this extra receive. var msg2 T - if config.Initializer != nil { - if err := config.Initializer(conn.Spec(), &msg2); err != nil { + if initializer != nil { + if err := initializer(conn.Spec(), &msg2); err != nil { return nil, err } } diff --git a/handler.go b/handler.go index e9ab2b75..fbb2b69e 100644 --- a/handler.go +++ b/handler.go @@ -113,8 +113,8 @@ func NewClientStreamHandler[Req, Res any]( config, func(ctx context.Context, conn StreamingHandlerConn) error { stream := &ClientStream[Req]{ - conn: conn, - config: config, + conn: conn, + initializer: config.Initializer, } res, err := implementation(ctx, stream) if err != nil { @@ -179,8 +179,8 @@ func NewBidiStreamHandler[Req, Res any]( return implementation( ctx, &BidiStream[Req, Res]{ - conn: conn, - config: config, + conn: conn, + initializer: config.Initializer, }, ) }, diff --git a/handler_stream.go b/handler_stream.go index 0755baa8..9bbb965a 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -25,10 +25,10 @@ import ( // It's constructed as part of [Handler] invocation, but doesn't currently have // an exported constructor. type ClientStream[Req any] struct { - conn StreamingHandlerConn - config *handlerConfig - msg *Req - err error + conn StreamingHandlerConn + initializer func(Spec, any) error + msg *Req + err error } // Spec returns the specification for the RPC. @@ -56,8 +56,8 @@ func (c *ClientStream[Req]) Receive() bool { return false } c.msg = new(Req) - if c.config.Initializer != nil { - if err := c.config.Initializer(c.Spec(), c.msg); err != nil { + if c.initializer != nil { + if err := c.initializer(c.Spec(), c.msg); err != nil { c.err = err return false } @@ -134,8 +134,8 @@ func (s *ServerStream[Res]) Conn() StreamingHandlerConn { // It's constructed as part of [Handler] invocation, but doesn't currently have // an exported constructor. type BidiStream[Req, Res any] struct { - conn StreamingHandlerConn - config *handlerConfig + conn StreamingHandlerConn + initializer func(Spec, any) error } // Spec returns the specification for the RPC. @@ -157,8 +157,8 @@ func (b *BidiStream[Req, Res]) RequestHeader() http.Header { // return an error that wraps [io.EOF]. func (b *BidiStream[Req, Res]) Receive() (*Req, error) { var req Req - if b.config.Initializer != nil { - if err := b.config.Initializer(b.Spec(), &req); err != nil { + if b.initializer != nil { + if err := b.initializer(b.Spec(), &req); err != nil { return nil, err } } diff --git a/handler_stream_test.go b/handler_stream_test.go index 25b0abf9..6aa0587c 100644 --- a/handler_stream_test.go +++ b/handler_stream_test.go @@ -27,10 +27,8 @@ func TestClientStreamIterator(t *testing.T) { // The server's view of a client streaming RPC is an iterator. For safety, // and to match grpc-go's behavior, we should allocate a new message for each // iteration. - config := newHandlerConfig("/connect.ping.v1.PingService/Ping", StreamTypeUnary, nil) stream := &ClientStream[pingv1.PingRequest]{ - conn: &nopStreamingHandlerConn{}, - config: config, + conn: &nopStreamingHandlerConn{}, } assert.True(t, stream.Receive()) first := fmt.Sprintf("%p", stream.Msg()) From 83f06cb05e7bfe92b694593abb5e8429eed29180 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 20 Nov 2023 18:04:29 -0500 Subject: [PATCH 12/13] Unify initializer options --- option.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/option.go b/option.go index e756416a..e54ce1a9 100644 --- a/option.go +++ b/option.go @@ -200,7 +200,7 @@ func WithSchema(schema any) Option { // will be a non nil pointer to the type created by the handler. Use the Schema // field of the [Spec] to determine the type of the message. func WithRequestInitializer(initializer func(spec Spec, message any) error) HandlerOption { - return &requestInitializerOption{Initializer: initializer} + return &initializerOption{Initializer: initializer} } // WithResponseInitializer provides a function that initializes a new message. @@ -209,7 +209,7 @@ func WithRequestInitializer(initializer func(spec Spec, message any) error) Hand // will be a non nil pointer to the type created by the client. Use the Schema // field of the [Spec] to determine the type of the message. func WithResponseInitializer(initializer func(spec Spec, message any) error) ClientOption { - return &responseInitializerOption{Initializer: initializer} + return &initializerOption{Initializer: initializer} } // WithCodec registers a serialization method with a client or handler. @@ -368,19 +368,15 @@ func (o *schemaOption) applyToHandler(config *handlerConfig) { config.Schema = o.Schema } -type requestInitializerOption struct { +type initializerOption struct { Initializer func(spec Spec, message any) error } -func (o *requestInitializerOption) applyToHandler(config *handlerConfig) { +func (o *initializerOption) applyToHandler(config *handlerConfig) { config.Initializer = o.Initializer } -type responseInitializerOption struct { - Initializer func(spec Spec, message any) error -} - -func (o *responseInitializerOption) applyToClient(config *clientConfig) { +func (o *initializerOption) applyToClient(config *clientConfig) { config.Initializer = o.Initializer } From a6f83172a0720e024a4094d861882617128b1c3d Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 20 Nov 2023 18:25:13 -0500 Subject: [PATCH 13/13] Add maybeInitializer to avoid nil checks --- client.go | 2 +- client_stream.go | 20 ++++++++------------ connect.go | 14 +++++--------- handler.go | 14 +++++--------- handler_stream.go | 18 +++++++----------- option.go | 15 +++++++++++++-- 6 files changed, 39 insertions(+), 44 deletions(-) diff --git a/client.go b/client.go index 87a2da32..84b84c0a 100644 --- a/client.go +++ b/client.go @@ -199,7 +199,7 @@ type clientConfig struct { Protocol protocol Procedure string Schema any - Initializer func(Spec, any) error + Initializer maybeInitializer CompressMinBytes int Interceptor Interceptor CompressionPools map[string]*compressionPool diff --git a/client_stream.go b/client_stream.go index 2eec9bd7..d3a8948b 100644 --- a/client_stream.go +++ b/client_stream.go @@ -26,7 +26,7 @@ import ( // exported constructor function. type ClientStreamForClient[Req, Res any] struct { conn StreamingClientConn - initializer func(Spec, any) error + initializer maybeInitializer // Error from client construction. If non-nil, return for all calls. err error } @@ -99,7 +99,7 @@ func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) { // exported constructor function. type ServerStreamForClient[Res any] struct { conn StreamingClientConn - initializer func(Spec, any) error + initializer maybeInitializer msg *Res // Error from client construction. If non-nil, return for all calls. constructErr error @@ -117,11 +117,9 @@ func (s *ServerStreamForClient[Res]) Receive() bool { return false } s.msg = new(Res) - if s.initializer != nil { - if err := s.initializer(s.conn.Spec(), s.msg); err != nil { - s.receiveErr = err - return false - } + if err := s.initializer.maybe(s.conn.Spec(), s.msg); err != nil { + s.receiveErr = err + return false } s.receiveErr = s.conn.Receive(s.msg) return s.receiveErr == nil @@ -184,7 +182,7 @@ func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) { // exported constructor function. type BidiStreamForClient[Req, Res any] struct { conn StreamingClientConn - initializer func(Spec, any) error + initializer maybeInitializer // Error from client construction. If non-nil, return for all calls. err error } @@ -243,10 +241,8 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) { return nil, b.err } var msg Res - if b.initializer != nil { - if err := b.initializer(b.conn.Spec(), &msg); err != nil { - return nil, err - } + if err := b.initializer.maybe(b.conn.Spec(), &msg); err != nil { + return nil, err } if err := b.conn.Receive(&msg); err != nil { return nil, err diff --git a/connect.go b/connect.go index 5479e78b..a97bd63d 100644 --- a/connect.go +++ b/connect.go @@ -358,12 +358,10 @@ type handlerConnCloser interface { // envelopes the message and attaches headers and trailers. It attempts to // consume the response stream and isn't appropriate when receiving multiple // messages. -func receiveUnaryResponse[T any](conn StreamingClientConn, initializer func(Spec, any) error) (*Response[T], error) { +func receiveUnaryResponse[T any](conn StreamingClientConn, initializer maybeInitializer) (*Response[T], error) { var msg T - if initializer != nil { - if err := initializer(conn.Spec(), &msg); err != nil { - return nil, err - } + if err := initializer.maybe(conn.Spec(), &msg); err != nil { + return nil, err } if err := conn.Receive(&msg); err != nil { return nil, err @@ -373,10 +371,8 @@ func receiveUnaryResponse[T any](conn StreamingClientConn, initializer func(Spec // trailers, try to read another message from the stream. // TODO: optimise unary calls to avoid this extra receive. var msg2 T - if initializer != nil { - if err := initializer(conn.Spec(), &msg2); err != nil { - return nil, err - } + if err := initializer.maybe(conn.Spec(), &msg2); err != nil { + return nil, err } if err := conn.Receive(&msg2); err == nil { return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages")) diff --git a/handler.go b/handler.go index fbb2b69e..05906fe2 100644 --- a/handler.go +++ b/handler.go @@ -64,10 +64,8 @@ func NewUnaryHandler[Req, Res any]( // Given a stream, how should we call the unary function? implementation := func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req - if config.Initializer != nil { - if err := config.Initializer(conn.Spec(), &msg); err != nil { - return err - } + if err := config.Initializer.maybe(conn.Spec(), &msg); err != nil { + return err } if err := conn.Receive(&msg); err != nil { return err @@ -143,10 +141,8 @@ func NewServerStreamHandler[Req, Res any]( config, func(ctx context.Context, conn StreamingHandlerConn) error { var msg Req - if config.Initializer != nil { - if err := config.Initializer(conn.Spec(), &msg); err != nil { - return err - } + if err := config.Initializer.maybe(conn.Spec(), &msg); err != nil { + return err } if err := conn.Receive(&msg); err != nil { return err @@ -260,7 +256,7 @@ type handlerConfig struct { Interceptor Interceptor Procedure string Schema any - Initializer func(Spec, any) error + Initializer maybeInitializer HandleGRPC bool HandleGRPCWeb bool RequireConnectProtocolHeader bool diff --git a/handler_stream.go b/handler_stream.go index 9bbb965a..b68f704f 100644 --- a/handler_stream.go +++ b/handler_stream.go @@ -26,7 +26,7 @@ import ( // an exported constructor. type ClientStream[Req any] struct { conn StreamingHandlerConn - initializer func(Spec, any) error + initializer maybeInitializer msg *Req err error } @@ -56,11 +56,9 @@ func (c *ClientStream[Req]) Receive() bool { return false } c.msg = new(Req) - if c.initializer != nil { - if err := c.initializer(c.Spec(), c.msg); err != nil { - c.err = err - return false - } + if err := c.initializer.maybe(c.Spec(), c.msg); err != nil { + c.err = err + return false } c.err = c.conn.Receive(c.msg) return c.err == nil @@ -135,7 +133,7 @@ func (s *ServerStream[Res]) Conn() StreamingHandlerConn { // an exported constructor. type BidiStream[Req, Res any] struct { conn StreamingHandlerConn - initializer func(Spec, any) error + initializer maybeInitializer } // Spec returns the specification for the RPC. @@ -157,10 +155,8 @@ func (b *BidiStream[Req, Res]) RequestHeader() http.Header { // return an error that wraps [io.EOF]. func (b *BidiStream[Req, Res]) Receive() (*Req, error) { var req Req - if b.initializer != nil { - if err := b.initializer(b.Spec(), &req); err != nil { - return nil, err - } + if err := b.initializer.maybe(b.Spec(), &req); err != nil { + return nil, err } if err := b.conn.Receive(&req); err != nil { return nil, err diff --git a/option.go b/option.go index e54ce1a9..35256cf5 100644 --- a/option.go +++ b/option.go @@ -373,11 +373,22 @@ type initializerOption struct { } func (o *initializerOption) applyToHandler(config *handlerConfig) { - config.Initializer = o.Initializer + config.Initializer = maybeInitializer{initializer: o.Initializer} } func (o *initializerOption) applyToClient(config *clientConfig) { - config.Initializer = o.Initializer + config.Initializer = maybeInitializer{initializer: o.Initializer} +} + +type maybeInitializer struct { + initializer func(spec Spec, message any) error +} + +func (o maybeInitializer) maybe(spec Spec, message any) error { + if o.initializer != nil { + return o.initializer(spec, message) + } + return nil } type clientOptionsOption struct {