From 51e2a9177992b9ce8b37e5103f38edd06c9be940 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 23 Feb 2024 17:44:50 +0000 Subject: [PATCH 1/5] Fix ErrorWriter to be codec agnostic This PR changes the ErrorWriter to be more lenient with classifying protocols. Errors codecs are agnostic to the codec used. Therefore we avoid checking the codec in classifying the request. IsSupported will return true for an unknown codec which allows the server to encode a better error message to the client. If not supported a 415 error response could be used to match gRPC server like handling. If not supported and trying to write an error the ErrorWriter will default to connects unary encoding. --- connect_ext_test.go | 6 ++- error_writer.go | 91 ++++++++++++++++++-------------------------- error_writer_test.go | 62 +++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 58 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 0ada5bbe..5714a090 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2423,7 +2423,9 @@ func TestClientDisconnect(t *testing.T) { assert.NotNil(t, err) <-gotResponse assert.NotNil(t, handlerReceiveErr) - assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) + if !assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) { + t.Logf("handlerReceiveErr: %v", handlerReceiveErr) + } assert.ErrorIs(t, handlerContextErr, context.Canceled) }) t.Run("handler_writes", func(t *testing.T) { @@ -2434,7 +2436,7 @@ func TestClientDisconnect(t *testing.T) { gotResponse = make(chan struct{}) ) pingServer := &pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, _ *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { close(gotRequest) var err error for err == nil { diff --git a/error_writer.go b/error_writer.go index 629918ea..6ffeb0b6 100644 --- a/error_writer.go +++ b/error_writer.go @@ -41,10 +41,8 @@ const ( type ErrorWriter struct { bufferPool *bufferPool protobuf Codec - grpcContentTypes map[string]struct{} - grpcWebContentTypes map[string]struct{} - unaryConnectContentTypes map[string]struct{} - streamingConnectContentTypes map[string]struct{} + handleGRPC bool + handleGRPCWeb bool requireConnectProtocolHeader bool } @@ -54,71 +52,56 @@ type ErrorWriter struct { // Options supplied via [WithConditionalHandlerOptions] are ignored. func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { config := newHandlerConfig("", StreamTypeUnary, opts) - writer := &ErrorWriter{ + codecs := newReadOnlyCodecs(config.Codecs) + return &ErrorWriter{ bufferPool: config.BufferPool, - protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(), - grpcContentTypes: make(map[string]struct{}), - grpcWebContentTypes: make(map[string]struct{}), - unaryConnectContentTypes: make(map[string]struct{}), - streamingConnectContentTypes: make(map[string]struct{}), + protobuf: codecs.Protobuf(), + handleGRPC: config.HandleGRPC, + handleGRPCWeb: config.HandleGRPCWeb, requireConnectProtocolHeader: config.RequireConnectProtocolHeader, } - for name := range config.Codecs { - unary := connectContentTypeFromCodecName(StreamTypeUnary, name) - writer.unaryConnectContentTypes[unary] = struct{}{} - streaming := connectContentTypeFromCodecName(StreamTypeBidi, name) - writer.streamingConnectContentTypes[streaming] = struct{}{} - } - if config.HandleGRPC { - writer.grpcContentTypes[grpcContentTypeDefault] = struct{}{} - for name := range config.Codecs { - ct := grpcContentTypeFromCodecName(false /* web */, name) - writer.grpcContentTypes[ct] = struct{}{} - } - } - if config.HandleGRPCWeb { - writer.grpcWebContentTypes[grpcWebContentTypeDefault] = struct{}{} - for name := range config.Codecs { - ct := grpcContentTypeFromCodecName(true /* web */, name) - writer.grpcWebContentTypes[ct] = struct{}{} - } - } - return writer } func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) - if _, ok := w.unaryConnectContentTypes[ctype]; ok { - if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { - return unknownProtocol + method := request.Method + switch { + case w.handleGRPC && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): + if method != http.MethodPost { + break + } + return grpcProtocol + case w.handleGRPCWeb && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): + if method != http.MethodPost { + break + } + return grpcWebProtocol + case strings.HasPrefix(ctype, connectStreamingContentTypePrefix): + if method != http.MethodPost { + break } - return connectUnaryProtocol - } - if _, ok := w.streamingConnectContentTypes[ctype]; ok { // Streaming ignores the requireConnectProtocolHeader option as the // Content-Type is enough to determine the protocol. if err := connectCheckProtocolVersion(request, false /* required */); err != nil { - return unknownProtocol + break } return connectStreamProtocol - } - if _, ok := w.grpcContentTypes[ctype]; ok { - return grpcProtocol - } - if _, ok := w.grpcWebContentTypes[ctype]; ok { - return grpcWebProtocol - } - // Check for Connect-Protocol-Version header or connect protocol query - // parameter to support connect GET requests. - if request.Method == http.MethodGet { - connectVersion := getHeaderCanonical(request.Header, connectProtocolVersion) - if connectVersion == connectProtocolVersion { - return connectUnaryProtocol + case strings.HasPrefix(ctype, connectUnaryContentTypePrefix): + if method != http.MethodPost { + break } - connectVersion = request.URL.Query().Get(connectUnaryConnectQueryParameter) - if connectVersion == connectUnaryConnectQueryValue { - return connectUnaryProtocol + if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { + break } + return connectUnaryProtocol + case ctype == "": + if method != http.MethodGet { + break + } + if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { + break + } + return connectUnaryProtocol } return unknownProtocol } diff --git a/error_writer_test.go b/error_writer_test.go index 0b3be022..913b5669 100644 --- a/error_writer_test.go +++ b/error_writer_test.go @@ -24,11 +24,9 @@ import ( func TestErrorWriter(t *testing.T) { t.Parallel() - t.Run("RequireConnectProtocolHeader", func(t *testing.T) { t.Parallel() writer := NewErrorWriter(WithRequireConnectProtocolHeader()) - t.Run("Unary", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON) @@ -52,4 +50,64 @@ func TestErrorWriter(t *testing.T) { assert.True(t, writer.IsSupported(req)) }) }) + t.Run("Protocols", func(t *testing.T) { + t.Parallel() + writer := NewErrorWriter() // All supported by default + t.Run("ConnectUnary", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectUnaryContentTypePrefix+codecNameJSON) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("ConnectUnaryGET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("ConnectStream", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectStreamingContentTypePrefix+codecNameJSON) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPC", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcContentTypeDefault) + assert.True(t, writer.IsSupported(req)) + req.Header.Set("Content-Type", grpcContentTypePrefix+"json") + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPCWeb", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcWebContentTypeDefault) + assert.True(t, writer.IsSupported(req)) + req.Header.Set("Content-Type", grpcWebContentTypePrefix+"json") + assert.True(t, writer.IsSupported(req)) + }) + }) + t.Run("UnknownCodec", func(t *testing.T) { + // An Unknown codec should return supported as the protocol is known and + // the error codec is agnostic to the codec used. The server can respond + // with a protocol error for the unknown codec. + t.Parallel() + writer := NewErrorWriter() + unknownCodec := "invalid" + t.Run("ConnectUnary", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectUnaryContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("ConnectStream", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", connectStreamingContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPC", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + t.Run("GRPCWeb", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + req.Header.Set("Content-Type", grpcWebContentTypePrefix+unknownCodec) + assert.True(t, writer.IsSupported(req)) + }) + }) } From f2216ade18a0ac469ec25a3e453e33196d07a47a Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 5 Mar 2024 15:28:11 -0500 Subject: [PATCH 2/5] Feedback --- connect_ext_test.go | 4 +--- error_writer.go | 45 +++++++++++++++++---------------------------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 5714a090..257f18db 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2423,9 +2423,7 @@ func TestClientDisconnect(t *testing.T) { assert.NotNil(t, err) <-gotResponse assert.NotNil(t, handlerReceiveErr) - if !assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) { - t.Logf("handlerReceiveErr: %v", handlerReceiveErr) - } + assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled, assert.Sprintf("got %v", handlerReceiveErr)) assert.ErrorIs(t, handlerContextErr, context.Canceled) }) t.Run("handler_writes", func(t *testing.T) { diff --git a/error_writer.go b/error_writer.go index 6ffeb0b6..88444dcc 100644 --- a/error_writer.go +++ b/error_writer.go @@ -46,9 +46,11 @@ type ErrorWriter struct { requireConnectProtocolHeader bool } -// NewErrorWriter constructs an ErrorWriter. To properly recognize supported -// RPC Content-Types in net/http middleware, you must pass the same -// HandlerOptions to NewErrorWriter and any wrapped Connect handlers. +// NewErrorWriter constructs an ErrorWriter. Handler options may be passed to +// configure the error writer behaviour to match the handlers. +// [WithRequiredConnectProtocolHeader] will assert that Connect protocol +// requests include the version header allowing the error writer to correctly +// classify the request. // Options supplied via [WithConditionalHandlerOptions] are ignored. func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { config := newHandlerConfig("", StreamTypeUnary, opts) @@ -64,46 +66,33 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) - method := request.Method + isPost := request.Method == http.MethodPost + isGet := request.Method == http.MethodGet switch { - case w.handleGRPC && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): - if method != http.MethodPost { - break - } + case w.handleGRPC && isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): return grpcProtocol - case w.handleGRPCWeb && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): - if method != http.MethodPost { - break - } + case w.handleGRPCWeb && isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): return grpcWebProtocol - case strings.HasPrefix(ctype, connectStreamingContentTypePrefix): - if method != http.MethodPost { - break - } + case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix): // Streaming ignores the requireConnectProtocolHeader option as the // Content-Type is enough to determine the protocol. if err := connectCheckProtocolVersion(request, false /* required */); err != nil { - break + return unknownProtocol } return connectStreamProtocol - case strings.HasPrefix(ctype, connectUnaryContentTypePrefix): - if method != http.MethodPost { - break - } + case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix): if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { - break + return unknownProtocol } return connectUnaryProtocol - case ctype == "": - if method != http.MethodGet { - break - } + case isGet && ctype == "": if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { - break + return unknownProtocol } return connectUnaryProtocol + default: + return unknownProtocol } - return unknownProtocol } // IsSupported checks whether a request is using one of the ErrorWriter's From 36448d9b8342f186fb59248f3f65e38a74c193d3 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 6 Mar 2024 11:51:23 -0500 Subject: [PATCH 3/5] Drop contentType check on GET and fix fallthrough --- error_writer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/error_writer.go b/error_writer.go index 88444dcc..078e1add 100644 --- a/error_writer.go +++ b/error_writer.go @@ -80,12 +80,12 @@ func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { return unknownProtocol } return connectStreamProtocol - case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix): + case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix) && !strings.HasPrefix(ctype, grpcContentTypeDefault): if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { return unknownProtocol } return connectUnaryProtocol - case isGet && ctype == "": + case isGet: if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { return unknownProtocol } From d0dde2d41c92cd539aa7f4a04c7da0a9cb0ca86b Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 11 Mar 2024 15:11:51 -0400 Subject: [PATCH 4/5] Drop stubbed config for protocol selection --- error_writer.go | 8 +++----- handler.go | 14 ++++---------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/error_writer.go b/error_writer.go index 078e1add..39f7397f 100644 --- a/error_writer.go +++ b/error_writer.go @@ -58,8 +58,6 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { return &ErrorWriter{ bufferPool: config.BufferPool, protobuf: codecs.Protobuf(), - handleGRPC: config.HandleGRPC, - handleGRPCWeb: config.HandleGRPCWeb, requireConnectProtocolHeader: config.RequireConnectProtocolHeader, } } @@ -69,9 +67,9 @@ func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { isPost := request.Method == http.MethodPost isGet := request.Method == http.MethodGet switch { - case w.handleGRPC && isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): + case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)): return grpcProtocol - case w.handleGRPCWeb && isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): + case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)): return grpcWebProtocol case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix): // Streaming ignores the requireConnectProtocolHeader option as the @@ -80,7 +78,7 @@ func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { return unknownProtocol } return connectStreamProtocol - case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix) && !strings.HasPrefix(ctype, grpcContentTypeDefault): + case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix): if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil { return unknownProtocol } diff --git a/handler.go b/handler.go index 88e80360..77724bdf 100644 --- a/handler.go +++ b/handler.go @@ -274,8 +274,6 @@ type handlerConfig struct { Procedure string Schema any Initializer maybeInitializer - HandleGRPC bool - HandleGRPCWeb bool RequireConnectProtocolHeader bool IdempotencyLevel IdempotencyLevel BufferPool *bufferPool @@ -290,8 +288,6 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler Procedure: protoPath, CompressionPools: make(map[string]*compressionPool), Codecs: make(map[string]Codec), - HandleGRPC: true, - HandleGRPCWeb: true, BufferPool: newBufferPool(), StreamType: streamType, } @@ -314,12 +310,10 @@ func (c *handlerConfig) newSpec() Spec { } func (c *handlerConfig) newProtocolHandlers() []protocolHandler { - protocols := []protocol{&protocolConnect{}} - if c.HandleGRPC { - protocols = append(protocols, &protocolGRPC{web: false}) - } - if c.HandleGRPCWeb { - protocols = append(protocols, &protocolGRPC{web: true}) + protocols := []protocol{ + &protocolConnect{}, + &protocolGRPC{web: false}, + &protocolGRPC{web: true}, } handlers := make([]protocolHandler, 0, len(protocols)) codecs := newReadOnlyCodecs(c.Codecs) From e1b71669cde16f520dadb8a03984374ac253d55e Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 11 Mar 2024 15:13:17 -0400 Subject: [PATCH 5/5] Fix drop --- error_writer.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/error_writer.go b/error_writer.go index 39f7397f..466c3b8e 100644 --- a/error_writer.go +++ b/error_writer.go @@ -41,8 +41,6 @@ const ( type ErrorWriter struct { bufferPool *bufferPool protobuf Codec - handleGRPC bool - handleGRPCWeb bool requireConnectProtocolHeader bool }