Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Read and WriterTo methods to envelopes #622

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 109 additions & 52 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,69 +42,105 @@ var errSpecialEnvelope = errorf(
type envelope struct {
Data *bytes.Buffer
Flags uint8

offset int
}

func (e *envelope) IsSet(flag uint8) bool {
return e.Flags&flag == flag
}

// Read implements io.Reader.
func (e *envelope) Read(data []byte) (readN int, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
readN = copy(data, prefix[e.offset:])
e.offset += readN
if e.offset < 5 {
return readN, nil
}
data = data[readN:]
}
n := copy(data, e.Data.Bytes()[e.offset-5:])
e.offset += n
readN += n
if readN == 0 && e.offset == e.Data.Len()+5 {
err = io.EOF
}
return readN, err
}

// WriteTo implements io.WriterTo.
func (e *envelope) WriteTo(dst io.Writer) (wroteN int64, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
prefixN, err := dst.Write(prefix[e.offset:])
e.offset += prefixN
wroteN += int64(prefixN)
if e.offset < 5 {
return wroteN, err
}
}
n, err := dst.Write(e.Data.Bytes()[e.offset-5:])
e.offset += n
wroteN += int64(n)
return wroteN, err
}
Copy link
Contributor Author

@emcfarlane emcfarlane Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For #611 will also need to add func (e *envelope) Rewind() { b.offset = 0}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be way cleaner if envelope were immutable once sent. So that means we'd instead want a NewReader() io.Reader method, so that thing can be mutable/rewindable. That way, if any synchronization becomes necessary (as exists in the payload thingie in #611), it can live in that reader impl and not leak into the envelope struct itself.


type envelopeWriter struct {
writer io.Writer
codec Codec
compressMinBytes int
compressionPool *compressionPool
bufferPool *bufferPool
sendMaxBytes int
}

