Skip to content

Commit

Permalink
Add Read an WriterTo methods to envelopes
Browse files Browse the repository at this point in the history
Allow src and dst to be set for envelope read/writers. Envelopes also
implement io.Reader and io.WriterTo methods to allow passing envelopes
as a message payload.
  • Loading branch information
emcfarlane committed Nov 2, 2023
1 parent 5176a6c commit 4c44379
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 128 deletions.
161 changes: 109 additions & 52 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,69 +42,105 @@ var errSpecialEnvelope = errorf(
type envelope struct {
Data *bytes.Buffer
Flags uint8

offset int
}

func (e *envelope) IsSet(flag uint8) bool {
return e.Flags&flag == flag
}

// Read implements io.Reader.
func (e *envelope) Read(data []byte) (readN int, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
readN = copy(data, prefix[e.offset:])
e.offset += readN
if e.offset < 5 {
return readN, nil
}
data = data[readN:]
}
n := copy(data, e.Data.Bytes()[e.offset-5:])
e.offset += n
readN += n
if readN == 0 && e.offset == e.Data.Len()+5 {
err = io.EOF
}
return readN, err
}

// WriteTo implements io.WriterTo.
func (e *envelope) WriteTo(dst io.Writer) (wroteN int64, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
prefixN, err := dst.Write(prefix[e.offset:])
e.offset += prefixN
wroteN += int64(prefixN)
if e.offset < 5 {
return wroteN, err
}
}
n, err := dst.Write(e.Data.Bytes()[e.offset-5:])
e.offset += n
wroteN += int64(n)
return wroteN, err
}

type envelopeWriter struct {
writer io.Writer
codec Codec
compressMinBytes int
compressionPool *compressionPool
bufferPool *bufferPool
sendMaxBytes int
}

