Skip to content

Commit

Permalink
reconcile with grpc-go behavior
Browse files Browse the repository at this point in the history
* This largely undoes a recent change to do more validation of trailers-only
  responses, which would disallow a body or trailers in what appeared to be
  a trailers-only response.
* Instead, a trailers-only response is now defined by the lack of body and
  trailers, not the presence of a "grpc-status" header.
* This also tweaks some other error scenarios. If trailers (or an end-stream
  message) is completely missing from a response, it's considered an internal
  error. But if trailers are present, but the "grpc-status" key is missing,
  it's considered an issue determining the status, which is an unknown error.
* Similarly, if a response content-type doesn't appear to be the right
  protocol (like it may have come from a non-RPC server), the error code is
  now unknown. But if it looks like the right protocol but uses the wrong
  sub-format/codec, it's an internal error.
* This is also now more strict about the "compressed" flag in a streaming
  protocol when there was no compression algorithm negotiated. This was
  previously not considered an error if the message in question was empty
  (zero bytes).
  • Loading branch information
jhump committed Feb 16, 2024
1 parent 064c61e commit 4d6745d
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 206 deletions.
140 changes: 45 additions & 95 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ func TestGRPCMissingTrailersError(t *testing.T) {
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok)
assert.Equal(t, connectErr.Code(), connect.CodeInternal)
assert.Equal(t, connectErr.Code(), connect.CodeUnknown)
assert.True(
t,
strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer: unexpected EOF"),
Expand Down Expand Up @@ -1838,7 +1838,9 @@ func TestUnflushableResponseWriter(t *testing.T) {
t.Parallel()
assertIsFlusherErr := func(t *testing.T, err error) {
t.Helper()
assert.NotNil(t, err)
if !assert.NotNil(t, err) {
return
}
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal, assert.Sprintf("got %v", err))
assert.True(
t,
Expand Down Expand Up @@ -1875,8 +1877,9 @@ func TestUnflushableResponseWriter(t *testing.T) {
assertIsFlusherErr(t, err)
return
}
assert.False(t, stream.Receive())
assertIsFlusherErr(t, stream.Err())
if assert.False(t, stream.Receive()) {
assertIsFlusherErr(t, stream.Err())
}
})
}
}
Expand Down Expand Up @@ -2146,6 +2149,21 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc_missing_status",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Trailers exist, just no status. So error will be unknown instead of internal.
responseWriter.Header().Set(http.TrailerPrefix+"grpc-message", "foo")
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_end",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
Expand All @@ -2159,6 +2177,29 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_status",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Trailers exist, just no status. So error will be unknown instead of internal.
_, err = responseWriter.Write([]byte{128}) // end-stream flag
assert.Nil(t, err)
endStream := "grpc-message: foo\r\n"
var length [4]byte
binary.BigEndian.PutUint32(length[:], uint32(len(endStream)))
_, err = responseWriter.Write(length[:])
assert.Nil(t, err)
_, err = responseWriter.Write([]byte(endStream))
assert.Nil(t, err)
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "connect_partial_payload",
options: []connect.ClientOption{connect.WithProtoJSON()},
Expand Down Expand Up @@ -2442,97 +2483,6 @@ func TestClientDisconnect(t *testing.T) {
})
}

func TestTrailersOnlyErrors(t *testing.T) {
t.Parallel()

head := [3]byte{}
testcases := []struct {
name string
handler http.HandlerFunc
options []connect.ClientOption
expectCode connect.Code
expectMsg string
}{{
name: "grpc_body_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc")
header.Set("Grpc-Status", "3")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after trailers-only response", len(head)),
}, {
name: "grpc-web_body_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web")
header.Set("Grpc-Status", "3")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after trailers-only response", len(head)),
}, {
name: "grpc_trailers_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc")
header.Set("Grpc-Status", "3")
responseWriter.WriteHeader(http.StatusOK)
responseWriter.(http.Flusher).Flush() //nolint:forcetypeassert
header.Set(http.TrailerPrefix+"Foo", "abc")
},
expectCode: connect.CodeInternal,
expectMsg: "internal: corrupt response from server: gRPC trailers-only response may not contain HTTP trailers",
}, {
name: "grpc-web_trailers_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web")
header.Set("Grpc-Status", "3")
responseWriter.WriteHeader(http.StatusOK)
responseWriter.(http.Flusher).Flush() //nolint:forcetypeassert
header.Set(http.TrailerPrefix+"Foo", "abc")
},
expectCode: connect.CodeInternal,
expectMsg: "internal: corrupt response from server: gRPC trailers-only response may not contain HTTP trailers",
}}
for _, testcase := range testcases {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = io.Copy(io.Discard, request.Body)
testcase.handler(responseWriter, request)
})
server := memhttptest.NewServer(t, mux)
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL(),
testcase.options...,
)
const upTo = 2
request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo})
request.Header().Set("Test-Case", t.Name())
stream, err := client.CountUp(context.Background(), request)
assert.Nil(t, err)
for i := 0; stream.Receive() && i < upTo; i++ {
assert.Equal(t, stream.Msg().GetNumber(), 42)
}
assert.NotNil(t, stream.Err())
assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode)
assert.Equal(t, stream.Err().Error(), testcase.expectMsg)
})
}
}

