Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic messages with Initializer options #630

Closed
wants to merge 10 commits into from
18 changes: 14 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
_ = conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](conn)
response, err := receiveUnaryResponse[Res](conn, config)
if err != nil {
_ = conn.CloseResponse()
return nil, err
Expand Down Expand Up @@ -135,7 +135,10 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)}
return &ClientStreamForClient[Req, Res]{
conn: c.newConn(ctx, StreamTypeClient, nil),
config: c.config,
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
}
}

// CallServerStream calls a server streaming procedure.
Expand All @@ -160,15 +163,21 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques
if err := conn.CloseRequest(); err != nil {
return nil, err
}
return &ServerStreamForClient[Res]{conn: conn}, nil
return &ServerStreamForClient[Res]{
conn: conn,
config: c.config,
}, nil
}

// CallBidiStream calls a bidirectional streaming procedure.
func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForClient[Req, Res] {
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)}
return &BidiStreamForClient[Req, Res]{
conn: c.newConn(ctx, StreamTypeBidi, nil),
config: c.config,
}
}

func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
Expand All @@ -190,6 +199,7 @@ type clientConfig struct {
Protocol protocol
Procedure string
Schema any
Initializer InitializerFunc
CompressMinBytes int
Interceptor Interceptor
CompressionPools map[string]*compressionPool
Expand Down
169 changes: 169 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect"
"connectrpc.com/connect/internal/memhttp/memhttptest"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
)

func TestNewClient_InitFailure(t *testing.T) {
Expand Down Expand Up @@ -226,6 +228,173 @@ func TestSpecSchema(t *testing.T) {
})
}

func TestDynamicClient(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}))
server := memhttptest.NewServer(t, mux)
ctx := context.Background()
initializer := func(spec connect.Spec, msg any) error {
dynamic, ok := msg.(*dynamicpb.Message)
if !ok {
return nil
}
desc, ok := spec.Schema.(protoreflect.MethodDescriptor)
if !ok {
return fmt.Errorf("invalid schema type %T for %T message", spec.Schema, dynamic)
}
if spec.IsClient {
*dynamic = *dynamicpb.NewMessage(desc.Output())
} else {
*dynamic = *dynamicpb.NewMessage(desc.Input())
}
return nil
}
t.Run("unary", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/Ping",
connect.WithSchema(methodDesc),
connect.WithIdempotency(connect.IdempotencyNoSideEffects),
connect.WithResponseInitializer(initializer),
)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
res, err := client.CallUnary(ctx, connect.NewRequest(msg))
assert.Nil(t, err)
got := res.Msg.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
})
t.Run("clientStream", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Sum")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/Sum",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
stream := client.CallClientStream(ctx)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
assert.Nil(t, stream.Send(msg))
assert.Nil(t, stream.Send(msg))
rsp, err := stream.CloseAndReceive()
if !assert.Nil(t, err) {
return
}
got := rsp.Msg.Get(methodDesc.Output().Fields().ByName("sum")).Int()
assert.Equal(t, got, 42*2)
})
t.Run("serverStream", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CountUp")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/CountUp",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(2),
)
req := connect.NewRequest(msg)
stream, err := client.CallServerStream(ctx, req)
if !assert.Nil(t, err) {
return
}
for i := 1; stream.Receive(); i++ {
out := stream.Msg()
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, int64(i))
}
assert.Nil(t, stream.Close())
})
t.Run("bidi", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.CumSum")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/CumSum",
connect.WithSchema(methodDesc),
connect.WithResponseInitializer(initializer),
)
stream := client.CallBidiStream(ctx)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
assert.Nil(t, stream.Send(msg))
assert.Nil(t, stream.CloseRequest())
out, err := stream.Receive()
if assert.Nil(t, err) {
return
}
got := out.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
})
t.Run("option", func(t *testing.T) {
t.Parallel()
desc, err := protoregistry.GlobalFiles.FindDescriptorByName("connect.ping.v1.PingService.Ping")
assert.Nil(t, err)
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
assert.True(t, ok)
optionCalled := false
client := connect.NewClient[dynamicpb.Message, dynamicpb.Message](
server.Client(),
server.URL()+"/connect.ping.v1.PingService/Ping",
connect.WithSchema(methodDesc),
connect.WithIdempotency(connect.IdempotencyNoSideEffects),
connect.WithResponseInitializer(
func(spec connect.Spec, msg any) error {
assert.NotNil(t, spec)
assert.NotNil(t, msg)
dynamic, ok := msg.(*dynamicpb.Message)
if !assert.True(t, ok) {
return fmt.Errorf("unexpected message type: %T", msg)
}
*dynamic = *dynamicpb.NewMessage(methodDesc.Output())
optionCalled = true
return nil
},
),
)
msg := dynamicpb.NewMessage(methodDesc.Input())
msg.Set(
methodDesc.Input().Fields().ByName("number"),
protoreflect.ValueOfInt64(42),
)
res, err := client.CallUnary(ctx, connect.NewRequest(msg))
assert.Nil(t, err)
got := res.Msg.Get(methodDesc.Output().Fields().ByName("number")).Int()
assert.Equal(t, got, 42)
assert.True(t, optionCalled)
})
}

