diff --git a/duplex_http_call.go b/duplex_http_call.go index 35075bd5..a40775d5 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -94,21 +94,21 @@ func newDuplexHTTPCall( } } -// Write to the request body. -func (d *duplexHTTPCall) Write(data []byte) (int, error) { +// Send sends a message to the server. +func (d *duplexHTTPCall) Send(payload messsagePayload) (int64, error) { isFirst := d.ensureRequestMade() // Before we send any data, check if the context has been canceled. if err := d.ctx.Err(); err != nil { return 0, wrapIfContextError(err) } - if isFirst && data == nil { + if isFirst && payload.Len() == 0 { // On first write a nil Send is used to send request headers. Avoid // writing a zero-length payload to avoid superfluous errors with close. return 0, nil } // It's safe to write to this side of the pipe while net/http concurrently // reads from the other side. - bytesWritten, err := d.requestBodyWriter.Write(data) + bytesWritten, err := payload.WriteTo(d.requestBodyWriter) if err != nil && errors.Is(err, io.ErrClosedPipe) { // Signal that the stream is closed with the more-typical io.EOF instead of // io.ErrClosedPipe. This makes it easier for protocol-specific wrappers to @@ -295,6 +295,52 @@ func (d *duplexHTTPCall) makeRequest() { } } +// messsagePayload is a sized and seekable message payload. The interface is +// implemented by [*bytes.Reader] and *envelope. +type messsagePayload interface { + io.Reader + io.WriterTo + io.Seeker + Len() int +} + +// nopPayload is a message payload that does nothing. It's used to send headers +// to the server. +type nopPayload struct{} + +var _ messsagePayload = nopPayload{} + +func (nopPayload) Read([]byte) (int, error) { + return 0, io.EOF +} +func (nopPayload) WriteTo(io.Writer) (int64, error) { + return 0, nil +} +func (nopPayload) Seek(int64, int) (int64, error) { + return 0, nil +} +func (nopPayload) Len() int { + return 0 +} + +// messageSender sends a message payload. The interface is implemented by +// [*duplexHTTPCall] and writeSender. +type messageSender interface { + Send(messsagePayload) (int64, error) +} + +// writeSender is a sender that writes to an [io.Writer]. Useful for wrapping +// [http.ResponseWriter]. +type writeSender struct { + writer io.Writer +} + +var _ messageSender = writeSender{} + +func (w writeSender) Send(payload messsagePayload) (int64, error) { + return payload.WriteTo(w.writer) +} + // See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33 func cloneURL(oldURL *url.URL) *url.URL { if oldURL == nil { diff --git a/envelope.go b/envelope.go index 7452e1a6..6b333827 100644 --- a/envelope.go +++ b/envelope.go @@ -40,16 +40,84 @@ var errSpecialEnvelope = errorf( // message length. gRPC and Connect interpret the bitwise flags differently, so // envelope leaves their interpretation up to the caller. type envelope struct { - Data *bytes.Buffer - Flags uint8 + Data *bytes.Buffer + Flags uint8 + offset int64 } +var _ messsagePayload = (*envelope)(nil) + func (e *envelope) IsSet(flag uint8) bool { return e.Flags&flag == flag } +// Read implements [io.Reader]. +func (e *envelope) Read(data []byte) (readN int, err error) { + if e.offset < 5 { + prefix := makeEnvelopePrefix(e.Flags, e.Data.Len()) + readN = copy(data, prefix[e.offset:]) + e.offset += int64(readN) + if e.offset < 5 { + return readN, nil + } + data = data[readN:] + } + n := copy(data, e.Data.Bytes()[e.offset-5:]) + e.offset += int64(n) + readN += n + if readN == 0 && e.offset == int64(e.Data.Len()+5) { + err = io.EOF + } + return readN, err +} + +// WriteTo implements [io.WriterTo]. +func (e *envelope) WriteTo(dst io.Writer) (wroteN int64, err error) { + if e.offset < 5 { + prefix := makeEnvelopePrefix(e.Flags, e.Data.Len()) + prefixN, err := dst.Write(prefix[e.offset:]) + e.offset += int64(prefixN) + wroteN += int64(prefixN) + if e.offset < 5 { + return wroteN, err + } + } + n, err := dst.Write(e.Data.Bytes()[e.offset-5:]) + e.offset += int64(n) + wroteN += int64(n) + return wroteN, err +} + +// Seek implements [io.Seeker]. Based on the implementation of [bytes.Reader]. +func (e *envelope) Seek(offset int64, whence int) (int64, error) { + var abs int64 + switch whence { + case io.SeekStart: + abs = offset + case io.SeekCurrent: + abs = e.offset + offset + case io.SeekEnd: + abs = int64(e.Data.Len()) + offset + default: + return 0, errors.New("connect.envelope.Seek: invalid whence") + } + if abs < 0 { + return 0, errors.New("connect.envelope.Seek: negative position") + } + e.offset = abs + return abs, nil +} + +// Len returns the number of bytes of the unread portion of the envelope. +func (e *envelope) Len() int { + if length := int(int64(e.Data.Len()) + 5 - e.offset); length > 0 { + return length + } + return 0 +} + type envelopeWriter struct { - writer io.Writer + sender messageSender codec Codec compressMinBytes int compressionPool *compressionPool @@ -59,7 +127,9 @@ type envelopeWriter struct { func (w *envelopeWriter) Marshal(message any) *Error { if message == nil { - if _, err := w.writer.Write(nil); err != nil { + // Send no-op message to create the request and send headers. + payload := nopPayload{} + if _, err := w.sender.Send(payload); err != nil { if connectErr, ok := asError(err); ok { return connectErr } @@ -137,18 +207,12 @@ func (w *envelopeWriter) marshal(message any) *Error { } func (w *envelopeWriter) write(env *envelope) *Error { - prefix := [5]byte{} - prefix[0] = env.Flags - binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len())) - if _, err := w.writer.Write(prefix[:]); err != nil { + if _, err := w.sender.Send(env); err != nil { if connectErr, ok := asError(err); ok { return connectErr } return errorf(CodeUnknown, "write envelope: %w", err) } - if _, err := io.Copy(w.writer, env.Data); err != nil { - return errorf(CodeUnknown, "write message: %w", err) - } return nil } @@ -279,3 +343,10 @@ func (r *envelopeReader) Read(env *envelope) *Error { env.Flags = prefixes[0] return nil } + +func makeEnvelopePrefix(flags uint8, size int) [5]byte { + prefix := [5]byte{} + prefix[0] = flags + binary.BigEndian.PutUint32(prefix[1:5], uint32(size)) + return prefix +} diff --git a/envelope_test.go b/envelope_test.go index bf187934..f153b8c3 100644 --- a/envelope_test.go +++ b/envelope_test.go @@ -16,43 +16,87 @@ package connect import ( "bytes" - "encoding/binary" "io" "testing" "connectrpc.com/connect/internal/assert" ) -func TestEnvelope_read(t *testing.T) { +func TestEnvelope(t *testing.T) { t.Parallel() - - head := [5]byte{} payload := []byte(`{"number": 42}`) - binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) - + head := makeEnvelopePrefix(0, len(payload)) buf := &bytes.Buffer{} buf.Write(head[:]) buf.Write(payload) - - t.Run("full", func(t *testing.T) { + t.Run("read", func(t *testing.T) { t.Parallel() - env := &envelope{Data: &bytes.Buffer{}} - rdr := envelopeReader{ - reader: bytes.NewReader(buf.Bytes()), - } - assert.Nil(t, rdr.Read(env)) - assert.Equal(t, payload, env.Data.Bytes()) + t.Run("full", func(t *testing.T) { + t.Parallel() + env := &envelope{Data: &bytes.Buffer{}} + rdr := envelopeReader{ + reader: bytes.NewReader(buf.Bytes()), + } + assert.Nil(t, rdr.Read(env)) + assert.Equal(t, payload, env.Data.Bytes()) + }) + t.Run("byteByByte", func(t *testing.T) { + t.Parallel() + env := &envelope{Data: &bytes.Buffer{}} + rdr := envelopeReader{ + reader: byteByByteReader{ + reader: bytes.NewReader(buf.Bytes()), + }, + } + assert.Nil(t, rdr.Read(env)) + assert.Equal(t, payload, env.Data.Bytes()) + }) }) - t.Run("byteByByte", func(t *testing.T) { + t.Run("write", func(t *testing.T) { t.Parallel() - env := &envelope{Data: &bytes.Buffer{}} - rdr := envelopeReader{ - reader: byteByByteReader{ - reader: bytes.NewReader(buf.Bytes()), - }, - } - assert.Nil(t, rdr.Read(env)) - assert.Equal(t, payload, env.Data.Bytes()) + t.Run("full", func(t *testing.T) { + t.Parallel() + dst := &bytes.Buffer{} + wtr := envelopeWriter{ + sender: writeSender{writer: dst}, + } + env := &envelope{Data: bytes.NewBuffer(payload)} + err := wtr.Write(env) + assert.Nil(t, err) + assert.Equal(t, buf.Bytes(), dst.Bytes()) + }) + t.Run("partial", func(t *testing.T) { + t.Parallel() + dst := &bytes.Buffer{} + env := &envelope{Data: bytes.NewBuffer(payload)} + _, err := io.CopyN(dst, env, 2) + assert.Nil(t, err) + _, err = env.WriteTo(dst) + assert.Nil(t, err) + assert.Equal(t, buf.Bytes(), dst.Bytes()) + }) + }) + t.Run("seek", func(t *testing.T) { + t.Parallel() + t.Run("start", func(t *testing.T) { + t.Parallel() + dst1 := &bytes.Buffer{} + dst2 := &bytes.Buffer{} + env := &envelope{Data: bytes.NewBuffer(payload)} + _, err := io.CopyN(dst1, env, 2) + assert.Nil(t, err) + assert.Equal(t, env.Len(), len(payload)+3) + _, err = env.Seek(0, io.SeekStart) + assert.Nil(t, err) + assert.Equal(t, env.Len(), len(payload)+5) + _, err = io.CopyN(dst2, env, 2) + assert.Nil(t, err) + assert.Equal(t, dst1.Bytes(), dst2.Bytes()) + _, err = env.WriteTo(dst2) + assert.Nil(t, err) + assert.Equal(t, dst2.Bytes(), buf.Bytes()) + assert.Equal(t, env.Len(), 0) + }) }) } diff --git a/error_writer.go b/error_writer.go index 81d5bb9f..773aa4e9 100644 --- a/error_writer.go +++ b/error_writer.go @@ -133,7 +133,7 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er response.WriteHeader(http.StatusOK) marshaler := &connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ - writer: response, + sender: writeSender{writer: response}, bufferPool: w.bufferPool, }, } diff --git a/protocol_connect.go b/protocol_connect.go index 299cb830..477b20f1 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -254,7 +254,7 @@ func (h *connectHandler) NewConn( request: request, responseWriter: responseWriter, marshaler: connectUnaryMarshaler{ - writer: responseWriter, + sender: writeSender{writer: responseWriter}, codec: codec, compressMinBytes: h.CompressMinBytes, compressionName: responseCompression, @@ -280,7 +280,7 @@ func (h *connectHandler) NewConn( responseWriter: responseWriter, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, + sender: writeSender{responseWriter}, codec: codec, compressMinBytes: h.CompressMinBytes, compressionPool: h.CompressionPools.Get(responseCompression), @@ -375,7 +375,7 @@ func (c *connectClient) NewConn( bufferPool: c.BufferPool, marshaler: connectUnaryRequestMarshaler{ connectUnaryMarshaler: connectUnaryMarshaler{ - writer: duplexCall, + sender: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, compressionName: c.CompressionName, @@ -415,7 +415,7 @@ func (c *connectClient) NewConn( codec: c.Codec, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ - writer: duplexCall, + sender: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, compressionPool: c.CompressionPools.Get(c.CompressionName), @@ -892,7 +892,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error { } type connectUnaryMarshaler struct { - writer io.Writer + sender messageSender codec Codec compressMinBytes int compressionName string @@ -938,7 +938,8 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error { } func (m *connectUnaryMarshaler) write(data []byte) *Error { - if _, err := m.writer.Write(data); err != nil { + payload := bytes.NewReader(data) + if _, err := m.sender.Send(payload); err != nil { if connectErr, ok := asError(err); ok { return connectErr } diff --git a/protocol_connect_test.go b/protocol_connect_test.go index f7e8b5cb..ae1bedea 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -72,7 +72,7 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { assert.Nil(t, err) writer := envelopeWriter{ - writer: &buffer, + sender: writeSender{writer: &buffer}, bufferPool: bufferPool, } err = writer.Write(&envelope{ diff --git a/protocol_grpc.go b/protocol_grpc.go index adaf72b0..ce8ceb03 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -186,7 +186,7 @@ func (g *grpcHandler) NewConn( protobuf: g.Codecs.Protobuf(), // for errors marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, + sender: writeSender{writer: responseWriter}, compressionPool: g.CompressionPools.Get(responseCompression), codec: codec, compressMinBytes: g.CompressMinBytes, @@ -284,7 +284,7 @@ func (g *grpcClient) NewConn( protobuf: g.Protobuf, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: duplexCall, + sender: duplexCall, compressionPool: g.CompressionPools.Get(g.CompressionName), codec: g.Codec, compressMinBytes: g.CompressMinBytes, diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 7dd8e587..95886b09 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -48,7 +48,7 @@ func TestGRPCHandlerSender(t *testing.T) { protobuf: protobufCodec, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, + sender: writeSender{writer: responseWriter}, codec: protobufCodec, bufferPool: bufferPool, }, @@ -181,7 +181,7 @@ func TestGRPCWebTrailerMarshalling(t *testing.T) { responseWriter := httptest.NewRecorder() marshaler := grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, + sender: writeSender{writer: responseWriter}, bufferPool: newBufferPool(), }, }