Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic message types #640

Merged
merged 15 commits into from
Nov 27, 2023
18 changes: 14 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
_ = conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](conn)
response, err := receiveUnaryResponse[Res](conn, config.Initializer)
if err != nil {
_ = conn.CloseResponse()
return nil, err
Expand Down Expand Up @@ -135,7 +135,10 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)}
return &ClientStreamForClient[Req, Res]{
conn: c.newConn(ctx, StreamTypeClient, nil),
initializer: c.config.Initializer,
}
}

// CallServerStream calls a server streaming procedure.
Expand All @@ -160,15 +163,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques
if err := conn.CloseRequest(); err != nil {
return nil, err
}
return &ServerStreamForClient[Res]{conn: conn}, nil
return &ServerStreamForClient[Res]{
conn: conn,
initializer: c.config.Initializer,
}, nil
}

// CallBidiStream calls a bidirectional streaming procedure.
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)}
return &BidiStreamForClient[Req, Res]{
conn: c.newConn(ctx, StreamTypeBidi, nil),
initializer: c.config.Initializer,
}
}

func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
Expand All @@ -190,6 +199,7 @@ type clientConfig struct {
Protocol protocol
Procedure string
Schema any
Initializer func(Spec, any) error
CompressMinBytes int
Interceptor Interceptor
CompressionPools map[string]*compressionPool
Expand Down
169 changes: 169 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect"
"connectrpc.com/connect/internal/memhttp/memhttptest"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
)

func TestNewClient_InitFailure(t *testing.T) {
Expand Down Expand Up @@ -226,6 +228,173 @@ func TestSpecSchema(t *testing.T) {
})
}

func TestDynamicClient(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
server := memhttptest.NewServer(t, mux)
ctx := context.Background()
initializer := func(spec connect.Spec, msg any) error {
dynamic, ok := msg.(*dynamicpb.Message)
if !ok {
return nil
}
desc, ok := spec.Schema.(protoreflect.MethodDescriptor)
if !ok {
return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic)
}
if spec.IsClient {
*dynamic = *dynamicpb.NewMessage(desc.Output())
} else {
*dynamic = *dynamicpb.NewMessage(desc.Input())
}
return nil
}
t.Run("unary", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/Ping",
connect.WithSchema(methodDesc),
connect.WithIdempotency(connect.IdempotencyNoSideEffects),
connect.WithResponseInitializer(initializer),
)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
res, err := client.CallUnary(ctx, connect.NewRequest(msg))
assert.Nil(t, err)
got := res.Msg.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
})
t.Run("clientStream", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/Sum",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
stream := client.CallClientStream(ctx)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
assert.Nil(t, stream.Send(msg))
assert.Nil(t, stream.Send(msg))
rsp, err := stream.CloseAndReceive()
if !assert.Nil(t, err) {
return
}
got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int()
assert.Equal(t, got, 42*2)
})
t.Run("serverStream", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/CountUp",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(2),
)
req := connect.NewRequest(msg)
stream, err := client.CallServerStream(ctx, req)
if !assert.Nil(t, err) {
return
}
for i := 1; stream.Receive(); i++ {
out := stream.Msg()
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, int64(i))
}
assert.Nil(t, stream.Close())
})
t.Run("bidi", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/CumSum",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
stream := client.CallBidiStream(ctx)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
assert.Nil(t, stream.Send(msg))
assert.Nil(t, stream.CloseRequest())
out, err := stream.Receive()
if assert.Nil(t, err) {
return
}
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
})
t.Run("option", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
optionCalled := false
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/Ping",
connect.WithSchema(methodDesc),
connect.WithIdempotency(connect.IdempotencyNoSideEffects),
connect.WithResponseInitializer(
func(spec connect.Spec, msg any) error {
assert.NotNil(t, spec)
assert.NotNil(t, msg)
dynamic, ok := msg.(*dynamicpb.Message)
if !assert.True(t, ok) {
return fmt.Errorf("unexpected message type: %T", msg)
}
*dynamic = *dynamicpb.NewMessage(methodDesc.Output())
optionCalled = true
return nil
},
),
)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
res, err := client.CallUnary(ctx, connect.NewRequest(msg))
assert.Nil(t, err)
got := res.Msg.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
assert.True(t, optionCalled)
})
}

