Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap errors with context cancellation codes #659

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import (
"compress/flate"
"compress/gzip"
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"net/http"
"runtime"
"strings"
Expand Down Expand Up @@ -2307,6 +2309,139 @@ func TestStreamUnexpectedEOF(t *testing.T) {
}
}

// TestClientDisconnect tests that the handler receives a CodeCanceled error when
// the client abruptly disconnects.
func TestClientDisconnect(t *testing.T) {
t.Parallel()
type httpRoundTripFunc func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper
http1RoundTripper := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper {
transport := server.TransportHTTP1()
dialContext := transport.DialContext
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr)
if err != nil {
close(onError)
return nil, err
}
*clientConn = conn // Capture the client connection.
return conn, nil
}
return transport
}
http2RoundTripper := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper {
transport := server.Transport()
dialContext := transport.DialTLSContext
transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := dialContext(ctx, network, addr, cfg)
if err != nil {
close(onError)
return nil, err
}
*clientConn = conn // Capture the client connection.
return conn, nil
}
return transport
}
testTransportClosure := func(t *testing.T, captureTransport httpRoundTripFunc) { //nolint:thelper
t.Run("handler_reads", func(t *testing.T) {
var (
handlerReceiveErr error
handlerContextErr error
gotRequest = make(chan struct{})
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
close(gotRequest)
for stream.Receive() {
// Do nothing
}
handlerReceiveErr = stream.Err()
handlerContextErr = ctx.Err()
close(gotResponse)
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
var clientConn net.Conn
transport := captureTransport(server, &clientConn, gotRequest)
serverClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(serverClient, server.URL())
stream := client.Sum(context.Background())
// Send header.
assert.Nil(t, stream.Send(nil))
<-gotRequest
// Client abruptly disconnects.
if !assert.NotNil(t, clientConn) {
return
}
assert.Nil(t, clientConn.Close())
_, err := stream.CloseAndReceive()
assert.NotNil(t, err)
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
t.Run("handler_writes", func(t *testing.T) {
var (
handlerReceiveErr error
handlerContextErr error
gotRequest = make(chan struct{})
gotResponse = make(chan struct{})
)
pingServer := &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
close(gotRequest)
var err error
for err == nil {
err = stream.Send(&pingv1.CountUpResponse{})
}
handlerReceiveErr = err
handlerContextErr = ctx.Err()
close(gotResponse)
return nil
},
}
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
var clientConn net.Conn
transport := captureTransport(server, &clientConn, gotRequest)
serverClient := &http.Client{Transport: transport}
client := pingv1connect.NewPingServiceClient(serverClient, server.URL())
stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{}))
if !assert.Nil(t, err) {
return
}
<-gotRequest
// Client abruptly disconnects.
if !assert.NotNil(t, clientConn) {
return
}
assert.Nil(t, clientConn.Close())
for stream.Receive() {
// Do nothing
}
assert.NotNil(t, stream.Err())
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled)
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
}
t.Run("http1", func(t *testing.T) {
t.Parallel()
testTransportClosure(t, http1RoundTripper)
})
t.Run("http2", func(t *testing.T) {
t.Parallel()
testTransportClosure(t, http2RoundTripper)
})
}

