diff --git a/examples/internal/integration/integration_test.go b/examples/internal/integration/integration_test.go index 12a33f01d9f..f7caa5fdf43 100644 --- a/examples/internal/integration/integration_test.go +++ b/examples/internal/integration/integration_test.go @@ -1460,7 +1460,7 @@ func testABEBulkEchoDurationError(t *testing.T, port int) { defer reqw.Close() for i := 0; i < 10; i++ { s := fmt.Sprintf("%d.123s", i) - if i == 9 { + if i == 5 { s = "invalidDurationFormat" } buf, err := marshaler.Marshal(s) @@ -1494,6 +1494,7 @@ func testABEBulkEchoDurationError(t *testing.T, port int) { } var got []*durationpb.Duration + var invalidArgumentCount int wg.Add(1) go func() { defer wg.Done() @@ -1515,7 +1516,9 @@ func testABEBulkEchoDurationError(t *testing.T, port int) { code, ok := item.Error["code"].(float64) if !ok { t.Errorf("item.Error[code] not found or not a number: %#v; i = %d", item.Error, i) - } else if int32(code) != 3 { + } else if int32(code) == 3 { + invalidArgumentCount++ + } else { t.Errorf("item.Error[code] = %v; want 3; i = %d", code, i) } continue @@ -1527,11 +1530,14 @@ func testABEBulkEchoDurationError(t *testing.T, port int) { } got = append(got, msg) } + + if invalidArgumentCount != 1 { + t.Errorf("got %d errors with code 3; want exactly 1", invalidArgumentCount) + } }() wg.Wait() - - if diff := cmp.Diff(got, want[:len(got)], protocmp.Transform()); diff != "" { + if diff := cmp.Diff(got, want[:5], protocmp.Transform()); diff != "" { t.Error(diff) } } diff --git a/examples/internal/server/a_bit_of_everything.go b/examples/internal/server/a_bit_of_everything.go index 1a4b548f500..fbf5847b146 100644 --- a/examples/internal/server/a_bit_of_everything.go +++ b/examples/internal/server/a_bit_of_everything.go @@ -326,18 +326,6 @@ func (s *_ABitOfEverythingServer) BulkEcho(stream examples.StreamService_BulkEch } func (s *_ABitOfEverythingServer) BulkEchoDuration(stream examples.StreamService_BulkEchoDurationServer) error { - var msgs []*durationpb.Duration - for { - msg, err := stream.Recv() - if err == io.EOF { - break - } - if err != nil { - return err - } - msgs = append(msgs, msg) - } - hmd := metadata.New(map[string]string{ "foo": "foo1", "bar": "bar1", @@ -346,18 +334,47 @@ func (s *_ABitOfEverythingServer) BulkEchoDuration(stream examples.StreamService return err } - for _, msg := range msgs { - grpclog.Info(msg) - if err := stream.Send(msg); err != nil { - return err + // Channel to coordinate between read and write goroutines + msgChan := make(chan *durationpb.Duration) + errChan := make(chan error) + + go func() { + defer close(msgChan) + for { + msg, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + errChan <- err + return + } + msgChan <- msg } - } + }() + + go func() { + for msg := range msgChan { + grpclog.Info(msg) + if err := stream.Send(msg); err != nil { + errChan <- err + return + } + } + // Sleep to mock the delay in receiving the request close. + // Accommodates the integration test client which is not a true + // bidirectional streaming client that supports request streaming. + time.Sleep(1 * time.Second) + close(errChan) + }() + + err := <-errChan stream.SetTrailer(metadata.New(map[string]string{ "foo": "foo2", "bar": "bar2", })) - return nil + return err } func (s *_ABitOfEverythingServer) DeepPathEcho(ctx context.Context, msg *examples.ABitOfEverything) (*examples.ABitOfEverything, error) { diff --git a/runtime/errors.go b/runtime/errors.go index 003944d6565..41cd4f5030e 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -81,7 +81,7 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R mux.errorHandler(ctx, mux, marshaler, w, r, err) } -// HttpStreamError uses the mux-configured stream error handler to notify error to the client without closing the connection. +// HTTPStreamError uses the mux-configured stream error handler to notify error to the client without closing the connection. func HTTPStreamError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { st := mux.streamErrorHandler(ctx, err) msg := errorChunk(st)