func (w *envelopeWriter) Marshal(message any) *Error {
func (w *envelopeWriter) Marshal(dst io.Writer, message any) *Error {
if message == nil {
if _, err := w.writer.Write(nil); err != nil {
if _, err := dst.Write(nil); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return NewError(CodeUnknown, err)
}
return nil
}
if appender, ok := w.codec.(marshalAppender); ok {
return w.marshalAppend(message, appender)
buffer, err := w.marshal(message)
if err != nil {
return err
}
return w.marshal(message)
defer w.bufferPool.Put(buffer)
return w.Write(dst, &envelope{Data: buffer, Flags: 0})
}

// Write writes the enveloped message, compressing as necessary. It doesn't
// retain any references to the supplied envelope or its underlying data.
func (w *envelopeWriter) Write(env *envelope) *Error {
if env.IsSet(flagEnvelopeCompressed) ||
w.compressionPool == nil ||
env.Data.Len() < w.compressMinBytes {
if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes {
return errorf(CodeResourceExhausted, "message size %d exceeds sendMaxBytes %d", env.Data.Len(), w.sendMaxBytes)
func (w *envelopeWriter) Write(dst io.Writer, env *envelope) *Error {
if !env.IsSet(flagEnvelopeCompressed) &&
w.compressionPool != nil &&
env.Data.Len() > w.compressMinBytes {
if err := w.compress(env.Data); err != nil {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behaviour change on Write, now mutates the *env to compress, if needed. Nothing currently depended on the *env being unchanged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kinds of changes (mostly code motion) are sooo much easier to review if you don't change behavior at all. Was this change actually necessary? If not, please just omit and maybe do in a follow-up.

return err
}
return w.write(env)
env.Flags |= flagEnvelopeCompressed
}
data := w.bufferPool.Get()
defer w.bufferPool.Put(data)
if err := w.compressionPool.Compress(data, env.Data); err != nil {
return err
}
if w.sendMaxBytes > 0 && data.Len() > w.sendMaxBytes {
return errorf(CodeResourceExhausted, "compressed message size %d exceeds sendMaxBytes %d", data.Len(), w.sendMaxBytes)
return w.write(dst, env)
}

func (w *envelopeWriter) marshal(message any) (*bytes.Buffer, *Error) {
if appender, ok := w.codec.(marshalAppender); ok {
return w.marshalAppend(message, appender)
}
return w.write(&envelope{
Data: data,
Flags: env.Flags | flagEnvelopeCompressed,
})
return w.marshalBase(message)
}

func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Error {
func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) (*bytes.Buffer, *Error) {
// Codec supports MarshalAppend; try to re-use a []byte from the pool.
buffer := w.bufferPool.Get()
defer w.bufferPool.Put(buffer)
raw, err := codec.MarshalAppend(buffer.Bytes(), message)
if err != nil {
return errorf(CodeInternal, "marshal message: %w", err)
w.bufferPool.Put(buffer)
return nil, errorf(CodeInternal, "marshal message: %w", err)
}
if cap(raw) > buffer.Cap() {
// The buffer from the pool was too small, so MarshalAppend grew the slice.
Expand All @@ -119,54 +155,68 @@ func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Erro
// copies but avoids allocating.
buffer.Write(raw)
}
envelope := &envelope{Data: buffer}
return w.Write(envelope)
return buffer, nil
}

func (w *envelopeWriter) marshal(message any) *Error {
func (w *envelopeWriter) marshalBase(message any) (*bytes.Buffer, *Error) {
// Codec doesn't support MarshalAppend; let Marshal allocate a []byte.
raw, err := w.codec.Marshal(message)
if err != nil {
return errorf(CodeInternal, "marshal message: %w", err)
return nil, errorf(CodeInternal, "marshal message: %w", err)
}
buffer := bytes.NewBuffer(raw)
// Put our new []byte into the pool for later reuse.
defer w.bufferPool.Put(buffer)
envelope := &envelope{Data: buffer}
return w.Write(envelope)
return bytes.NewBuffer(raw), nil
}

func (w *envelopeWriter) write(env *envelope) *Error {
prefix := [5]byte{}
prefix[0] = env.Flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len()))
if _, err := w.writer.Write(prefix[:]); err != nil {
func (w *envelopeWriter) compress(buffer *bytes.Buffer) *Error {
compressed := w.bufferPool.Get()
defer w.bufferPool.Put(compressed)
if err := w.compressionPool.Compress(compressed, buffer); err != nil {
return err
}
*buffer, *compressed = *compressed, *buffer // Swap buffer contents.
return nil
}

func (w *envelopeWriter) checkSize(env *envelope) *Error {
if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes {
str := "message"
if env.IsSet(flagEnvelopeCompressed) {
str = "compressed message"
}
return errorf(CodeResourceExhausted,
"%s size %d exceeds sendMaxBytes %d",
str, env.Data.Len(), w.sendMaxBytes)
}
return nil
}

func (w *envelopeWriter) write(dst io.Writer, env *envelope) *Error {
if err := w.checkSize(env); err != nil {
return err
}
if _, err := env.WriteTo(dst); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return errorf(CodeUnknown, "write envelope: %w", err)
}
if _, err := io.Copy(w.writer, env.Data); err != nil {
return errorf(CodeUnknown, "write message: %w", err)
}
return nil
}

type envelopeReader struct {
reader io.Reader
codec Codec
last envelope
compressionPool *compressionPool
bufferPool *bufferPool
readMaxBytes int
}

func (r *envelopeReader) Unmarshal(message any) *Error {
func (r *envelopeReader) Unmarshal(message any, src io.Reader) *Error {
buffer := r.bufferPool.Get()
defer r.bufferPool.Put(buffer)

env := &envelope{Data: buffer}
err := r.Read(env)
err := r.Read(env, src)
switch {
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
Expand Down Expand Up @@ -200,7 +250,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// Drain the rest of the stream to ensure there is no extra data.
if n, err := discard(r.reader); err != nil {
if n, err := discard(src); err != nil {
return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err)
} else if n > 0 {
return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n)
Expand All @@ -224,11 +274,11 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
return nil
}

func (r *envelopeReader) Read(env *envelope) *Error {
func (r *envelopeReader) Read(env *envelope, src io.Reader) *Error {
prefixes := [5]byte{}
// io.ReadFull reads the number of bytes requested, or returns an error.
// io.EOF will only be returned if no bytes were read.
if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil {
if _, err := io.ReadFull(src, prefixes[:]); err != nil {
if errors.Is(err, io.EOF) {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
Expand All @@ -250,7 +300,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
_, err := io.CopyN(io.Discard, r.reader, size)
_, err := io.CopyN(io.Discard, src, size)
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeUnknown, "read enveloped message: %w", err)
}
Expand All @@ -259,7 +309,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
if readN, err := io.CopyN(env.Data, src, size); err != nil {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
Expand All @@ -279,3 +329,10 @@ func (r *envelopeReader) Read(env *envelope) *Error {
env.Flags = prefixes[0]
return nil
}

func makeEnvelopePrefix(flags uint8, size int) [5]byte {
prefix := [5]byte{}
prefix[0] = flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(size))
return prefix
}
48 changes: 38 additions & 10 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,61 @@ func TestEnvelope_read(t *testing.T) {
head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))

buf := &bytes.Buffer{}
buf.Write(head[:])
buf.Write(payload)

t.Run("full", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
rdr := envelopeReader{}
src := bytes.NewReader(buf.Bytes())
assert.Nil(t, rdr.Read(env, src))
assert.Equal(t, payload, env.Data.Bytes())
})
t.Run("byteByByte", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
rdr := envelopeReader{}
src := byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
assert.Nil(t, rdr.Read(env, src))
assert.Equal(t, payload, env.Data.Bytes())
})
}

func TestEnvelope_write(t *testing.T) {
t.Parallel()

head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))
buf := &bytes.Buffer{}
buf.Write(head[:])
buf.Write(payload)

t.Run("match", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
wtr := envelopeWriter{}
env := &envelope{Data: bytes.NewBuffer(payload)}
err := wtr.Write(dst, env)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
t.Run("partial", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
env := &envelope{Data: bytes.NewBuffer(payload)}
_, err := io.CopyN(dst, env, 2)
assert.Nil(t, err)
_, err = env.WriteTo(dst)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
}

// byteByByteReader is test reader that reads a single byte at a time.
type byteByByteReader struct {
reader io.ByteReader
Expand Down
3 changes: 1 addition & 2 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,11 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er
response.WriteHeader(http.StatusOK)
marshaler := &connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: response,
bufferPool: w.bufferPool,
},
}
// MarshalEndStream returns *Error: check return value to avoid typed nils.
if marshalErr := marshaler.MarshalEndStream(err, make(http.Header)); marshalErr != nil {
if marshalErr := marshaler.MarshalEndStream(response, err, make(http.Header)); marshalErr != nil {
return marshalErr
}
return nil
Expand Down
Loading