// TestBlankImportCodeGeneration tests that services.connect.go is generated with
// blank import statements to services.pb.go so that the service's Descriptor is
// available in the global proto registry.
Expand Down
27 changes: 17 additions & 10 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (w *envelopeWriter) write(env *envelope) *Error {
type envelopeReader struct {
ctx context.Context //nolint:containedctx
reader io.Reader
bytesRead int64
codec Codec
last envelope
compressionPool *compressionPool
Expand All @@ -241,6 +242,11 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
env := &envelope{Data: buffer}
err := r.Read(env)
switch {
case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil:
return errorf(
CodeInternal,
"protocol error: sent compressed message without compression support",
)
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
env.Data.Len() == 0:
Expand All @@ -257,12 +263,6 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

data := env.Data
if data.Len() > 0 && env.IsSet(flagEnvelopeCompressed) {
if r.compressionPool == nil {
return errorf(
CodeInvalidArgument,
"protocol error: sent compressed message without compression support",
)
}
decompressed := r.bufferPool.Get()
defer func() {
if decompressed != dontRelease {
Expand All @@ -277,7 +277,9 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// Drain the rest of the stream to ensure there is no extra data.
if numBytes, err := discard(r.reader); err != nil {
numBytes, err := discard(r.reader)
r.bytesRead += numBytes
if err != nil {
err = wrapIfContextError(err)
if connErr, ok := asError(err); ok {
return connErr
Expand Down Expand Up @@ -308,7 +310,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
prefixes := [5]byte{}
// io.ReadFull reads the number of bytes requested, or returns an error.
// io.EOF will only be returned if no bytes were read.
if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil {
n, err := io.ReadFull(r.reader, prefixes[:])
r.bytesRead += int64(n)
if err != nil {
if errors.Is(err, io.EOF) {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
Expand All @@ -328,7 +332,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
_, err := io.CopyN(io.Discard, r.reader, size)
n, err := io.CopyN(io.Discard, r.reader, size)
r.bytesRead += n
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err)
}
Expand All @@ -337,7 +342,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
readN, err := io.CopyN(env.Data, r.reader, size)
r.bytesRead += readN
if err != nil {
if errors.Is(err, io.EOF) {
// We've gotten fewer bytes than we expected, so the stream has ended
// unexpectedly.
Expand Down
10 changes: 8 additions & 2 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ func DecodeBinaryHeader(data string) ([]byte, error) {
}

func mergeHeaders(into, from http.Header) {
for k, vals := range from {
into[k] = append(into[k], vals...)
for key, vals := range from {
if len(vals) == 0 {
// For response trailers, net/http will pre-populate entries
// with nil values based on the "Trailer" header. But if there
// are no actual values for those keys, we skip them.
continue
}
into[key] = append(into[key], vals...)
}
}

Expand Down
1 change: 0 additions & 1 deletion header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ func TestHeaderMerge(t *testing.T) {
expect := http.Header{
"Foo": []string{"one", "two"},
"Bar": []string{"one"},
"Baz": nil,
}
assert.Equal(t, header, expect)
}
18 changes: 18 additions & 0 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,15 @@ func connectValidateUnaryResponseContentType(
)
}
// Normal responses must have valid content-type that indicates same codec as the request.
if !strings.HasPrefix(responseContentType, connectUnaryContentTypePrefix) {
// Doesn't even look like a Connect response? Use code "unknown".
return errorf(
CodeUnknown,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}
responseCodecName := connectCodecFromContentType(
StreamTypeUnary,
responseContentType,
Expand All @@ -1410,6 +1419,15 @@ func connectValidateUnaryResponseContentType(

func connectValidateStreamResponseContentType(requestCodecName string, streamType StreamType, responseContentType string) *Error {
// Responses must have valid content-type that indicates same codec as the request.
if !strings.HasPrefix(responseContentType, connectStreamingContentTypePrefix) {
// Doesn't even look like a Connect response? Use code "unknown".
return errorf(
CodeUnknown,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}
responseCodecName := connectCodecFromContentType(
streamType,
responseContentType,
Expand Down
37 changes: 24 additions & 13 deletions protocol_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func TestConnectValidateUnaryResponseContentType(t *testing.T) {
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "some/garbage",
expectCode: CodeInternal,
expectCode: CodeUnknown, // doesn't even look like it could be connect protocol
expectBadContentType: true,
},
// Error status, invalid content-type, returns code based on HTTP status code
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
testCases := []struct {
codecName string
responseContentType string
expectErr bool
expectCode Code
}{
// Allowed content-types
{
Expand All @@ -307,31 +307,42 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
codecName: codecNameJSON,
responseContentType: "application/connect+json",
},
// Mismatched response codec
{
codecName: codecNameProto,
responseContentType: "application/connect+json",
expectCode: CodeInternal,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+proto",
expectCode: CodeInternal,
},
// Disallowed content-types
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectCode: CodeInternal, // *almost* looks right
},
{
codecName: codecNameProto,
responseContentType: "application/proto",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameJSON,
responseContentType: "application/json",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameJSON,
responseContentType: "application/json; charset=utf-8",
expectErr: true,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameProto,
responseContentType: "some/garbage",
expectErr: true,
expectCode: CodeUnknown,
},
}
for _, testCase := range testCases {
Expand All @@ -344,10 +355,10 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
StreamTypeServer,
testCase.responseContentType,
)
if !testCase.expectErr {
if testCase.expectCode == 0 {
assert.Nil(t, err)
} else if assert.NotNil(t, err) {
assert.Equal(t, CodeOf(err), CodeInternal)
assert.Equal(t, CodeOf(err), testCase.expectCode)
assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType)))
}
})
Expand Down
Loading

0 comments on commit 4d6745d

Please sign in to comment.