func TestTrailersOnlyErrors(t *testing.T) {
t.Parallel()

Expand Down
21 changes: 9 additions & 12 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
Expand Down Expand Up @@ -117,6 +118,7 @@ func (e *envelope) Len() int {
}

type envelopeWriter struct {
ctx context.Context //nolint:containedctx
sender messageSender
codec Codec
compressMinBytes int
Expand Down Expand Up @@ -208,7 +210,7 @@ func (w *envelopeWriter) marshal(message any) *Error {

func (w *envelopeWriter) write(env *envelope) *Error {
if _, err := w.sender.Send(env); err != nil {
err = wrapIfContextError(err)
err = wrapIfContextDone(w.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand All @@ -218,6 +220,7 @@ func (w *envelopeWriter) write(env *envelope) *Error {
}

type envelopeReader struct {
ctx context.Context //nolint:containedctx
reader io.Reader
codec Codec
last envelope
Expand Down Expand Up @@ -312,15 +315,12 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
}
err = wrapIfContextError(err)
// Something else has gone wrong - the stream didn't end cleanly.
err = wrapIfMaxBytesError(err, "read 5 byte message prefix")
err = wrapIfContextDone(r.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
// Something else has gone wrong - the stream didn't end cleanly.
return errorf(
CodeInvalidArgument,
"protocol error: incomplete envelope: %w", err,
Expand All @@ -338,10 +338,6 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
}
if errors.Is(err, io.EOF) {
// We've gotten fewer bytes than we expected, so the stream has ended
// unexpectedly.
Expand All @@ -352,7 +348,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
readN,
)
}
err = wrapIfContextError(err)
err = wrapIfMaxBytesError(err, "read %d byte message", size)
err = wrapIfContextDone(r.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down
2 changes: 2 additions & 0 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package connect

import (
"bytes"
"context"
"io"
"testing"

Expand Down Expand Up @@ -44,6 +45,7 @@ func TestEnvelope(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
ctx: context.Background(),
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
Expand Down
31 changes: 29 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,25 @@ func wrapIfContextError(err error) error {
return err
}

// wrapIfContextDone wraps errors with CodeCanceled or CodeDeadlineExceeded
// if the context is done. It leaves already-wrapped errors unchanged.
func wrapIfContextDone(ctx context.Context, err error) error {
if err == nil {
return nil
}
err = wrapIfContextError(err)
if _, ok := asError(err); ok {
return err
}
ctxErr := ctx.Err()
if errors.Is(ctxErr, context.Canceled) {
return NewError(CodeCanceled, err)
} else if errors.Is(ctxErr, context.DeadlineExceeded) {
return NewError(CodeDeadlineExceeded, err)
}
return err
}

// wrapIfLikelyH2CNotConfiguredError adds a wrapping error that has a message
// telling the caller that they likely need to use h2c but are using a raw http.Client{}.
//
Expand Down Expand Up @@ -414,10 +433,18 @@ func wrapIfRSTError(err error) error {
}
}

func asMaxBytesError(err error, tmpl string, args ...any) *Error {
// wrapIfMaxBytesError wraps errors returned reading from a http.MaxBytesHandler
// whose limit has been exceeded.
func wrapIfMaxBytesError(err error, tmpl string, args ...any) error {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted wrapIfMaxBytesError to the same style as other error handling for consistency.

if err == nil {
return nil
}
if _, ok := asError(err); ok {
return err
}
var maxBytesErr *http.MaxBytesError
if ok := errors.As(err, &maxBytesErr); !ok {
return nil
return err
}
prefix := fmt.Sprintf(tmpl, args...)
return errorf(CodeResourceExhausted, "%s: exceeded %d byte http.MaxBytesReader limit", prefix, maxBytesErr.Limit)
Expand Down
17 changes: 13 additions & 4 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ func (h *connectHandler) NewConn(
responseWriter http.ResponseWriter,
request *http.Request,
) (handlerConnCloser, bool) {
ctx := request.Context()
query := request.URL.Query()
// We need to parse metadata before entering the interceptor stack; we'll
// send the error to the client later on.
Expand Down Expand Up @@ -255,6 +256,7 @@ func (h *connectHandler) NewConn(
request: request,
responseWriter: responseWriter,
marshaler: connectUnaryMarshaler{
ctx: ctx,
sender: writeSender{writer: responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
Expand All @@ -265,6 +267,7 @@ func (h *connectHandler) NewConn(
sendMaxBytes: h.SendMaxBytes,
},
unmarshaler: connectUnaryUnmarshaler{
ctx: ctx,
reader: requestBody,
codec: codec,
compressionPool: h.CompressionPools.Get(requestCompression),
Expand All @@ -281,6 +284,7 @@ func (h *connectHandler) NewConn(
responseWriter: responseWriter,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
ctx: ctx,
sender: writeSender{responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
Expand All @@ -291,6 +295,7 @@ func (h *connectHandler) NewConn(
},
unmarshaler: connectStreamingUnmarshaler{
envelopeReader: envelopeReader{
ctx: ctx,
reader: requestBody,
codec: codec,
compressionPool: h.CompressionPools.Get(requestCompression),
Expand Down Expand Up @@ -376,6 +381,7 @@ func (c *connectClient) NewConn(
bufferPool: c.BufferPool,
marshaler: connectUnaryRequestMarshaler{
connectUnaryMarshaler: connectUnaryMarshaler{
ctx: ctx,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
Expand All @@ -387,6 +393,7 @@ func (c *connectClient) NewConn(
},
},
unmarshaler: connectUnaryUnmarshaler{
ctx: ctx,
reader: duplexCall,
codec: c.Codec,
bufferPool: c.BufferPool,
Expand Down Expand Up @@ -416,6 +423,7 @@ func (c *connectClient) NewConn(
codec: c.Codec,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
ctx: ctx,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
Expand All @@ -426,6 +434,7 @@ func (c *connectClient) NewConn(
},
unmarshaler: connectStreamingUnmarshaler{
envelopeReader: envelopeReader{
ctx: ctx,
reader: duplexCall,
codec: c.Codec,
bufferPool: c.BufferPool,
Expand Down Expand Up @@ -912,6 +921,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error {
}

type connectUnaryMarshaler struct {
ctx context.Context //nolint:containedctx
sender messageSender
codec Codec
compressMinBytes int
Expand Down Expand Up @@ -1077,6 +1087,7 @@ func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error {
}

type connectUnaryUnmarshaler struct {
ctx context.Context //nolint:containedctx
reader io.Reader
codec Codec
compressionPool *compressionPool
Expand All @@ -1103,13 +1114,11 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
// ReadFrom ignores io.EOF, so any error here is real.
bytesRead, err := data.ReadFrom(reader)
if err != nil {
err = wrapIfContextError(err)
err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead)
err = wrapIfContextDone(u.ctx, err)
if connectErr, ok := asError(err); ok {
return connectErr
}
if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil {
return readMaxBytesErr
}
return errorf(CodeUnknown, "read message: %w", err)
}
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
Expand Down
Loading
Loading