From 716794bb28bca01d77ca4416dcd75d2b8eb31bd4 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Sun, 22 Oct 2023 15:43:58 -0400 Subject: [PATCH] Ensure response errors are reported consistently Removes SetError in favour of reporting errors on BlockUntilResponseReady. ensureRequestMade removes sync.Once as the contract of ClientStream state calling Send with CloseRequest is not safe to call concurrently. --- client_ext_test.go | 4 +- connect_ext_test.go | 14 ++++--- duplex_http_call.go | 95 ++++++++++++++++++--------------------------- protocol_connect.go | 9 +++-- protocol_grpc.go | 13 ++++--- 5 files changed, 58 insertions(+), 77 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 36908a1d..110edb1b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -106,10 +106,8 @@ func TestClientPeer(t *testing.T) { err = clientStream.Send(&pingv1.SumRequest{}) assert.Nil(t, err) // server streaming - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1})) t.Cleanup(func() { - // TODO(emcfarlane): debug flaky test close with error: - // "unknown: io: read/write on closed pipe" assert.Nil(t, serverStream.Close()) }) assert.Nil(t, err) diff --git a/connect_ext_test.go b/connect_ext_test.go index e8437da9..24d85e26 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -774,15 +774,15 @@ func TestBidiRequiresHTTP2(t *testing.T) { server.URL(), ) stream := client.CumSum(context.Background()) - // Stream creates an async request, can error on Send or Receive. - err := stream.Send(&pingv1.CumSumRequest{}) - if err == nil { - assert.Nil(t, stream.CloseRequest()) - _, err = stream.Receive() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { + assert.ErrorIs(t, err, io.EOF) } + assert.Nil(t, stream.CloseRequest()) + _, err := stream.Receive() assert.NotNil(t, err) var connectErr *connect.Error assert.True(t, errors.As(err, &connectErr)) + t.Log(err) assert.Equal(t, connectErr.Code(), connect.CodeUnimplemented) assert.True( t, @@ -1988,13 +1988,14 @@ func TestBidiOverHTTP1(t *testing.T) { server.URL(), ) stream := client.CumSum(context.Background()) + // Stream creates an async request, can error on Send or Receive. if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil { assert.ErrorIs(t, err, io.EOF) } _, err := stream.Receive() assert.NotNil(t, err) assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) - assert.Equal(t, err.Error(), "unknown: HTTP status 505 HTTP Version Not Supported") + assert.True(t, strings.HasSuffix(err.Error(), "HTTP status 505 HTTP Version Not Supported")) assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) } @@ -2342,6 +2343,7 @@ func (p *pluggablePingServer) CumSum( func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { tb.Helper() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { assert.ErrorIs(tb, err, io.EOF) assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) diff --git a/duplex_http_call.go b/duplex_http_call.go index ab1a6db4..f57c8d82 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -42,13 +42,11 @@ type duplexHTTPCall struct { requestBodyReader *io.PipeReader requestBodyWriter *io.PipeWriter - sendRequestOnce sync.Once - responseReady chan struct{} - request *http.Request - response *http.Response - - errMu sync.Mutex - err error + requestSent bool + responseReady sync.WaitGroup + request *http.Request + response *http.Response + responseErr error } func newDuplexHTTPCall( @@ -80,24 +78,23 @@ func newDuplexHTTPCall( Body: pipeReader, Host: url.Host, }).WithContext(ctx) - return &duplexHTTPCall{ + call := &duplexHTTPCall{ ctx: ctx, httpClient: httpClient, streamType: spec.StreamType, requestBodyReader: pipeReader, requestBodyWriter: pipeWriter, request: request, - responseReady: make(chan struct{}), } + call.responseReady.Add(1) + return call } -// Write to the request body. Returns an error wrapping io.EOF after SetError -// is called. +// Write to the request body. func (d *duplexHTTPCall) Write(data []byte) (int, error) { d.ensureRequestMade() // Before we send any data, check if the context has been canceled. if err := d.ctx.Err(); err != nil { - d.SetError(err) return 0, wrapIfContextError(err) } // It's safe to write to this side of the pipe while net/http concurrently @@ -157,14 +154,12 @@ func (d *duplexHTTPCall) SetMethod(method string) { func (d *duplexHTTPCall) Read(data []byte) (int, error) { // First, we wait until we've gotten the response headers and established the // server-to-client side of the stream. - d.BlockUntilResponseReady() - if err := d.getError(); err != nil { + if err := d.BlockUntilResponseReady(); err != nil { // The stream is already closed or corrupted. return 0, err } // Before we read, check if the context has been canceled. if err := d.ctx.Err(); err != nil { - d.SetError(err) return 0, wrapIfContextError(err) } if d.response == nil { @@ -175,7 +170,7 @@ func (d *duplexHTTPCall) Read(data []byte) (int, error) { } func (d *duplexHTTPCall) CloseRead() error { - d.BlockUntilResponseReady() + d.responseReady.Wait() if d.response == nil { return nil } @@ -188,7 +183,9 @@ func (d *duplexHTTPCall) CloseRead() error { // ResponseStatusCode is the response's HTTP status code. func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { - d.BlockUntilResponseReady() + if err := d.BlockUntilResponseReady(); err != nil { + return 0, err + } if d.response == nil { return 0, fmt.Errorf("nil response from %v", d.request.URL) } @@ -197,7 +194,7 @@ func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { // ResponseHeader returns the response HTTP headers. func (d *duplexHTTPCall) ResponseHeader() http.Header { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Header } @@ -206,56 +203,39 @@ func (d *duplexHTTPCall) ResponseHeader() http.Header { // ResponseTrailer returns the response HTTP trailers. func (d *duplexHTTPCall) ResponseTrailer() http.Header { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Trailer } return make(http.Header) } -// SetError stores any error encountered processing the response. All -// subsequent calls to Read return this error, and all subsequent calls to -// Write return an error wrapping io.EOF. It's safe to call concurrently with -// any other method. -func (d *duplexHTTPCall) SetError(err error) { - d.errMu.Lock() - if d.err == nil { - d.err = wrapIfContextError(err) - } - // Closing the read side of the request body pipe acquires an internal lock, - // so we want to scope errMu's usage narrowly and avoid defer. - d.errMu.Unlock() - - // We've already hit an error, so we should stop writing to the request body. - // It's safe to call Close more than once and/or concurrently (calls after - // the first are no-ops), so it's okay for us to call this even though - // net/http sometimes closes the reader too. - // - // It's safe to ignore the returned error here. Under the hood, Close calls - // CloseWithError, which is documented to always return nil. - _ = d.requestBodyReader.Close() -} - // SetValidateResponse sets the response validation function. The function runs // in a background goroutine. func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Error) { d.validateResponse = validate } -func (d *duplexHTTPCall) BlockUntilResponseReady() { - <-d.responseReady +func (d *duplexHTTPCall) BlockUntilResponseReady() error { + d.responseReady.Wait() + return d.responseErr } +// ensureRequestMade sends the request headers and starts the response stream. +// It is not safe to call this concurrently. Write and CloseWrite call this but +// ensure that they're not called concurrently. func (d *duplexHTTPCall) ensureRequestMade() { - d.sendRequestOnce.Do(func() { - go d.makeRequest() - }) + if d.requestSent { + return // already sent + } + d.requestSent = true + go d.makeRequest() } func (d *duplexHTTPCall) makeRequest() { // This runs concurrently with Write and CloseWrite. Read and CloseRead wait // on d.responseReady, so we can't race with them. - defer close(d.responseReady) + defer d.responseReady.Done() // Promote the header Host to the request object. if host := d.request.Header.Get(headerHost); len(host) > 0 { @@ -276,33 +256,32 @@ func (d *duplexHTTPCall) makeRequest() { if _, ok := asError(err); !ok { err = NewError(CodeUnavailable, err) } - d.SetError(err) + d.responseErr = err + d.requestBodyReader.CloseWithError(io.EOF) return } d.response = response if err := d.validateResponse(response); err != nil { - d.SetError(err) + d.responseErr = err + d.response.Body.Close() + d.requestBodyReader.CloseWithError(io.EOF) return } if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { // If we somehow dialed an HTTP/1.x server, fail with an explicit message // rather than returning a more cryptic error later on. - d.SetError(errorf( + d.responseErr = errorf( CodeUnimplemented, "response from %v is HTTP/%d.%d: bidi streams require at least HTTP/2", d.request.URL, response.ProtoMajor, response.ProtoMinor, - )) + ) + d.response.Body.Close() + d.requestBodyReader.CloseWithError(io.EOF) } } -func (d *duplexHTTPCall) getError() error { - d.errMu.Lock() - defer d.errMu.Unlock() - return d.err -} - // See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33 func cloneURL(oldURL *url.URL) *url.URL { if oldURL == nil { diff --git a/protocol_connect.go b/protocol_connect.go index e3c74923..cab89c08 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -589,7 +589,9 @@ func (cc *connectStreamingClientConn) CloseRequest() error { } func (cc *connectStreamingClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil @@ -603,7 +605,6 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) return serverErr } // If the error is EOF but not from a last message, we want to return @@ -614,8 +615,8 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // There's no error in the trailers, so this was probably an error // converting the bytes to a message, an error reading from the network, or // just an EOF. We're going to return it to the user, but we also want to - // setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // close the writer so Send errors out. + _ = cc.duplexCall.CloseWrite() return err } diff --git a/protocol_grpc.go b/protocol_grpc.go index d3cb0062..7cdc411f 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -381,7 +381,9 @@ func (cc *grpcClientConn) CloseRequest() error { } func (cc *grpcClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil @@ -409,23 +411,22 @@ func (cc *grpcClientConn) Receive(msg any) error { // the stream has ended, Receive must return an error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) return serverErr } // This was probably an error converting the bytes to a message or an error // reading from the network. We're going to return it to the - // user, but we also want to setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // user, but we also want to close writes so Send errors out. + _ = cc.duplexCall.CloseWrite() return err } func (cc *grpcClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *grpcClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer }