Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor http calls to support message payloads #646

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,21 @@ func newDuplexHTTPCall(
}
}

// Write to the request body.
func (d *duplexHTTPCall) Write(data []byte) (int, error) {
// Send sends a message to the server.
func (d *duplexHTTPCall) Send(payload messsagePayload) (int64, error) {
isFirst := d.ensureRequestMade()
// Before we send any data, check if the context has been canceled.
if err := d.ctx.Err(); err != nil {
return 0, wrapIfContextError(err)
}
if isFirst && data == nil {
if isFirst && payload.Len() == 0 {
// On first write a nil Send is used to send request headers. Avoid
// writing a zero-length payload to avoid superfluous errors with close.
return 0, nil
}
// It's safe to write to this side of the pipe while net/http concurrently
// reads from the other side.
bytesWritten, err := d.requestBodyWriter.Write(data)
bytesWritten, err := payload.WriteTo(d.requestBodyWriter)
if err != nil && errors.Is(err, io.ErrClosedPipe) {
// Signal that the stream is closed with the more-typical io.EOF instead of
// io.ErrClosedPipe. This makes it easier for protocol-specific wrappers to
Expand Down Expand Up @@ -295,6 +295,52 @@ func (d *duplexHTTPCall) makeRequest() {
}
}

// messsagePayload is a sized and seekable message payload. The interface is
// implemented by [*bytes.Reader] and *envelope.
type messsagePayload interface {
io.Reader
io.WriterTo
io.Seeker
Len() int
}

// nopPayload is a message payload that does nothing. It's used to send headers
// to the server.
type nopPayload struct{}

var _ messsagePayload = nopPayload{}

func (nopPayload) Read([]byte) (int, error) {
return 0, io.EOF
}
func (nopPayload) WriteTo(io.Writer) (int64, error) {
return 0, nil
}
func (nopPayload) Seek(int64, int) (int64, error) {
return 0, nil
}
func (nopPayload) Len() int {
return 0
}

// messageSender sends a message payload. The interface is implemented by
// [*duplexHTTPCall] and writeSender.
type messageSender interface {
Send(messsagePayload) (int64, error)
}

// writeSender is a sender that writes to an [io.Writer]. Useful for wrapping
// [http.ResponseWriter].
type writeSender struct {
writer io.Writer
}

var _ messageSender = writeSender{}

func (w writeSender) Send(payload messsagePayload) (int64, error) {
return payload.WriteTo(w.writer)
}

// See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33
func cloneURL(oldURL *url.URL) *url.URL {
if oldURL == nil {
Expand Down
93 changes: 82 additions & 11 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,84 @@ var errSpecialEnvelope = errorf(
// message length. gRPC and Connect interpret the bitwise flags differently, so
// envelope leaves their interpretation up to the caller.
type envelope struct {
Data *bytes.Buffer
Flags uint8
Data *bytes.Buffer
Flags uint8
offset int64
}

var _ messsagePayload = (*envelope)(nil)

func (e *envelope) IsSet(flag uint8) bool {
return e.Flags&flag == flag
}

// Read implements [io.Reader].
func (e *envelope) Read(data []byte) (readN int, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
readN = copy(data, prefix[e.offset:])
e.offset += int64(readN)
if e.offset < 5 {
return readN, nil
}
data = data[readN:]
}
n := copy(data, e.Data.Bytes()[e.offset-5:])
e.offset += int64(n)
readN += n
if readN == 0 && e.offset == int64(e.Data.Len()+5) {
err = io.EOF
}
return readN, err
}

// WriteTo implements [io.WriterTo].
func (e *envelope) WriteTo(dst io.Writer) (wroteN int64, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
prefixN, err := dst.Write(prefix[e.offset:])
e.offset += int64(prefixN)
wroteN += int64(prefixN)
if e.offset < 5 {
return wroteN, err
}
}
n, err := dst.Write(e.Data.Bytes()[e.offset-5:])
e.offset += int64(n)
wroteN += int64(n)
return wroteN, err
}

// Seek implements [io.Seeker]. Based on the implementation of [bytes.Reader].
func (e *envelope) Seek(offset int64, whence int) (int64, error) {
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = e.offset + offset
case io.SeekEnd:
abs = int64(e.Data.Len()) + offset
default:
return 0, errors.New("connect.envelope.Seek: invalid whence")
}
if abs < 0 {
return 0, errors.New("connect.envelope.Seek: negative position")
}
e.offset = abs
return abs, nil
}

// Len returns the number of bytes of the unread portion of the envelope.
func (e *envelope) Len() int {
if length := int(int64(e.Data.Len()) + 5 - e.offset); length > 0 {
return length
}
return 0
}

type envelopeWriter struct {
writer io.Writer
sender messageSender
codec Codec
compressMinBytes int
compressionPool *compressionPool
Expand All @@ -59,7 +127,9 @@ type envelopeWriter struct {

func (w *envelopeWriter) Marshal(message any) *Error {
if message == nil {
if _, err := w.writer.Write(nil); err != nil {
// Send no-op message to create the request and send headers.
payload := nopPayload{}
if _, err := w.sender.Send(payload); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down Expand Up @@ -137,18 +207,12 @@ func (w *envelopeWriter) marshal(message any) *Error {
}

func (w *envelopeWriter) write(env *envelope) *Error {
prefix := [5]byte{}
prefix[0] = env.Flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len()))
if _, err := w.writer.Write(prefix[:]); err != nil {
if _, err := w.sender.Send(env); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return errorf(CodeUnknown, "write envelope: %w", err)
}
if _, err := io.Copy(w.writer, env.Data); err != nil {
return errorf(CodeUnknown, "write message: %w", err)
}
return nil
}

Expand Down Expand Up @@ -279,3 +343,10 @@ func (r *envelopeReader) Read(env *envelope) *Error {
env.Flags = prefixes[0]
return nil
}

func makeEnvelopePrefix(flags uint8, size int) [5]byte {
prefix := [5]byte{}
prefix[0] = flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(size))
return prefix
}
90 changes: 67 additions & 23 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,87 @@ package connect

import (
"bytes"
"encoding/binary"
"io"
"testing"

"connectrpc.com/connect/internal/assert"
)

func TestEnvelope_read(t *testing.T) {
func TestEnvelope(t *testing.T) {
t.Parallel()

head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))

head := makeEnvelopePrefix(0, len(payload))
buf := &bytes.Buffer{}
buf.Write(head[:])
buf.Write(payload)

t.Run("full", func(t *testing.T) {
t.Run("read", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
t.Run("full", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
})
t.Run("byteByByte", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
})
})
t.Run("byteByByte", func(t *testing.T) {
t.Run("write", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
t.Run("full", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
wtr := envelopeWriter{
sender: writeSender{writer: dst},
}
env := &envelope{Data: bytes.NewBuffer(payload)}
err := wtr.Write(env)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
t.Run("partial", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
env := &envelope{Data: bytes.NewBuffer(payload)}
_, err := io.CopyN(dst, env, 2)
assert.Nil(t, err)
_, err = env.WriteTo(dst)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
})
t.Run("seek", func(t *testing.T) {
t.Parallel()
t.Run("start", func(t *testing.T) {
t.Parallel()
dst1 := &bytes.Buffer{}
dst2 := &bytes.Buffer{}
env := &envelope{Data: bytes.NewBuffer(payload)}
_, err := io.CopyN(dst1, env, 2)
assert.Nil(t, err)
assert.Equal(t, env.Len(), len(payload)+3)
_, err = env.Seek(0, io.SeekStart)
assert.Nil(t, err)
assert.Equal(t, env.Len(), len(payload)+5)
_, err = io.CopyN(dst2, env, 2)
assert.Nil(t, err)
assert.Equal(t, dst1.Bytes(), dst2.Bytes())
_, err = env.WriteTo(dst2)
assert.Nil(t, err)
assert.Equal(t, dst2.Bytes(), buf.Bytes())
assert.Equal(t, env.Len(), 0)
})
})
}

Expand Down
2 changes: 1 addition & 1 deletion error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er
response.WriteHeader(http.StatusOK)
marshaler := &connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: response,
sender: writeSender{writer: response},
bufferPool: w.bufferPool,
},
}
Expand Down
13 changes: 7 additions & 6 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (h *connectHandler) NewConn(
request: request,
responseWriter: responseWriter,
marshaler: connectUnaryMarshaler{
writer: responseWriter,
sender: writeSender{writer: responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
compressionName: responseCompression,
Expand All @@ -280,7 +280,7 @@ func (h *connectHandler) NewConn(
responseWriter: responseWriter,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: responseWriter,
sender: writeSender{responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
compressionPool: h.CompressionPools.Get(responseCompression),
Expand Down Expand Up @@ -375,7 +375,7 @@ func (c *connectClient) NewConn(
bufferPool: c.BufferPool,
marshaler: connectUnaryRequestMarshaler{
connectUnaryMarshaler: connectUnaryMarshaler{
writer: duplexCall,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
compressionName: c.CompressionName,
Expand Down Expand Up @@ -415,7 +415,7 @@ func (c *connectClient) NewConn(
codec: c.Codec,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: duplexCall,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
compressionPool: c.CompressionPools.Get(c.CompressionName),
Expand Down Expand Up @@ -892,7 +892,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error {
}

type connectUnaryMarshaler struct {
writer io.Writer
sender messageSender
codec Codec
compressMinBytes int
compressionName string
Expand Down Expand Up @@ -938,7 +938,8 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error {
}

func (m *connectUnaryMarshaler) write(data []byte) *Error {
if _, err := m.writer.Write(data); err != nil {
payload := bytes.NewReader(data)
if _, err := m.sender.Send(payload); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down
2 changes: 1 addition & 1 deletion protocol_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) {
assert.Nil(t, err)

writer := envelopeWriter{
writer: &buffer,
sender: writeSender{writer: &buffer},
bufferPool: bufferPool,
}
err = writer.Write(&envelope{
Expand Down
Loading