type notModifiedPingServer struct {
pingv1connect.UnimplementedPingServiceHandler

Expand Down
24 changes: 19 additions & 5 deletions client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import (
// It's returned from [Client].CallClientStream, but doesn't currently have an
// exported constructor function.
type ClientStreamForClient[Req, Res any] struct {
conn StreamingClientConn
conn StreamingClientConn
config *clientConfig
// Error from client construction. If non-nil, return for all calls.
err error
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func (c *ClientStreamForClient[Req, Res]) CloseAndReceive() (*Response[Res], err
_ = c.conn.CloseResponse()
return nil, err
}
response, err := receiveUnaryResponse[Res](c.conn)
response, err := receiveUnaryResponse[Res](c.conn, c.config)
if err != nil {
_ = c.conn.CloseResponse()
return nil, err
Expand All @@ -97,8 +98,9 @@ func (c *ClientStreamForClient[Req, Res]) Conn() (StreamingClientConn, error) {
// It's returned from [Client].CallServerStream, but doesn't currently have an
// exported constructor function.
type ServerStreamForClient[Res any] struct {
conn StreamingClientConn
msg *Res
conn StreamingClientConn
config *clientConfig
msg *Res
// Error from client construction. If non-nil, return for all calls.
constructErr error
// Error from conn.Receive().
Expand All @@ -115,6 +117,12 @@ func (s *ServerStreamForClient[Res]) Receive() bool {
return false
}
s.msg = new(Res)
if s.config.Initializer != nil {
if err := s.config.Initializer(s.conn.Spec(), s.msg); err != nil {
s.receiveErr = err
return false
}
}
s.receiveErr = s.conn.Receive(s.msg)
return s.receiveErr == nil
}
Expand Down Expand Up @@ -175,7 +183,8 @@ func (s *ServerStreamForClient[Res]) Conn() (StreamingClientConn, error) {
// It's returned from [Client].CallBidiStream, but doesn't currently have an
// exported constructor function.
type BidiStreamForClient[Req, Res any] struct {
conn StreamingClientConn
conn StreamingClientConn
config *clientConfig
// Error from client construction. If non-nil, return for all calls.
err error
}
Expand Down Expand Up @@ -234,6 +243,11 @@ func (b *BidiStreamForClient[Req, Res]) Receive() (*Res, error) {
return nil, b.err
}
var msg Res
if b.config.Initializer != nil {
if err := b.config.Initializer(b.conn.Spec(), &msg); err != nil {
return nil, err
}
}
if err := b.conn.Receive(&msg); err != nil {
return nil, err
}
Expand Down
11 changes: 10 additions & 1 deletion client_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ func TestServerStreamForClient_NoPanics(t *testing.T) {

func TestServerStreamForClient(t *testing.T) {
t.Parallel()
stream := &ServerStreamForClient[pingv1.PingResponse]{conn: &nopStreamingClientConn{}}
config, cerr := newClientConfig("http://localhost:1234", nil)
assert.Nil(t, cerr)
stream := &ServerStreamForClient[pingv1.PingResponse]{
conn: &nopStreamingClientConn{},
config: config,
}
// Ensure that each call to Receive allocates a new message. This helps
// vtprotobuf, which doesn't automatically zero messages before unmarshaling
// (see https://connectrpc.com/connect/issues/345), and it's also
Expand Down Expand Up @@ -104,3 +109,7 @@ type nopStreamingClientConn struct {
func (c *nopStreamingClientConn) Receive(msg any) error {
return nil
}

func (c *nopStreamingClientConn) Spec() Spec {
return Spec{}
}
16 changes: 14 additions & 2 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,27 @@ type handlerConnCloser interface {
// envelopes the message and attaches headers and trailers. It attempts to
// consume the response stream and isn't appropriate when receiving multiple
// messages.
func receiveUnaryResponse[T any](conn StreamingClientConn) (*Response[T], error) {
func receiveUnaryResponse[T any](conn StreamingClientConn, config *clientConfig) (*Response[T], error) {
var msg T
if config.Initializer != nil {
if err := config.Initializer(conn.Spec(), &msg); err != nil {
return nil, err
}
}
if err := conn.Receive(&msg); err != nil {
return nil, err
}
// In a well-formed stream, the response message may be followed by a block
// of in-stream trailers or HTTP trailers. To ensure that we receive the
// trailers, try to read another message from the stream.
if err := conn.Receive(new(T)); err == nil {
// TODO: optimise unary calls to avoid this extra receive.
var msg2 T
if config.Initializer != nil {
if err := config.Initializer(conn.Spec(), &msg2); err != nil {
return nil, err
}
}
if err := conn.Receive(&msg2); err == nil {
return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages"))
} else if err != nil && !errors.Is(err, io.EOF) {
return nil, NewError(CodeUnknown, err)
Expand Down
1 change: 0 additions & 1 deletion connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2360,7 +2360,6 @@ func (p *pluggablePingServer) CumSum(

func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) {
tb.Helper()

if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(tb, err, io.EOF)
assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown)
Expand Down
Loading