func (w *envelopeWriter) Marshal(message any) *Error {
func (w *envelopeWriter) Marshal(dst io.Writer, message any) *Error {
if message == nil {
if _, err := w.writer.Write(nil); err != nil {
if _, err := dst.Write(nil); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return NewError(CodeUnknown, err)
}
return nil
}
if appender, ok := w.codec.(marshalAppender); ok {
return w.marshalAppend(message, appender)
buffer, err := w.marshal(message)
if err != nil {
return err
}
return w.marshal(message)
defer w.bufferPool.Put(buffer)
return w.Write(dst, &envelope{Data: buffer, Flags: 0})
}

// Write writes the enveloped message, compressing as necessary. It doesn't
// retain any references to the supplied envelope or its underlying data.
func (w *envelopeWriter) Write(env *envelope) *Error {
if env.IsSet(flagEnvelopeCompressed) ||
w.compressionPool == nil ||
env.Data.Len() < w.compressMinBytes {
if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes {
return errorf(CodeResourceExhausted, "message size %d exceeds sendMaxBytes %d", env.Data.Len(), w.sendMaxBytes)
func (w *envelopeWriter) Write(dst io.Writer, env *envelope) *Error {
if !env.IsSet(flagEnvelopeCompressed) &&
w.compressionPool != nil &&
env.Data.Len() > w.compressMinBytes {
if err := w.compress(env.Data); err != nil {
return err
}
return w.write(env)
env.Flags |= flagEnvelopeCompressed
}
data := w.bufferPool.Get()
defer w.bufferPool.Put(data)
if err := w.compressionPool.Compress(data, env.Data); err != nil {
return err
}
if w.sendMaxBytes > 0 && data.Len() > w.sendMaxBytes {
return errorf(CodeResourceExhausted, "compressed message size %d exceeds sendMaxBytes %d", data.Len(), w.sendMaxBytes)
return w.write(dst, env)
}

func (w *envelopeWriter) marshal(message any) (*bytes.Buffer, *Error) {
if appender, ok := w.codec.(marshalAppender); ok {
return w.marshalAppend(message, appender)
}
return w.write(&envelope{
Data: data,
Flags: env.Flags | flagEnvelopeCompressed,
})
return w.marshalBase(message)
}

func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Error {
func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) (*bytes.Buffer, *Error) {
// Codec supports MarshalAppend; try to re-use a []byte from the pool.
buffer := w.bufferPool.Get()
defer w.bufferPool.Put(buffer)
raw, err := codec.MarshalAppend(buffer.Bytes(), message)
if err != nil {
return errorf(CodeInternal, "marshal message: %w", err)
w.bufferPool.Put(buffer)
return nil, errorf(CodeInternal, "marshal message: %w", err)
}
if cap(raw) > buffer.Cap() {
// The buffer from the pool was too small, so MarshalAppend grew the slice.
Expand All @@ -119,54 +155,68 @@ func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Erro
// copies but avoids allocating.
buffer.Write(raw)
}
envelope := &envelope{Data: buffer}
return w.Write(envelope)
return buffer, nil
}

func (w *envelopeWriter) marshal(message any) *Error {
func (w *envelopeWriter) marshalBase(message any) (*bytes.Buffer, *Error) {
// Codec doesn't support MarshalAppend; let Marshal allocate a []byte.
raw, err := w.codec.Marshal(message)
if err != nil {
return errorf(CodeInternal, "marshal message: %w", err)
return nil, errorf(CodeInternal, "marshal message: %w", err)
}
buffer := bytes.NewBuffer(raw)
// Put our new []byte into the pool for later reuse.
defer w.bufferPool.Put(buffer)
envelope := &envelope{Data: buffer}
return w.Write(envelope)
return bytes.NewBuffer(raw), nil
}

func (w *envelopeWriter) write(env *envelope) *Error {
prefix := [5]byte{}
prefix[0] = env.Flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len()))
if _, err := w.writer.Write(prefix[:]); err != nil {
func (w *envelopeWriter) compress(buffer *bytes.Buffer) *Error {
compressed := w.bufferPool.Get()
defer w.bufferPool.Put(compressed)
if err := w.compressionPool.Compress(compressed, buffer); err != nil {
return err
}
*buffer, *compressed = *compressed, *buffer // Swap buffer contents.
return nil
}

func (w *envelopeWriter) checkSize(env *envelope) *Error {
if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes {
str := "message"
if env.IsSet(flagEnvelopeCompressed) {
str = "compressed message"
}
return errorf(CodeResourceExhausted,
"%s size %d exceeds sendMaxBytes %d",
str, env.Data.Len(), w.sendMaxBytes)
}
return nil
}

func (w *envelopeWriter) write(dst io.Writer, env *envelope) *Error {
if err := w.checkSize(env); err != nil {
return err
}
if _, err := env.WriteTo(dst); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return errorf(CodeUnknown, "write envelope: %w", err)
}
if _, err := io.Copy(w.writer, env.Data); err != nil {
return errorf(CodeUnknown, "write message: %w", err)
}
return nil
}

type envelopeReader struct {
reader io.Reader
codec Codec
last envelope
compressionPool *compressionPool
bufferPool *bufferPool
readMaxBytes int
}

func (r *envelopeReader) Unmarshal(message any) *Error {
func (r *envelopeReader) Unmarshal(message any, src io.Reader) *Error {
buffer := r.bufferPool.Get()
defer r.bufferPool.Put(buffer)

env := &envelope{Data: buffer}
err := r.Read(env)
err := r.Read(env, src)
switch {
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
Expand Down Expand Up @@ -200,7 +250,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// Drain the rest of the stream to ensure there is no extra data.
if n, err := discard(r.reader); err != nil {
if n, err := discard(src); err != nil {
return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err)
} else if n > 0 {
return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n)
Expand All @@ -224,11 +274,11 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
return nil
}

func (r *envelopeReader) Read(env *envelope) *Error {
func (r *envelopeReader) Read(env *envelope, src io.Reader) *Error {
prefixes := [5]byte{}
// io.ReadFull reads the number of bytes requested, or returns an error.
// io.EOF will only be returned if no bytes were read.
if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil {
if _, err := io.ReadFull(src, prefixes[:]); err != nil {
if errors.Is(err, io.EOF) {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
Expand All @@ -250,7 +300,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
_, err := io.CopyN(io.Discard, r.reader, size)
_, err := io.CopyN(io.Discard, src, size)
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeUnknown, "read enveloped message: %w", err)
}
Expand All @@ -259,7 +309,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
if readN, err := io.CopyN(env.Data, src, size); err != nil {
if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil {
// We're reading from an http.MaxBytesHandler, and we've exceeded the read limit.
return maxBytesErr
Expand All @@ -279,3 +329,10 @@ func (r *envelopeReader) Read(env *envelope) *Error {
env.Flags = prefixes[0]
return nil
}

func makeEnvelopePrefix(flags uint8, size int) [5]byte {
prefix := [5]byte{}
prefix[0] = flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(size))
return prefix
}
48 changes: 38 additions & 10 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,61 @@ func TestEnvelope_read(t *testing.T) {
head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))

buf := &bytes.Buffer{}
buf.Write(head[:])
buf.Write(payload)

t.Run("full", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
rdr := envelopeReader{}
src := bytes.NewReader(buf.Bytes())
assert.Nil(t, rdr.Read(env, src))
assert.Equal(t, payload, env.Data.Bytes())
})
t.Run("byteByByte", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
rdr := envelopeReader{}
src := byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
assert.Nil(t, rdr.Read(env, src))
assert.Equal(t, payload, env.Data.Bytes())
})
}

func TestEnvelope_write(t *testing.T) {
t.Parallel()

head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))
buf := &bytes.Buffer{}
buf.Write(head[:])
buf.Write(payload)

t.Run("match", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
wtr := envelopeWriter{}
env := &envelope{Data: bytes.NewBuffer(payload)}
err := wtr.Write(dst, env)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
t.Run("partial", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
env := &envelope{Data: bytes.NewBuffer(payload)}
_, err := io.CopyN(dst, env, 2)
assert.Nil(t, err)
_, err = env.WriteTo(dst)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
}

// byteByByteReader is test reader that reads a single byte at a time.
type byteByByteReader struct {
reader io.ByteReader
Expand Down
3 changes: 1 addition & 2 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,11 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er
response.WriteHeader(http.StatusOK)
marshaler := &connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: response,
bufferPool: w.bufferPool,
},
}
// MarshalEndStream returns *Error: check return value to avoid typed nils.
if marshalErr := marshaler.MarshalEndStream(err, make(http.Header)); marshalErr != nil {
if marshalErr := marshaler.MarshalEndStream(response, err, make(http.Header)); marshalErr != nil {
return marshalErr
}
return nil
Expand Down
Loading

0 comments on commit 4c44379

Please sign in to comment.