type notModifiedPingServer struct {
pingv1connect.UnimplementedPingServiceHandler

Expand Down
24 changes: 19 additions & 5 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import (
// It's returned from [Client].CallClientStream, but doesn't currently have an
// exported constructor function.
type ClientStreamForClient[Req, Res any] struct {
conn StreamingClientConn
conn StreamingClientConn
initializer func(Spec, any) error
// Error from client construction. If non-nil, return for all calls.
err error
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err
_ = c.conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](c.conn)
response, err := receiveUnaryResponse[Res](c.conn, c.initializer)
if err != nil {
_ = c.conn.CloseResponse()
return nil, err
Expand All @@ -97,8 +98,9 @@ func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) {
// It's returned from [Client].CallServerStream, but doesn't currently have an
// exported constructor function.
type ServerStreamForClient[Res any] struct {
conn StreamingClientConn
msg *Res
conn StreamingClientConn
initializer func(Spec, any) error
msg *Res
// Error from client construction. If non-nil, return for all calls.
constructErr error
// Error from conn.Receive().
Expand All @@ -115,6 +117,12 @@ func (s *ServerStreamForClient[Res]) Receive() bool {
return false
}
s.msg = new(Res)
if s.initializer != nil {
if err := s.initializer(s.conn.Spec(), s.msg); err != nil {
s.receiveErr = err
return false
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nil checks is made in 8 separate places. What if instead you added an unexported initialize function to config and then pass config.initialize instead of config.Initializer? That way, that one method could do the nil check (and just return nil error if the func is nil).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having the function branch off the two configs I've added a maybeInitializer that wraps the func and does the nil check. Wdyt?

s.receiveErr = s.conn.Receive(s.msg)
return s.receiveErr == nil
}
Expand Down Expand Up @@ -175,7 +183,8 @@ func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) {
// It's returned from [Client].CallBidiStream, but doesn't currently have an
// exported constructor function.
type BidiStreamForClient[Req, Res any] struct {
conn StreamingClientConn
conn StreamingClientConn
initializer func(Spec, any) error
// Error from client construction. If non-nil, return for all calls.
err error
}
Expand Down Expand Up @@ -234,6 +243,11 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {
return nil, b.err
}
var msg Res
if b.initializer != nil {
if err := b.initializer(b.conn.Spec(), &msg); err != nil {
return nil, err
}
}
if err := b.conn.Receive(&msg); err != nil {
return nil, err
}
Expand Down
8 changes: 7 additions & 1 deletion client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ func TestServerStreamForClient_NoPanics(t *testing.T) {

func TestServerStreamForClient(t *testing.T) {
t.Parallel()
stream := &ServerStreamForClient[pingv1.PingResponse]{conn: &nopStreamingClientConn{}}
stream := &ServerStreamForClient[pingv1.PingResponse]{
conn: &nopStreamingClientConn{},
}
// Ensure that each call to Receive allocates a new message. This helps
// vtprotobuf, which doesn't automatically zero messages before unmarshaling
// (see https://connectrpc.com/connect/issues/345), and it's also
Expand Down Expand Up @@ -104,3 +106,7 @@ type nopStreamingClientConn struct {
func (c *nopStreamingClientConn) Receive(msg any) error {
return nil
}

func (c *nopStreamingClientConn) Spec() Spec {
return Spec{}
}
16 changes: 14 additions & 2 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,27 @@ type handlerConnCloser interface {
// envelopes the message and attaches headers and trailers. It attempts to
// consume the response stream and isn't appropriate when receiving multiple
// messages.
func receiveUnaryResponse[T any](conn StreamingClientConn) (*Response[T], error) {
func receiveUnaryResponse[T any](conn StreamingClientConn, initializer func(Spec, any) error) (*Response[T], error) {
var msg T
if initializer != nil {
if err := initializer(conn.Spec(), &msg); err != nil {
return nil, err
}
}
if err := conn.Receive(&msg); err != nil {
return nil, err
}
// In a well-formed stream, the response message may be followed by a block
// of in-stream trailers or HTTP trailers. To ensure that we receive the
// trailers, try to read another message from the stream.
if err := conn.Receive(new(T)); err == nil {
// TODO: optimise unary calls to avoid this extra receive.
var msg2 T
if initializer != nil {
if err := initializer(conn.Spec(), &msg2); err != nil {
return nil, err
}
}
if err := conn.Receive(&msg2); err == nil {
return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages"))
} else if err != nil && !errors.Is(err, io.EOF) {
return nil, NewError(CodeUnknown, err)
Expand Down
1 change: 0 additions & 1 deletion connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2364,7 +2364,6 @@ func (p *pluggablePingServer) CumSum(

func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) {
tb.Helper()

if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(tb, err, io.EOF)
assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown)
Expand Down
Loading