Skip to content

Commit

Permalink
Ensure response errors are reported consistently
Browse files Browse the repository at this point in the history
Removes SetError in favour of reporting errors on BlockUntilResponseReady.
ensureRequestMade removes sync.Once as the contract of ClientStream state
calling Send with CloseRequest is not safe to call concurrently.
  • Loading branch information
emcfarlane committed Oct 22, 2023
1 parent dd98dd4 commit 716794b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 77 deletions.
4 changes: 1 addition & 3 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ func TestClientPeer(t *testing.T) {
err = clientStream.Send(&pingv1.SumRequest{})
assert.Nil(t, err)
// server streaming
serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{}))
serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1}))
t.Cleanup(func() {
// TODO(emcfarlane): debug flaky test close with error:
// "unknown: io: read/write on closed pipe"
assert.Nil(t, serverStream.Close())
})
assert.Nil(t, err)
Expand Down
14 changes: 8 additions & 6 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -774,15 +774,15 @@ func TestBidiRequiresHTTP2(t *testing.T) {
server.URL(),
)
stream := client.CumSum(context.Background())
// Stream creates an async request, can error on Send or Receive.
err := stream.Send(&pingv1.CumSumRequest{})
if err == nil {
assert.Nil(t, stream.CloseRequest())
_, err = stream.Receive()
if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(t, err, io.EOF)
}
assert.Nil(t, stream.CloseRequest())
_, err := stream.Receive()
assert.NotNil(t, err)
var connectErr *connect.Error
assert.True(t, errors.As(err, &connectErr))
t.Log(err)
assert.Equal(t, connectErr.Code(), connect.CodeUnimplemented)
assert.True(
t,
Expand Down Expand Up @@ -1988,13 +1988,14 @@ func TestBidiOverHTTP1(t *testing.T) {
server.URL(),
)
stream := client.CumSum(context.Background())
// Stream creates an async request, can error on Send or Receive.
if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil {
assert.ErrorIs(t, err, io.EOF)
}
_, err := stream.Receive()
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
assert.Equal(t, err.Error(), "unknown: HTTP status 505 HTTP Version Not Supported")
assert.True(t, strings.HasSuffix(err.Error(), "HTTP status 505 HTTP Version Not Supported"))
assert.Nil(t, stream.CloseRequest())
assert.Nil(t, stream.CloseResponse())
}
Expand Down Expand Up @@ -2342,6 +2343,7 @@ func (p *pluggablePingServer) CumSum(

func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) {
tb.Helper()

if err := stream.Send(&pingv1.CumSumRequest{}); err != nil {
assert.ErrorIs(tb, err, io.EOF)
assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown)
Expand Down
95 changes: 37 additions & 58 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ type duplexHTTPCall struct {
requestBodyReader *io.PipeReader
requestBodyWriter *io.PipeWriter

sendRequestOnce sync.Once
responseReady chan struct{}
request *http.Request
response *http.Response

errMu sync.Mutex
err error
requestSent bool
responseReady sync.WaitGroup
request *http.Request
response *http.Response
responseErr error
}

func newDuplexHTTPCall(
Expand Down Expand Up @@ -80,24 +78,23 @@ func newDuplexHTTPCall(
Body: pipeReader,
Host: url.Host,
}).WithContext(ctx)
return &duplexHTTPCall{
call := &duplexHTTPCall{
ctx: ctx,
httpClient: httpClient,
streamType: spec.StreamType,
requestBodyReader: pipeReader,
requestBodyWriter: pipeWriter,
request: request,
responseReady: make(chan struct{}),
}
call.responseReady.Add(1)
return call
}

// Write to the request body. Returns an error wrapping io.EOF after SetError
// is called.
// Write to the request body.
func (d *duplexHTTPCall) Write(data []byte) (int, error) {
d.ensureRequestMade()
// Before we send any data, check if the context has been canceled.
if err := d.ctx.Err(); err != nil {
d.SetError(err)
return 0, wrapIfContextError(err)
}
// It's safe to write to this side of the pipe while net/http concurrently
Expand Down Expand Up @@ -157,14 +154,12 @@ func (d *duplexHTTPCall) SetMethod(method string) {
func (d *duplexHTTPCall) Read(data []byte) (int, error) {
// First, we wait until we've gotten the response headers and established the
// server-to-client side of the stream.
d.BlockUntilResponseReady()
if err := d.getError(); err != nil {
if err := d.BlockUntilResponseReady(); err != nil {
// The stream is already closed or corrupted.
return 0, err
}
// Before we read, check if the context has been canceled.
if err := d.ctx.Err(); err != nil {
d.SetError(err)
return 0, wrapIfContextError(err)
}
if d.response == nil {
Expand All @@ -175,7 +170,7 @@ func (d *duplexHTTPCall) Read(data []byte) (int, error) {
}

func (d *duplexHTTPCall) CloseRead() error {
d.BlockUntilResponseReady()
d.responseReady.Wait()
if d.response == nil {
return nil
}
Expand All @@ -188,7 +183,9 @@ func (d *duplexHTTPCall) CloseRead() error {

// ResponseStatusCode is the response's HTTP status code.
func (d *duplexHTTPCall) ResponseStatusCode() (int, error) {
d.BlockUntilResponseReady()
if err := d.BlockUntilResponseReady(); err != nil {
return 0, err
}
if d.response == nil {
return 0, fmt.Errorf("nil response from %v", d.request.URL)
}
Expand All @@ -197,7 +194,7 @@ func (d *duplexHTTPCall) ResponseStatusCode() (int, error) {

// ResponseHeader returns the response HTTP headers.
func (d *duplexHTTPCall) ResponseHeader() http.Header {
d.BlockUntilResponseReady()
_ = d.BlockUntilResponseReady()
if d.response != nil {
return d.response.Header
}
Expand All @@ -206,56 +203,39 @@ func (d *duplexHTTPCall) ResponseHeader() http.Header {

// ResponseTrailer returns the response HTTP trailers.
func (d *duplexHTTPCall) ResponseTrailer() http.Header {
d.BlockUntilResponseReady()
_ = d.BlockUntilResponseReady()
if d.response != nil {
return d.response.Trailer
}
return make(http.Header)
}

// SetError stores any error encountered processing the response. All
// subsequent calls to Read return this error, and all subsequent calls to
// Write return an error wrapping io.EOF. It's safe to call concurrently with
// any other method.
func (d *duplexHTTPCall) SetError(err error) {
d.errMu.Lock()
if d.err == nil {
d.err = wrapIfContextError(err)
}
// Closing the read side of the request body pipe acquires an internal lock,
// so we want to scope errMu's usage narrowly and avoid defer.
d.errMu.Unlock()

// We've already hit an error, so we should stop writing to the request body.
// It's safe to call Close more than once and/or concurrently (calls after
// the first are no-ops), so it's okay for us to call this even though
// net/http sometimes closes the reader too.
//
// It's safe to ignore the returned error here. Under the hood, Close calls
// CloseWithError, which is documented to always return nil.
_ = d.requestBodyReader.Close()
}

// SetValidateResponse sets the response validation function. The function runs
// in a background goroutine.
func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Error) {
d.validateResponse = validate
}

func (d *duplexHTTPCall) BlockUntilResponseReady() {
<-d.responseReady
func (d *duplexHTTPCall) BlockUntilResponseReady() error {
d.responseReady.Wait()
return d.responseErr
}

// ensureRequestMade sends the request headers and starts the response stream.
// It is not safe to call this concurrently. Write and CloseWrite call this but
// ensure that they're not called concurrently.
func (d *duplexHTTPCall) ensureRequestMade() {
d.sendRequestOnce.Do(func() {
go d.makeRequest()
})
if d.requestSent {
return // already sent
}
d.requestSent = true
go d.makeRequest()
}

func (d *duplexHTTPCall) makeRequest() {
// This runs concurrently with Write and CloseWrite. Read and CloseRead wait
// on d.responseReady, so we can't race with them.
defer close(d.responseReady)
defer d.responseReady.Done()

// Promote the header Host to the request object.
if host := d.request.Header.Get(headerHost); len(host) > 0 {
Expand All @@ -276,33 +256,32 @@ func (d *duplexHTTPCall) makeRequest() {
if _, ok := asError(err); !ok {
err = NewError(CodeUnavailable, err)
}
d.SetError(err)
d.responseErr = err
d.requestBodyReader.CloseWithError(io.EOF)
return
}
d.response = response
if err := d.validateResponse(response); err != nil {
d.SetError(err)
d.responseErr = err
d.response.Body.Close()
d.requestBodyReader.CloseWithError(io.EOF)
return
}
if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 {
// If we somehow dialed an HTTP/1.x server, fail with an explicit message
// rather than returning a more cryptic error later on.
d.SetError(errorf(
d.responseErr = errorf(
CodeUnimplemented,
"response from %v is HTTP/%d.%d: bidi streams require at least HTTP/2",
d.request.URL,
response.ProtoMajor,
response.ProtoMinor,
))
)
d.response.Body.Close()
d.requestBodyReader.CloseWithError(io.EOF)
}
}

func (d *duplexHTTPCall) getError() error {
d.errMu.Lock()
defer d.errMu.Unlock()
return d.err
}

// See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33
func cloneURL(oldURL *url.URL) *url.URL {
if oldURL == nil {
Expand Down
9 changes: 5 additions & 4 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,9 @@ func (cc *connectStreamingClientConn) CloseRequest() error {
}

func (cc *connectStreamingClientConn) Receive(msg any) error {
cc.duplexCall.BlockUntilResponseReady()
if err := cc.duplexCall.BlockUntilResponseReady(); err != nil {
return err
}
err := cc.unmarshaler.Unmarshal(msg)
if err == nil {
return nil
Expand All @@ -603,7 +605,6 @@ func (cc *connectStreamingClientConn) Receive(msg any) error {
// error.
serverErr.meta = cc.responseHeader.Clone()
mergeHeaders(serverErr.meta, cc.responseTrailer)
cc.duplexCall.SetError(serverErr)
return serverErr
}
// If the error is EOF but not from a last message, we want to return
Expand All @@ -614,8 +615,8 @@ func (cc *connectStreamingClientConn) Receive(msg any) error {
// There's no error in the trailers, so this was probably an error
// converting the bytes to a message, an error reading from the network, or
// just an EOF. We're going to return it to the user, but we also want to
// setResponseError so Send errors out.
cc.duplexCall.SetError(err)
// close the writer so Send errors out.
_ = cc.duplexCall.CloseWrite()
return err
}

Expand Down
13 changes: 7 additions & 6 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ func (cc *grpcClientConn) CloseRequest() error {
}

func (cc *grpcClientConn) Receive(msg any) error {
cc.duplexCall.BlockUntilResponseReady()
if err := cc.duplexCall.BlockUntilResponseReady(); err != nil {
return err
}
err := cc.unmarshaler.Unmarshal(msg)
if err == nil {
return nil
Expand Down Expand Up @@ -409,23 +411,22 @@ func (cc *grpcClientConn) Receive(msg any) error {
// the stream has ended, Receive must return an error.
serverErr.meta = cc.responseHeader.Clone()
mergeHeaders(serverErr.meta, cc.responseTrailer)
cc.duplexCall.SetError(serverErr)
return serverErr
}
// This was probably an error converting the bytes to a message or an error
// reading from the network. We're going to return it to the
// user, but we also want to setResponseError so Send errors out.
cc.duplexCall.SetError(err)
// user, but we also want to close writes so Send errors out.
_ = cc.duplexCall.CloseWrite()
return err
}

func (cc *grpcClientConn) ResponseHeader() http.Header {
cc.duplexCall.BlockUntilResponseReady()
_ = cc.duplexCall.BlockUntilResponseReady()
return cc.responseHeader
}

func (cc *grpcClientConn) ResponseTrailer() http.Header {
cc.duplexCall.BlockUntilResponseReady()
_ = cc.duplexCall.BlockUntilResponseReady()
return cc.responseTrailer
}

Expand Down

0 comments on commit 716794b

Please sign in to comment.