From 9b22f41d8535fa3e40908c78ae66066c7972b6d9 Mon Sep 17 00:00:00 2001 From: Michal Witkowski Date: Wed, 1 Mar 2017 10:28:33 +0000 Subject: [PATCH 1/3] fix full duplex streaming --- proxy/handler.go | 16 +++++-- proxy/handler_test.go | 53 +++++++++++++++++------ testservice/test.pb.go | 98 +++++++++++++++++++++++++++++++++++------- testservice/test.proto | 3 ++ 4 files changed, 137 insertions(+), 33 deletions(-) diff --git a/proxy/handler.go b/proxy/handler.go index bacf7d8..f66abef 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/transport" + "golang.org/x/net/context" ) var ( @@ -64,21 +65,28 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") } fullMethodName := lowLevelServerStream.Method() + clientCtx, clientCancel := context.WithCancel(serverStream.Context()) backendConn, err := s.director(serverStream.Context(), fullMethodName) if err != nil { return err } // TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For. - clientStream, err := grpc.NewClientStream(serverStream.Context(), clientStreamDescForProxying, backendConn, fullMethodName) + clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName) if err != nil { return err } - defer clientStream.CloseSend() // always close this! - s2cErr := <-s.forwardServerToClient(serverStream, clientStream) - c2sErr := <-s.forwardClientToServer(clientStream, serverStream) + + s2cErrChan := s.forwardServerToClient(serverStream, clientStream) + c2sErrChan := s.forwardClientToServer(clientStream, serverStream) + s2cErr := <-s2cErrChan if s2cErr != io.EOF { + clientCancel() return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + } else { + clientStream.CloseSend() } + c2sErr := <-c2sErrChan + serverStream.SetTrailer(clientStream.Trailer()) // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. if c2sErr != io.EOF { diff --git a/proxy/handler_test.go b/proxy/handler_test.go index 96c42ce..bbd51ef 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -22,6 +22,8 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" + "fmt" + pb "github.com/mwitkow/grpc-proxy/testservice" ) @@ -72,6 +74,27 @@ func (s *assertingService) PingList(ping *pb.PingRequest, stream pb.TestService_ return nil } +func (s *assertingService) PingStream(stream pb.TestService_PingStreamServer) error { + stream.SendHeader(metadata.Pairs(serverHeaderMdKey, "I like turtles.")) + counter := int32(0) + for { + ping, err := stream.Recv() + if err == io.EOF { + break + } else if err != nil { + require.NoError(s.t, err, "can't fail reading stream") + return err + } + pong := &pb.PingResponse{Value: ping.Value, Counter: counter} + if err := stream.Send(pong); err != nil { + require.NoError(s.t, err, "can't fail sending back a pong") + } + counter += 1 + } + stream.SetTrailer(metadata.Pairs(serverTrailerMdKey, "I like ending turtles.")) + return nil +} + // ProxyHappySuite tests the "happy" path of handling: that everything works in absence of connection issues. type ProxyHappySuite struct { suite.Suite @@ -125,24 +148,28 @@ func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() { assert.Equal(s.T(), "testing rejection", grpc.ErrorDesc(err)) } -func (s *ProxyHappySuite) TestPingListStreamsAll() { - stream, err := s.testClient.PingList(s.ctx(), &pb.PingRequest{Value: "foo"}) - require.NoError(s.T(), err, "PingList request should be successful.") - // Check that the header arrives before all entries. - headerMd, err := stream.Header() - require.NoError(s.T(), err, "PingList headers should not error.") - assert.Len(s.T(), headerMd, 1, "PingList response headers user contain metadata") - count := 0 - for { +func (s *ProxyHappySuite) TestPingStream_FullDuplexWorks() { + stream, err := s.testClient.PingStream(s.ctx()) + require.NoError(s.T(), err, "PingStream request should be successful.") + + for i := 0; i < countListResponses; i++ { + ping := &pb.PingRequest{Value: fmt.Sprintf("foo:%d", i)} + require.NoError(s.T(), stream.Send(ping), "sending to PingStream must not fail") resp, err := stream.Recv() if err == io.EOF { break } - require.NoError(s.T(), err, "PingList stream should not be interrupted.") - require.Equal(s.T(), "foo", resp.Value) - count = count + 1 + if i == 0 { + // Check that the header arrives before all entries. + headerMd, err := stream.Header() + require.NoError(s.T(), err, "PingStream headers should not error.") + assert.Len(s.T(), headerMd, 1, "PingStream response headers user contain metadata") + } + assert.EqualValues(s.T(), i, resp.Counter, "ping roundtrip must succeed with the correct id") } - assert.Equal(s.T(), countListResponses, count, "PingList must successfully return all outputs") + require.NoError(s.T(), stream.CloseSend(), "no error on close send") + _, err = stream.Recv() + require.Equal(s.T(), io.EOF, err, "stream should close with io.EOF, meaining OK") // Check that the trailer headers are here. trailerMd := stream.Trailer() assert.Len(s.T(), trailerMd, 1, "PingList trailer headers user contain metadata") diff --git a/testservice/test.pb.go b/testservice/test.pb.go index acc40a2..1f1f482 100644 --- a/testservice/test.pb.go +++ b/testservice/test.pb.go @@ -60,7 +60,7 @@ func (m *PingRequest) GetValue() string { } type PingResponse struct { - Value string `protobuf:"bytes,1,opt,name=Value,json=value" json:"Value,omitempty"` + Value string `protobuf:"bytes,1,opt,name=Value" json:"Value,omitempty"` Counter int32 `protobuf:"varint,2,opt,name=counter" json:"counter,omitempty"` } @@ -104,6 +104,7 @@ type TestServiceClient interface { Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) PingError(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*Empty, error) PingList(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (TestService_PingListClient, error) + PingStream(ctx context.Context, opts ...grpc.CallOption) (TestService_PingStreamClient, error) } type testServiceClient struct { @@ -173,6 +174,37 @@ func (x *testServicePingListClient) Recv() (*PingResponse, error) { return m, nil } +func (c *testServiceClient) PingStream(ctx context.Context, opts ...grpc.CallOption) (TestService_PingStreamClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/mwitkow.testproto.TestService/PingStream", opts...) + if err != nil { + return nil, err + } + x := &testServicePingStreamClient{stream} + return x, nil +} + +type TestService_PingStreamClient interface { + Send(*PingRequest) error + Recv() (*PingResponse, error) + grpc.ClientStream +} + +type testServicePingStreamClient struct { + grpc.ClientStream +} + +func (x *testServicePingStreamClient) Send(m *PingRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServicePingStreamClient) Recv() (*PingResponse, error) { + m := new(PingResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // Server API for TestService service type TestServiceServer interface { @@ -180,6 +212,7 @@ type TestServiceServer interface { Ping(context.Context, *PingRequest) (*PingResponse, error) PingError(context.Context, *PingRequest) (*Empty, error) PingList(*PingRequest, TestService_PingListServer) error + PingStream(TestService_PingStreamServer) error } func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { @@ -261,6 +294,32 @@ func (x *testServicePingListServer) Send(m *PingResponse) error { return x.ServerStream.SendMsg(m) } +func _TestService_PingStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).PingStream(&testServicePingStreamServer{stream}) +} + +type TestService_PingStreamServer interface { + Send(*PingResponse) error + Recv() (*PingRequest, error) + grpc.ServerStream +} + +type testServicePingStreamServer struct { + grpc.ServerStream +} + +func (x *testServicePingStreamServer) Send(m *PingResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServicePingStreamServer) Recv() (*PingRequest, error) { + m := new(PingRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + var _TestService_serviceDesc = grpc.ServiceDesc{ ServiceName: "mwitkow.testproto.TestService", HandlerType: (*TestServiceServer)(nil), @@ -284,6 +343,12 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ Handler: _TestService_PingList_Handler, ServerStreams: true, }, + { + StreamName: "PingStream", + Handler: _TestService_PingStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, }, Metadata: "test.proto", } @@ -291,19 +356,20 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ func init() { proto.RegisterFile("test.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 218 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, - 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0xcc, 0x2d, 0xcf, 0x2c, 0xc9, 0xce, 0x2f, 0xd7, - 0x03, 0x89, 0x81, 0x85, 0x94, 0xd8, 0xb9, 0x58, 0x5d, 0x73, 0x0b, 0x4a, 0x2a, 0x95, 0x94, 0xb9, - 0xb8, 0x03, 0x32, 0xf3, 0xd2, 0x83, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x44, 0xb8, 0x58, - 0xcb, 0x12, 0x73, 0x4a, 0x53, 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x20, 0x1c, 0x25, 0x3b, - 0x2e, 0x1e, 0x88, 0xa2, 0xe2, 0x82, 0xfc, 0xbc, 0xe2, 0x54, 0x90, 0xaa, 0x30, 0x0c, 0x55, 0x42, - 0x12, 0x5c, 0xec, 0xc9, 0xf9, 0xa5, 0x79, 0x25, 0xa9, 0x45, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0xac, - 0x41, 0x30, 0xae, 0xd1, 0x1e, 0x26, 0x2e, 0xee, 0x90, 0xd4, 0xe2, 0x92, 0xe0, 0xd4, 0xa2, 0xb2, - 0xcc, 0xe4, 0x54, 0x21, 0x0f, 0x2e, 0x4e, 0x90, 0x79, 0x60, 0x17, 0x08, 0x49, 0xe8, 0x61, 0x38, - 0x4f, 0x0f, 0x2c, 0x23, 0x25, 0x8f, 0x45, 0x06, 0xd9, 0x1d, 0x4a, 0x0c, 0x42, 0x9e, 0x5c, 0x2c, - 0x20, 0x11, 0x21, 0x39, 0x9c, 0x4a, 0xc1, 0xfe, 0x22, 0xc6, 0x28, 0x77, 0xa8, 0xa3, 0x8a, 0x8a, - 0xf2, 0x8b, 0x08, 0x9a, 0x87, 0xd3, 0xd1, 0x4a, 0x0c, 0x42, 0xfe, 0x5c, 0x1c, 0x20, 0xa5, 0x3e, - 0x99, 0xc5, 0x25, 0x54, 0x70, 0x97, 0x01, 0x63, 0x12, 0x1b, 0x58, 0xdc, 0x18, 0x10, 0x00, 0x00, - 0xff, 0xff, 0x7b, 0xc9, 0x16, 0xf1, 0xd4, 0x01, 0x00, 0x00, + // 237 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xac, 0x8f, 0x31, 0x4b, 0xc4, 0x40, + 0x10, 0x85, 0x6f, 0xd5, 0x78, 0xde, 0x9c, 0x8d, 0x83, 0xc5, 0x62, 0xa1, 0xc7, 0xda, 0xa4, 0x5a, + 0x0e, 0xed, 0xed, 0x44, 0x05, 0x41, 0x49, 0xc4, 0xfe, 0x0c, 0x83, 0x2c, 0x9a, 0x6c, 0xdc, 0x9d, + 0x24, 0xf8, 0x33, 0xfc, 0xc7, 0xb2, 0x1b, 0x85, 0x80, 0x06, 0x2d, 0x52, 0xce, 0x7b, 0x1f, 0x8f, + 0x6f, 0x00, 0x98, 0x3c, 0xeb, 0xda, 0x59, 0xb6, 0x78, 0x50, 0x76, 0x86, 0x5f, 0x6c, 0xa7, 0x43, + 0x16, 0x23, 0x35, 0x87, 0xe4, 0xb2, 0xac, 0xf9, 0x5d, 0x9d, 0xc2, 0xf2, 0xde, 0x54, 0xcf, 0x19, + 0xbd, 0x35, 0xe4, 0x19, 0x0f, 0x21, 0x69, 0x37, 0xaf, 0x0d, 0x49, 0xb1, 0x12, 0xe9, 0x22, 0xeb, + 0x0f, 0x75, 0x01, 0xfb, 0x3d, 0xe4, 0x6b, 0x5b, 0x79, 0x0a, 0xd4, 0xe3, 0x90, 0x8a, 0x07, 0x4a, + 0x98, 0x17, 0xb6, 0xa9, 0x98, 0x9c, 0xdc, 0x5a, 0x89, 0x34, 0xc9, 0xbe, 0xcf, 0xb3, 0x8f, 0x6d, + 0x58, 0x3e, 0x90, 0xe7, 0x9c, 0x5c, 0x6b, 0x0a, 0xc2, 0x6b, 0x58, 0x84, 0xbd, 0x68, 0x80, 0x52, + 0xff, 0xd0, 0xd3, 0xb1, 0x39, 0x3a, 0xf9, 0xa5, 0x19, 0x7a, 0xa8, 0x19, 0xde, 0xc0, 0x4e, 0x48, + 0xf0, 0x78, 0x14, 0x8d, 0x7f, 0xfd, 0x67, 0xea, 0xea, 0x4b, 0xca, 0x39, 0xeb, 0xfe, 0xdc, 0x1b, + 0x95, 0x56, 0x33, 0xbc, 0x83, 0xbd, 0x80, 0xde, 0x1a, 0xcf, 0x13, 0x78, 0xad, 0x05, 0xe6, 0x00, + 0x21, 0xcb, 0xd9, 0xd1, 0xa6, 0x9c, 0x60, 0x32, 0x15, 0x6b, 0xf1, 0xb4, 0x1b, 0x9b, 0xf3, 0xcf, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x4a, 0xc0, 0x8e, 0xe7, 0x29, 0x02, 0x00, 0x00, } diff --git a/testservice/test.proto b/testservice/test.proto index 3ee34d0..54e3cf5 100644 --- a/testservice/test.proto +++ b/testservice/test.proto @@ -22,5 +22,8 @@ service TestService { rpc PingError(PingRequest) returns (Empty) {} rpc PingList(PingRequest) returns (stream PingResponse) {} + + rpc PingStream(stream PingRequest) returns (stream PingResponse) {} + } From 84242c4e690da18d16d2ab8f2fa47e45986220b6 Mon Sep 17 00:00:00 2001 From: Michal Witkowski Date: Wed, 1 Mar 2017 15:24:22 +0000 Subject: [PATCH 2/3] fix the "i don't know who finished" case --- proxy/handler.go | 51 ++++++++++++++++++++++++++++--------------- proxy/handler_test.go | 4 ++-- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/proxy/handler.go b/proxy/handler.go index f66abef..3c088da 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -6,10 +6,10 @@ package proxy import ( "io" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) var ( @@ -75,24 +75,39 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error if err != nil { return err } - s2cErrChan := s.forwardServerToClient(serverStream, clientStream) + defer close(s2cErrChan) c2sErrChan := s.forwardClientToServer(clientStream, serverStream) - s2cErr := <-s2cErrChan - if s2cErr != io.EOF { - clientCancel() - return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) - } else { - clientStream.CloseSend() - } - c2sErr := <-c2sErrChan - - serverStream.SetTrailer(clientStream.Trailer()) - // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. - if c2sErr != io.EOF { - return c2sErr + defer close(c2sErrChan) + // We don't know which side is going to stop sending first, so we need a select between the two. + for i := 0; i < 2; i++ { + select { + case s2cErr := <-s2cErrChan: + if s2cErr == io.EOF { + // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./ + // the clientStream>serverStream may continue pumping though. + clientStream.CloseSend() + break + } else { + // however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need + // to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and + // exit with an error to the stack + clientCancel() + return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + } + case c2sErr := <-c2sErrChan: + // This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two + // cases we may have received Trailers as part of the call. In case of other errors (stream closed) the trailers + // will be nil. + serverStream.SetTrailer(clientStream.Trailer()) + // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. + if c2sErr != io.EOF { + return c2sErr + } + return nil + } } - return nil + return grpc.Errorf(codes.Internal, "gRPC proxying should never reach this stage.") } func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error { @@ -123,7 +138,6 @@ func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerSt break } } - close(ret) }() return ret } @@ -134,15 +148,16 @@ func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientSt f := &frame{} for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { + //grpclog.Printf("s2c err: %v", err) ret <- err // this can be io.EOF which is happy case break } if err := dst.SendMsg(f); err != nil { + //grpclog.Printf("s2c err: %v", err) ret <- err break } } - close(ret) }() return ret } diff --git a/proxy/handler_test.go b/proxy/handler_test.go index bbd51ef..ea67ab2 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -210,12 +210,12 @@ func (s *ProxyHappySuite) SetupSuite() { "Ping") // Start the serving loops. + s.T().Logf("starting grpc.Server at: %v", s.serverListener.Addr().String()) go func() { - s.T().Logf("starting grpc.Server at: %v", s.serverListener.Addr().String()) s.server.Serve(s.serverListener) }() + s.T().Logf("starting grpc.Proxy at: %v", s.proxyListener.Addr().String()) go func() { - s.T().Logf("starting grpc.Proxy at: %v", s.proxyListener.Addr().String()) s.proxy.Serve(s.proxyListener) }() From de4d3db538565636e1e977102f6f0bd1ed0ce9c2 Mon Sep 17 00:00:00 2001 From: Michal Witkowski Date: Wed, 1 Mar 2017 15:25:51 +0000 Subject: [PATCH 3/3] remove spurious printfs --- proxy/handler.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/proxy/handler.go b/proxy/handler.go index 3c088da..f5868d9 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -148,12 +148,10 @@ func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientSt f := &frame{} for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - //grpclog.Printf("s2c err: %v", err) ret <- err // this can be io.EOF which is happy case break } if err := dst.SendMsg(f); err != nil { - //grpclog.Printf("s2c err: %v", err) ret <- err break }