From 7db421e1369489c5d88299484fee708d17061f6c Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 22 Nov 2023 12:57:12 -0500 Subject: [PATCH 1/6] Wrap errors with context cancellation codes This PR wraps errors with the appropariate connect code of Cancelled or DeadlineExceeded if the context error is not nil. Improves error handling for some well known error cases that do not surface context.Cancelled errors. For example HTTP2 "client disconnect" string errors are now raised with a Cancelled code not an Unknwon. This lets handlers check the error code for better handling and reporting of errors. --- .golangci.yml | 6 ++ connect_ext_test.go | 128 +++++++++++++++++++++++++++++++++++++++ envelope.go | 19 +++--- envelope_test.go | 2 + error.go | 32 +++++++++- protocol_connect.go | 16 ++++- protocol_connect_test.go | 2 + protocol_grpc.go | 5 ++ 8 files changed, 194 insertions(+), 16 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 48828844..22707139 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -132,3 +132,9 @@ issues: # We want to show examples with http.Get - linters: [noctx] path: internal/memhttp/memhttp_test.go + # We need envelope readers and writers to have access to the context for error handling. + - linters: [containedctx] + path: envelope.go + # We need marshallers and unmarshallers to have access to the context for error handling. + - linters: [containedctx] + path: protocol_connect.go diff --git a/connect_ext_test.go b/connect_ext_test.go index 2cb04a07..e3ca3106 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -19,12 +19,14 @@ import ( "compress/flate" "compress/gzip" "context" + "crypto/tls" "encoding/binary" "errors" "fmt" "io" "math" "math/rand" + "net" "net/http" "runtime" "strings" @@ -2276,6 +2278,132 @@ func TestStreamUnexpectedEOF(t *testing.T) { } } +// TestClientDisconnect tests that the handler receives a CodeCanceled error when +// the client abruptly disconnects. +func TestClientDisconnect(t *testing.T) { + t.Parallel() + captureTransportConn := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}, http2 bool) http.RoundTripper { + if http2 { + transport := server.Transport() + dialContext := transport.DialTLSContext + transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := dialContext(ctx, network, addr, cfg) + if err != nil { + close(onError) + return nil, err + } + *clientConn = conn // Capture the client connection. + return conn, nil + } + return transport + } + transport := server.TransportHTTP1() + dialContext := transport.DialContext + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := dialContext(ctx, network, addr) + if err != nil { + close(onError) + return nil, err + } + *clientConn = conn // Capture the client connection. + return conn, nil + } + return transport + } + testTransportClosure := func(t *testing.T, http2 bool) { //nolint:thelper + t.Run("handler_reads", func(t *testing.T) { + var ( + handlerReceiveErr error + handlerContextErr error + gotRequest = make(chan struct{}) + gotResponse = make(chan struct{}) + ) + pingServer := &pluggablePingServer{ + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + close(gotRequest) + for stream.Receive() { + // Do nothing + } + handlerReceiveErr = stream.Err() + handlerContextErr = ctx.Err() + close(gotResponse) + return connect.NewResponse(&pingv1.SumResponse{}), nil + }, + } + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := memhttptest.NewServer(t, mux) + var clientConn net.Conn + transport := captureTransportConn(server, &clientConn, gotRequest, http2) + serverClient := &http.Client{Transport: transport} + client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) + stream := client.Sum(context.Background()) + // Send header. + assert.Nil(t, stream.Send(nil)) + <-gotRequest + // Client abruptly disconnects. + if !assert.NotNil(t, clientConn) { + return + } + assert.Nil(t, clientConn.Close()) + _, err := stream.CloseAndReceive() + assert.NotNil(t, err) + <-gotResponse + assert.NotNil(t, handlerReceiveErr) + assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) + assert.ErrorIs(t, handlerContextErr, context.Canceled) + }) + t.Run("handler_writes", func(t *testing.T) { + var ( + handlerReceiveErr error + handlerContextErr error + gotRequest = make(chan struct{}) + gotResponse = make(chan struct{}) + ) + pingServer := &pluggablePingServer{ + countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + close(gotRequest) + var err error + for err == nil { + err = stream.Send(&pingv1.CountUpResponse{}) + } + handlerReceiveErr = err + handlerContextErr = ctx.Err() + close(gotResponse) + return nil + }, + } + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := memhttptest.NewServer(t, mux) + var clientConn net.Conn + transport := captureTransportConn(server, &clientConn, gotRequest, http2) + serverClient := &http.Client{Transport: transport} + client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) + stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) + if !assert.Nil(t, err) { + return + } + <-gotRequest + // Client abruptly disconnects. + if !assert.NotNil(t, clientConn) { + return + } + assert.Nil(t, clientConn.Close()) + for stream.Receive() { + // Do nothing + } + assert.NotNil(t, stream.Err()) + <-gotResponse + assert.NotNil(t, handlerReceiveErr) + assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) + assert.ErrorIs(t, handlerContextErr, context.Canceled) + }) + } + testTransportClosure(t, true) + testTransportClosure(t, false) +} + // TestBlankImportCodeGeneration tests that services.connect.go is generated with // blank import statements to services.pb.go so that the service's Descriptor is // available in the global proto registry. diff --git a/envelope.go b/envelope.go index c9f67b36..c976fb6d 100644 --- a/envelope.go +++ b/envelope.go @@ -16,6 +16,7 @@ package connect import ( "bytes" + "context" "encoding/binary" "errors" "io" @@ -117,6 +118,7 @@ func (e *envelope) Len() int { } type envelopeWriter struct { + ctx context.Context sender messageSender codec Codec compressMinBytes int @@ -209,6 +211,7 @@ func (w *envelopeWriter) marshal(message any) *Error { func (w *envelopeWriter) write(env *envelope) *Error { if _, err := w.sender.Send(env); err != nil { err = wrapIfContextError(err) + err = wrapWithContextError(w.ctx, err) if connectErr, ok := asError(err); ok { return connectErr } @@ -218,6 +221,7 @@ func (w *envelopeWriter) write(env *envelope) *Error { } type envelopeReader struct { + ctx context.Context reader io.Reader codec Codec last envelope @@ -305,17 +309,12 @@ func (r *envelopeReader) Read(env *envelope) *Error { return NewError(CodeUnknown, err) } err = wrapIfContextError(err) + err = wrapWithContextError(r.ctx, err) + err = wrapIfMaxBytesError(err, "read 5 byte message prefix") if connectErr, ok := asError(err); ok { return connectErr } // Something else has gone wrong - the stream didn't end cleanly. - if connectErr, ok := asError(err); ok { - return connectErr - } - if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil { - // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. - return maxBytesErr - } return errorf( CodeInvalidArgument, "protocol error: incomplete envelope: %w", err, @@ -333,10 +332,6 @@ func (r *envelopeReader) Read(env *envelope) *Error { // CopyN will return an error if it doesn't read the requested // number of bytes. if readN, err := io.CopyN(env.Data, r.reader, size); err != nil { - if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { - // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. - return maxBytesErr - } if errors.Is(err, io.EOF) { // We've gotten fewer bytes than we expected, so the stream has ended // unexpectedly. @@ -348,6 +343,8 @@ func (r *envelopeReader) Read(env *envelope) *Error { ) } err = wrapIfContextError(err) + err = wrapWithContextError(r.ctx, err) + err = wrapIfMaxBytesError(err, "read %d byte message", size) if connectErr, ok := asError(err); ok { return connectErr } diff --git a/envelope_test.go b/envelope_test.go index f153b8c3..8f4a96d7 100644 --- a/envelope_test.go +++ b/envelope_test.go @@ -16,6 +16,7 @@ package connect import ( "bytes" + "context" "io" "testing" @@ -44,6 +45,7 @@ func TestEnvelope(t *testing.T) { t.Parallel() env := &envelope{Data: &bytes.Buffer{}} rdr := envelopeReader{ + ctx: context.Background(), reader: byteByByteReader{ reader: bytes.NewReader(buf.Bytes()), }, diff --git a/error.go b/error.go index ea080e3d..bbca34aa 100644 --- a/error.go +++ b/error.go @@ -302,6 +302,26 @@ func wrapIfContextError(err error) error { return err } +// wrapWithContextError wraps errors with CodeCanceled or CodeDeadlineExceeded +// if the context is done. It leaves already-wrapped errors unchanged. +func wrapWithContextError(ctx context.Context, err error) error { + if err == nil { + return nil + } + if _, ok := asError(err); ok { + return err + } + ctxErr := ctx.Err() + switch { + case errors.Is(ctxErr, context.Canceled): + return NewError(CodeCanceled, err) + case errors.Is(ctxErr, context.DeadlineExceeded): + return NewError(CodeDeadlineExceeded, err) + default: + return err + } +} + // wrapIfLikelyH2CNotConfiguredError adds a wrapping error that has a message // telling the caller that they likely need to use h2c but are using a raw http.Client{}. // @@ -408,10 +428,18 @@ func wrapIfRSTError(err error) error { } } -func asMaxBytesError(err error, tmpl string, args ...any) *Error { +// wrapIfMaxBytesError wraps errors returned reading from a http.MaxBytesHandler +// whose limit has been exceeded. +func wrapIfMaxBytesError(err error, tmpl string, args ...any) error { + if err == nil { + return nil + } + if _, ok := asError(err); ok { + return err + } var maxBytesErr *http.MaxBytesError if ok := errors.As(err, &maxBytesErr); !ok { - return nil + return err } prefix := fmt.Sprintf(tmpl, args...) return errorf(CodeResourceExhausted, "%s: exceeded %d byte http.MaxBytesReader limit", prefix, maxBytesErr.Limit) diff --git a/protocol_connect.go b/protocol_connect.go index 6f8b63cc..58cdaaab 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -147,6 +147,7 @@ func (h *connectHandler) NewConn( responseWriter http.ResponseWriter, request *http.Request, ) (handlerConnCloser, bool) { + ctx := request.Context() query := request.URL.Query() // We need to parse metadata before entering the interceptor stack; we'll // send the error to the client later on. @@ -254,6 +255,7 @@ func (h *connectHandler) NewConn( request: request, responseWriter: responseWriter, marshaler: connectUnaryMarshaler{ + ctx: ctx, sender: writeSender{writer: responseWriter}, codec: codec, compressMinBytes: h.CompressMinBytes, @@ -264,6 +266,7 @@ func (h *connectHandler) NewConn( sendMaxBytes: h.SendMaxBytes, }, unmarshaler: connectUnaryUnmarshaler{ + ctx: ctx, reader: requestBody, codec: codec, compressionPool: h.CompressionPools.Get(requestCompression), @@ -280,6 +283,7 @@ func (h *connectHandler) NewConn( responseWriter: responseWriter, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ + ctx: ctx, sender: writeSender{responseWriter}, codec: codec, compressMinBytes: h.CompressMinBytes, @@ -290,6 +294,7 @@ func (h *connectHandler) NewConn( }, unmarshaler: connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ + ctx: ctx, reader: requestBody, codec: codec, compressionPool: h.CompressionPools.Get(requestCompression), @@ -375,6 +380,7 @@ func (c *connectClient) NewConn( bufferPool: c.BufferPool, marshaler: connectUnaryRequestMarshaler{ connectUnaryMarshaler: connectUnaryMarshaler{ + ctx: ctx, sender: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, @@ -386,6 +392,7 @@ func (c *connectClient) NewConn( }, }, unmarshaler: connectUnaryUnmarshaler{ + ctx: ctx, reader: duplexCall, codec: c.Codec, bufferPool: c.BufferPool, @@ -415,6 +422,7 @@ func (c *connectClient) NewConn( codec: c.Codec, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ + ctx: ctx, sender: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, @@ -425,6 +433,7 @@ func (c *connectClient) NewConn( }, unmarshaler: connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ + ctx: ctx, reader: duplexCall, codec: c.Codec, bufferPool: c.BufferPool, @@ -892,6 +901,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error { } type connectUnaryMarshaler struct { + ctx context.Context sender messageSender codec Codec compressMinBytes int @@ -1057,6 +1067,7 @@ func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error { } type connectUnaryUnmarshaler struct { + ctx context.Context reader io.Reader codec Codec compressionPool *compressionPool @@ -1084,12 +1095,11 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by bytesRead, err := data.ReadFrom(reader) if err != nil { err = wrapIfContextError(err) + err = wrapWithContextError(u.ctx, err) + err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead) if connectErr, ok := asError(err); ok { return connectErr } - if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil { - return readMaxBytesErr - } return errorf(CodeUnknown, "read message: %w", err) } if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) { diff --git a/protocol_connect_test.go b/protocol_connect_test.go index ae1bedea..92d86896 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -16,6 +16,7 @@ package connect import ( "bytes" + "context" "encoding/json" "net/http" "strings" @@ -83,6 +84,7 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { unmarshaler := connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ + ctx: context.Background(), reader: &buffer, bufferPool: bufferPool, }, diff --git a/protocol_grpc.go b/protocol_grpc.go index ce8ceb03..68fde10a 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -144,6 +144,7 @@ func (g *grpcHandler) NewConn( responseWriter http.ResponseWriter, request *http.Request, ) (handlerConnCloser, bool) { + ctx := request.Context() // We need to parse metadata before entering the interceptor stack; we'll // send the error to the client later on. requestCompression, responseCompression, failed := negotiateCompression( @@ -186,6 +187,7 @@ func (g *grpcHandler) NewConn( protobuf: g.Codecs.Protobuf(), // for errors marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ + ctx: ctx, sender: writeSender{writer: responseWriter}, compressionPool: g.CompressionPools.Get(responseCompression), codec: codec, @@ -200,6 +202,7 @@ func (g *grpcHandler) NewConn( request: request, unmarshaler: grpcUnmarshaler{ envelopeReader: envelopeReader{ + ctx: ctx, reader: request.Body, codec: codec, compressionPool: g.CompressionPools.Get(requestCompression), @@ -284,6 +287,7 @@ func (g *grpcClient) NewConn( protobuf: g.Protobuf, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ + ctx: ctx, sender: duplexCall, compressionPool: g.CompressionPools.Get(g.CompressionName), codec: g.Codec, @@ -294,6 +298,7 @@ func (g *grpcClient) NewConn( }, unmarshaler: grpcUnmarshaler{ envelopeReader: envelopeReader{ + ctx: ctx, reader: duplexCall, codec: g.Codec, bufferPool: g.BufferPool, From 59632af457679a4eff1f24bab69663ff0e179490 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 27 Dec 2023 13:28:45 +0100 Subject: [PATCH 2/6] Tidy up client disconnect test --- connect_ext_test.go | 47 ++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index e3ca3106..f1cef5fb 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2282,21 +2282,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { // the client abruptly disconnects. func TestClientDisconnect(t *testing.T) { t.Parallel() - captureTransportConn := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}, http2 bool) http.RoundTripper { - if http2 { - transport := server.Transport() - dialContext := transport.DialTLSContext - transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { - conn, err := dialContext(ctx, network, addr, cfg) - if err != nil { - close(onError) - return nil, err - } - *clientConn = conn // Capture the client connection. - return conn, nil - } - return transport - } + type httpRoundTripFunc func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper + http1RoundTripper := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper { transport := server.TransportHTTP1() dialContext := transport.DialContext transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -2310,7 +2297,21 @@ func TestClientDisconnect(t *testing.T) { } return transport } - testTransportClosure := func(t *testing.T, http2 bool) { //nolint:thelper + http2RoundTripper := func(server *memhttp.Server, clientConn *net.Conn, onError chan struct{}) http.RoundTripper { + transport := server.Transport() + dialContext := transport.DialTLSContext + transport.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := dialContext(ctx, network, addr, cfg) + if err != nil { + close(onError) + return nil, err + } + *clientConn = conn // Capture the client connection. + return conn, nil + } + return transport + } + testTransportClosure := func(t *testing.T, captureTransport httpRoundTripFunc) { //nolint:thelper t.Run("handler_reads", func(t *testing.T) { var ( handlerReceiveErr error @@ -2334,7 +2335,7 @@ func TestClientDisconnect(t *testing.T) { mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn - transport := captureTransportConn(server, &clientConn, gotRequest, http2) + transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) stream := client.Sum(context.Background()) @@ -2377,7 +2378,7 @@ func TestClientDisconnect(t *testing.T) { mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := memhttptest.NewServer(t, mux) var clientConn net.Conn - transport := captureTransportConn(server, &clientConn, gotRequest, http2) + transport := captureTransport(server, &clientConn, gotRequest) serverClient := &http.Client{Transport: transport} client := pingv1connect.NewPingServiceClient(serverClient, server.URL()) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) @@ -2400,8 +2401,14 @@ func TestClientDisconnect(t *testing.T) { assert.ErrorIs(t, handlerContextErr, context.Canceled) }) } - testTransportClosure(t, true) - testTransportClosure(t, false) + t.Run("http1", func(t *testing.T) { + t.Parallel() + testTransportClosure(t, http1RoundTripper) + }) + t.Run("http2", func(t *testing.T) { + t.Parallel() + testTransportClosure(t, http2RoundTripper) + }) } // TestBlankImportCodeGeneration tests that services.connect.go is generated with From 1113d81294ab8ba3d22c88574e3cfea3beda1d7c Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 6 Feb 2024 13:26:07 -0500 Subject: [PATCH 3/6] Feedback lint and ordering --- .golangci.yml | 6 ------ envelope.go | 8 ++++---- error.go | 8 +++----- protocol_connect.go | 6 +++--- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 22707139..48828844 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -132,9 +132,3 @@ issues: # We want to show examples with http.Get - linters: [noctx] path: internal/memhttp/memhttp_test.go - # We need envelope readers and writers to have access to the context for error handling. - - linters: [containedctx] - path: envelope.go - # We need marshallers and unmarshallers to have access to the context for error handling. - - linters: [containedctx] - path: protocol_connect.go diff --git a/envelope.go b/envelope.go index c976fb6d..8cb8b796 100644 --- a/envelope.go +++ b/envelope.go @@ -118,7 +118,7 @@ func (e *envelope) Len() int { } type envelopeWriter struct { - ctx context.Context + ctx context.Context //nolint:containedctx sender messageSender codec Codec compressMinBytes int @@ -221,7 +221,7 @@ func (w *envelopeWriter) write(env *envelope) *Error { } type envelopeReader struct { - ctx context.Context + ctx context.Context //nolint:containedctx reader io.Reader codec Codec last envelope @@ -308,9 +308,9 @@ func (r *envelopeReader) Read(env *envelope) *Error { // add any alarming text about protocol errors, though. return NewError(CodeUnknown, err) } + err = wrapIfMaxBytesError(err, "read 5 byte message prefix") err = wrapIfContextError(err) err = wrapWithContextError(r.ctx, err) - err = wrapIfMaxBytesError(err, "read 5 byte message prefix") if connectErr, ok := asError(err); ok { return connectErr } @@ -342,9 +342,9 @@ func (r *envelopeReader) Read(env *envelope) *Error { readN, ) } + err = wrapIfMaxBytesError(err, "read %d byte message", size) err = wrapIfContextError(err) err = wrapWithContextError(r.ctx, err) - err = wrapIfMaxBytesError(err, "read %d byte message", size) if connectErr, ok := asError(err); ok { return connectErr } diff --git a/error.go b/error.go index bbca34aa..7d5a9f1c 100644 --- a/error.go +++ b/error.go @@ -312,14 +312,12 @@ func wrapWithContextError(ctx context.Context, err error) error { return err } ctxErr := ctx.Err() - switch { - case errors.Is(ctxErr, context.Canceled): + if errors.Is(ctxErr, context.Canceled) { return NewError(CodeCanceled, err) - case errors.Is(ctxErr, context.DeadlineExceeded): + } else if errors.Is(ctxErr, context.DeadlineExceeded) { return NewError(CodeDeadlineExceeded, err) - default: - return err } + return err } // wrapIfLikelyH2CNotConfiguredError adds a wrapping error that has a message diff --git a/protocol_connect.go b/protocol_connect.go index 58cdaaab..22b67273 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -901,7 +901,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error { } type connectUnaryMarshaler struct { - ctx context.Context + ctx context.Context //nolint:containedctx sender messageSender codec Codec compressMinBytes int @@ -1067,7 +1067,7 @@ func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error { } type connectUnaryUnmarshaler struct { - ctx context.Context + ctx context.Context //nolint:containedctx reader io.Reader codec Codec compressionPool *compressionPool @@ -1094,9 +1094,9 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by // ReadFrom ignores io.EOF, so any error here is real. bytesRead, err := data.ReadFrom(reader) if err != nil { + err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead) err = wrapIfContextError(err) err = wrapWithContextError(u.ctx, err) - err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead) if connectErr, ok := asError(err); ok { return connectErr } From 83d97eb50a554eed95e13070be78186909ab47eb Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 15 Feb 2024 16:05:00 -0500 Subject: [PATCH 4/6] Fix bad merge --- connect_ext_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index bd34cf03..8136db9c 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2309,7 +2309,6 @@ func TestStreamUnexpectedEOF(t *testing.T) { } } -<<<<<<< HEAD // TestClientDisconnect tests that the handler receives a CodeCanceled error when // the client abruptly disconnects. func TestClientDisconnect(t *testing.T) { From f925b2e8cddbf94ecb7677483e794fb8cebb60d7 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 15 Feb 2024 17:11:22 -0500 Subject: [PATCH 5/6] Change name to wrapIfContextDone --- envelope.go | 6 +++--- error.go | 4 ++-- protocol_connect.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/envelope.go b/envelope.go index 0cab39c8..481b89ba 100644 --- a/envelope.go +++ b/envelope.go @@ -211,7 +211,7 @@ func (w *envelopeWriter) marshal(message any) *Error { func (w *envelopeWriter) write(env *envelope) *Error { if _, err := w.sender.Send(env); err != nil { err = wrapIfContextError(err) - err = wrapWithContextError(w.ctx, err) + err = wrapIfContextDone(w.ctx, err) if connectErr, ok := asError(err); ok { return connectErr } @@ -318,7 +318,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { } err = wrapIfMaxBytesError(err, "read 5 byte message prefix") err = wrapIfContextError(err) - err = wrapWithContextError(r.ctx, err) + err = wrapIfContextDone(r.ctx, err) if connectErr, ok := asError(err); ok { return connectErr } @@ -352,7 +352,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { } err = wrapIfMaxBytesError(err, "read %d byte message", size) err = wrapIfContextError(err) - err = wrapWithContextError(r.ctx, err) + err = wrapIfContextDone(r.ctx, err) if connectErr, ok := asError(err); ok { return connectErr } diff --git a/error.go b/error.go index e2099599..5aa141a8 100644 --- a/error.go +++ b/error.go @@ -308,9 +308,9 @@ func wrapIfContextError(err error) error { return err } -// wrapWithContextError wraps errors with CodeCanceled or CodeDeadlineExceeded +// wrapIfContextDone wraps errors with CodeCanceled or CodeDeadlineExceeded // if the context is done. It leaves already-wrapped errors unchanged. -func wrapWithContextError(ctx context.Context, err error) error { +func wrapIfContextDone(ctx context.Context, err error) error { if err == nil { return nil } diff --git a/protocol_connect.go b/protocol_connect.go index 72d39c4e..bb3f670a 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -1116,7 +1116,7 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by if err != nil { err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead) err = wrapIfContextError(err) - err = wrapWithContextError(u.ctx, err) + err = wrapIfContextDone(u.ctx, err) if connectErr, ok := asError(err); ok { return connectErr } From c3324072f50f3882a98e4ee846ee7511d3911efa Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Thu, 15 Feb 2024 17:25:03 -0500 Subject: [PATCH 6/6] Call wrapIfContextError from wrapIfContextDone --- envelope.go | 3 --- error.go | 1 + protocol_connect.go | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/envelope.go b/envelope.go index 481b89ba..9bc15555 100644 --- a/envelope.go +++ b/envelope.go @@ -210,7 +210,6 @@ func (w *envelopeWriter) marshal(message any) *Error { func (w *envelopeWriter) write(env *envelope) *Error { if _, err := w.sender.Send(env); err != nil { - err = wrapIfContextError(err) err = wrapIfContextDone(w.ctx, err) if connectErr, ok := asError(err); ok { return connectErr @@ -317,7 +316,6 @@ func (r *envelopeReader) Read(env *envelope) *Error { return NewError(CodeUnknown, err) } err = wrapIfMaxBytesError(err, "read 5 byte message prefix") - err = wrapIfContextError(err) err = wrapIfContextDone(r.ctx, err) if connectErr, ok := asError(err); ok { return connectErr @@ -351,7 +349,6 @@ func (r *envelopeReader) Read(env *envelope) *Error { ) } err = wrapIfMaxBytesError(err, "read %d byte message", size) - err = wrapIfContextError(err) err = wrapIfContextDone(r.ctx, err) if connectErr, ok := asError(err); ok { return connectErr diff --git a/error.go b/error.go index 5aa141a8..26544d95 100644 --- a/error.go +++ b/error.go @@ -314,6 +314,7 @@ func wrapIfContextDone(ctx context.Context, err error) error { if err == nil { return nil } + err = wrapIfContextError(err) if _, ok := asError(err); ok { return err } diff --git a/protocol_connect.go b/protocol_connect.go index bb3f670a..1726aed4 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -1115,7 +1115,6 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by bytesRead, err := data.ReadFrom(reader) if err != nil { err = wrapIfMaxBytesError(err, "read first %d bytes of message", bytesRead) - err = wrapIfContextError(err) err = wrapIfContextDone(u.ctx, err) if connectErr, ok := asError(err); ok { return connectErr