From d7334c477d1ea670fa8d5fa12f06a2bfe4f41d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sercan=20De=C4=9Firmenci?= Date: Wed, 21 Feb 2024 02:12:22 +0300 Subject: [PATCH] fix enabling compression by trimming whitespaces in accept encoding header (#6952) --- internal/transport/transport.go | 9 +++++++-- server.go | 6 +++--- test/compressor_test.go | 7 +++++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index b7b8fec18046..d3796c256e2f 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -28,6 +28,7 @@ import ( "fmt" "io" "net" + "strings" "sync" "sync/atomic" "time" @@ -362,8 +363,12 @@ func (s *Stream) SendCompress() string { // ClientAdvertisedCompressors returns the compressor names advertised by the // client via grpc-accept-encoding header. -func (s *Stream) ClientAdvertisedCompressors() string { - return s.clientAdvertisedCompressors +func (s *Stream) ClientAdvertisedCompressors() []string { + values := strings.Split(s.clientAdvertisedCompressors, ",") + for i, v := range values { + values[i] = strings.TrimSpace(v) + } + return values } // Done returns a channel which is closed when it receives the final status diff --git a/server.go b/server.go index 155a512bc3e7..a6a11704b34d 100644 --- a/server.go +++ b/server.go @@ -2120,7 +2120,7 @@ func ClientSupportedCompressors(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) } - return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil + return stream.ClientAdvertisedCompressors(), nil } // SetTrailer sets the trailer metadata that will be sent when an RPC returns. @@ -2160,7 +2160,7 @@ func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { // validateSendCompressor returns an error when given compressor name cannot be // handled by the server or the client based on the advertised compressors. -func validateSendCompressor(name, clientCompressors string) error { +func validateSendCompressor(name string, clientCompressors []string) error { if name == encoding.Identity { return nil } @@ -2169,7 +2169,7 @@ func validateSendCompressor(name, clientCompressors string) error { return fmt.Errorf("compressor not registered %q", name) } - for _, c := range strings.Split(clientCompressors, ",") { + for _, c := range clientCompressors { if c == name { return nil // found match } diff --git a/test/compressor_test.go b/test/compressor_test.go index a18d14f4ac73..7f3abb908c2e 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -566,6 +566,13 @@ func (s) TestClientSupportedCompressors(t *testing.T) { ), want: []string{"gzip"}, }, + { + desc: "With additional grpc-accept-encoding header with spaces between values", + ctx: metadata.AppendToOutgoingContext(ctx, + "grpc-accept-encoding", "identity, deflate", + ), + want: []string{"gzip", "identity", "deflate"}, + }, } { t.Run(tt.desc, func(t *testing.T) { ss := &stubserver.StubServer{