diff --git a/envelope.go b/envelope.go index 7452e1a6..e8d13c2e 100644 --- a/envelope.go +++ b/envelope.go @@ -42,14 +42,52 @@ var errSpecialEnvelope = errorf( type envelope struct { Data *bytes.Buffer Flags uint8 + + offset int } func (e *envelope) IsSet(flag uint8) bool { return e.Flags&flag == flag } +// Read implements io.Reader. +func (e *envelope) Read(data []byte) (readN int, err error) { + if e.offset < 5 { + prefix := makeEnvelopePrefix(e.Flags, e.Data.Len()) + readN = copy(data, prefix[e.offset:]) + e.offset += readN + if e.offset < 5 { + return readN, nil + } + data = data[readN:] + } + n := copy(data, e.Data.Bytes()[e.offset-5:]) + e.offset += n + readN += n + if readN == 0 && e.offset == e.Data.Len()+5 { + err = io.EOF + } + return readN, err +} + +// WriteTo implements io.WriterTo. +func (e *envelope) WriteTo(dst io.Writer) (wroteN int64, err error) { + if e.offset < 5 { + prefix := makeEnvelopePrefix(e.Flags, e.Data.Len()) + prefixN, err := dst.Write(prefix[e.offset:]) + e.offset += prefixN + wroteN += int64(prefixN) + if e.offset < 5 { + return wroteN, err + } + } + n, err := dst.Write(e.Data.Bytes()[e.offset-5:]) + e.offset += n + wroteN += int64(n) + return wroteN, err +} + type envelopeWriter struct { - writer io.Writer codec Codec compressMinBytes int compressionPool *compressionPool @@ -57,9 +95,9 @@ type envelopeWriter struct { sendMaxBytes int } -func (w *envelopeWriter) Marshal(message any) *Error { +func (w *envelopeWriter) Marshal(dst io.Writer, message any) *Error { if message == nil { - if _, err := w.writer.Write(nil); err != nil { + if _, err := dst.Write(nil); err != nil { if connectErr, ok := asError(err); ok { return connectErr } @@ -67,44 +105,42 @@ func (w *envelopeWriter) Marshal(message any) *Error { } return nil } - if appender, ok := w.codec.(marshalAppender); ok { - return w.marshalAppend(message, appender) + buffer, err := w.marshal(message) + if err != nil { + return err } - return w.marshal(message) + defer w.bufferPool.Put(buffer) + return w.Write(dst, &envelope{Data: buffer, Flags: 0}) } // Write writes the enveloped message, compressing as necessary. It doesn't // retain any references to the supplied envelope or its underlying data. -func (w *envelopeWriter) Write(env *envelope) *Error { - if env.IsSet(flagEnvelopeCompressed) || - w.compressionPool == nil || - env.Data.Len() < w.compressMinBytes { - if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes { - return errorf(CodeResourceExhausted, "message size %d exceeds sendMaxBytes %d", env.Data.Len(), w.sendMaxBytes) +func (w *envelopeWriter) Write(dst io.Writer, env *envelope) *Error { + if !env.IsSet(flagEnvelopeCompressed) && + w.compressionPool != nil && + env.Data.Len() > w.compressMinBytes { + if err := w.compress(env.Data); err != nil { + return err } - return w.write(env) + env.Flags |= flagEnvelopeCompressed } - data := w.bufferPool.Get() - defer w.bufferPool.Put(data) - if err := w.compressionPool.Compress(data, env.Data); err != nil { - return err - } - if w.sendMaxBytes > 0 && data.Len() > w.sendMaxBytes { - return errorf(CodeResourceExhausted, "compressed message size %d exceeds sendMaxBytes %d", data.Len(), w.sendMaxBytes) + return w.write(dst, env) +} + +func (w *envelopeWriter) marshal(message any) (*bytes.Buffer, *Error) { + if appender, ok := w.codec.(marshalAppender); ok { + return w.marshalAppend(message, appender) } - return w.write(&envelope{ - Data: data, - Flags: env.Flags | flagEnvelopeCompressed, - }) + return w.marshalBase(message) } -func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Error { +func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) (*bytes.Buffer, *Error) { // Codec supports MarshalAppend; try to re-use a []byte from the pool. buffer := w.bufferPool.Get() - defer w.bufferPool.Put(buffer) raw, err := codec.MarshalAppend(buffer.Bytes(), message) if err != nil { - return errorf(CodeInternal, "marshal message: %w", err) + w.bufferPool.Put(buffer) + return nil, errorf(CodeInternal, "marshal message: %w", err) } if cap(raw) > buffer.Cap() { // The buffer from the pool was too small, so MarshalAppend grew the slice. @@ -119,41 +155,55 @@ func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Erro // copies but avoids allocating. buffer.Write(raw) } - envelope := &envelope{Data: buffer} - return w.Write(envelope) + return buffer, nil } -func (w *envelopeWriter) marshal(message any) *Error { +func (w *envelopeWriter) marshalBase(message any) (*bytes.Buffer, *Error) { // Codec doesn't support MarshalAppend; let Marshal allocate a []byte. raw, err := w.codec.Marshal(message) if err != nil { - return errorf(CodeInternal, "marshal message: %w", err) + return nil, errorf(CodeInternal, "marshal message: %w", err) } - buffer := bytes.NewBuffer(raw) - // Put our new []byte into the pool for later reuse. - defer w.bufferPool.Put(buffer) - envelope := &envelope{Data: buffer} - return w.Write(envelope) + return bytes.NewBuffer(raw), nil } -func (w *envelopeWriter) write(env *envelope) *Error { - prefix := [5]byte{} - prefix[0] = env.Flags - binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len())) - if _, err := w.writer.Write(prefix[:]); err != nil { +func (w *envelopeWriter) compress(buffer *bytes.Buffer) *Error { + compressed := w.bufferPool.Get() + defer w.bufferPool.Put(compressed) + if err := w.compressionPool.Compress(compressed, buffer); err != nil { + return err + } + *buffer, *compressed = *compressed, *buffer // Swap buffer contents. + return nil +} + +func (w *envelopeWriter) checkSize(env *envelope) *Error { + if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes { + str := "message" + if env.IsSet(flagEnvelopeCompressed) { + str = "compressed message" + } + return errorf(CodeResourceExhausted, + "%s size %d exceeds sendMaxBytes %d", + str, env.Data.Len(), w.sendMaxBytes) + } + return nil +} + +func (w *envelopeWriter) write(dst io.Writer, env *envelope) *Error { + if err := w.checkSize(env); err != nil { + return err + } + if _, err := env.WriteTo(dst); err != nil { if connectErr, ok := asError(err); ok { return connectErr } - return errorf(CodeUnknown, "write envelope: %w", err) - } - if _, err := io.Copy(w.writer, env.Data); err != nil { return errorf(CodeUnknown, "write message: %w", err) } return nil } type envelopeReader struct { - reader io.Reader codec Codec last envelope compressionPool *compressionPool @@ -161,12 +211,12 @@ type envelopeReader struct { readMaxBytes int } -func (r *envelopeReader) Unmarshal(message any) *Error { +func (r *envelopeReader) Unmarshal(message any, src io.Reader) *Error { buffer := r.bufferPool.Get() defer r.bufferPool.Put(buffer) env := &envelope{Data: buffer} - err := r.Read(env) + err := r.Read(env, src) switch { case err == nil && (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && @@ -200,7 +250,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error { if env.Flags != 0 && env.Flags != flagEnvelopeCompressed { // Drain the rest of the stream to ensure there is no extra data. - if n, err := discard(r.reader); err != nil { + if n, err := discard(src); err != nil { return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err) } else if n > 0 { return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n) @@ -224,11 +274,11 @@ func (r *envelopeReader) Unmarshal(message any) *Error { return nil } -func (r *envelopeReader) Read(env *envelope) *Error { +func (r *envelopeReader) Read(env *envelope, src io.Reader) *Error { prefixes := [5]byte{} // io.ReadFull reads the number of bytes requested, or returns an error. // io.EOF will only be returned if no bytes were read. - if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil { + if _, err := io.ReadFull(src, prefixes[:]); err != nil { if errors.Is(err, io.EOF) { // The stream ended cleanly. That's expected, but we need to propagate an EOF // to the user so that they know that the stream has ended. We shouldn't @@ -250,7 +300,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { } size := int64(binary.BigEndian.Uint32(prefixes[1:5])) if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) { - _, err := io.CopyN(io.Discard, r.reader, size) + _, err := io.CopyN(io.Discard, src, size) if err != nil && !errors.Is(err, io.EOF) { return errorf(CodeUnknown, "read enveloped message: %w", err) } @@ -259,7 +309,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { // We've read the prefix, so we know how many bytes to expect. // CopyN will return an error if it doesn't read the requested // number of bytes. - if readN, err := io.CopyN(env.Data, r.reader, size); err != nil { + if readN, err := io.CopyN(env.Data, src, size); err != nil { if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. return maxBytesErr @@ -279,3 +329,10 @@ func (r *envelopeReader) Read(env *envelope) *Error { env.Flags = prefixes[0] return nil } + +func makeEnvelopePrefix(flags uint8, size int) [5]byte { + prefix := [5]byte{} + prefix[0] = flags + binary.BigEndian.PutUint32(prefix[1:5], uint32(size)) + return prefix +} diff --git a/envelope_test.go b/envelope_test.go index bf187934..68846e36 100644 --- a/envelope_test.go +++ b/envelope_test.go @@ -29,7 +29,6 @@ func TestEnvelope_read(t *testing.T) { head := [5]byte{} payload := []byte(`{"number": 42}`) binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) - buf := &bytes.Buffer{} buf.Write(head[:]) buf.Write(payload) @@ -37,25 +36,54 @@ func TestEnvelope_read(t *testing.T) { t.Run("full", func(t *testing.T) { t.Parallel() env := &envelope{Data: &bytes.Buffer{}} - rdr := envelopeReader{ - reader: bytes.NewReader(buf.Bytes()), - } - assert.Nil(t, rdr.Read(env)) + rdr := envelopeReader{} + src := bytes.NewReader(buf.Bytes()) + assert.Nil(t, rdr.Read(env, src)) assert.Equal(t, payload, env.Data.Bytes()) }) t.Run("byteByByte", func(t *testing.T) { t.Parallel() env := &envelope{Data: &bytes.Buffer{}} - rdr := envelopeReader{ - reader: byteByByteReader{ - reader: bytes.NewReader(buf.Bytes()), - }, + rdr := envelopeReader{} + src := byteByByteReader{ + reader: bytes.NewReader(buf.Bytes()), } - assert.Nil(t, rdr.Read(env)) + assert.Nil(t, rdr.Read(env, src)) assert.Equal(t, payload, env.Data.Bytes()) }) } +func TestEnvelope_write(t *testing.T) { + t.Parallel() + + head := [5]byte{} + payload := []byte(`{"number": 42}`) + binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) + buf := &bytes.Buffer{} + buf.Write(head[:]) + buf.Write(payload) + + t.Run("match", func(t *testing.T) { + t.Parallel() + dst := &bytes.Buffer{} + wtr := envelopeWriter{} + env := &envelope{Data: bytes.NewBuffer(payload)} + err := wtr.Write(dst, env) + assert.Nil(t, err) + assert.Equal(t, buf.Bytes(), dst.Bytes()) + }) + t.Run("partial", func(t *testing.T) { + t.Parallel() + dst := &bytes.Buffer{} + env := &envelope{Data: bytes.NewBuffer(payload)} + _, err := io.CopyN(dst, env, 2) + assert.Nil(t, err) + _, err = env.WriteTo(dst) + assert.Nil(t, err) + assert.Equal(t, buf.Bytes(), dst.Bytes()) + }) +} + // byteByByteReader is test reader that reads a single byte at a time. type byteByByteReader struct { reader io.ByteReader diff --git a/error_writer.go b/error_writer.go index 81d5bb9f..1945e95a 100644 --- a/error_writer.go +++ b/error_writer.go @@ -133,12 +133,11 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er response.WriteHeader(http.StatusOK) marshaler := &connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ - writer: response, bufferPool: w.bufferPool, }, } // MarshalEndStream returns *Error: check return value to avoid typed nils. - if marshalErr := marshaler.MarshalEndStream(err, make(http.Header)); marshalErr != nil { + if marshalErr := marshaler.MarshalEndStream(response, err, make(http.Header)); marshalErr != nil { return marshalErr } return nil diff --git a/protocol_connect.go b/protocol_connect.go index bd5500e4..efd597fe 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -252,9 +252,9 @@ func (h *connectHandler) NewConn( spec: h.Spec, peer: peer, request: request, + requestBody: requestBody, responseWriter: responseWriter, marshaler: connectUnaryMarshaler{ - writer: responseWriter, codec: codec, compressMinBytes: h.CompressMinBytes, compressionName: responseCompression, @@ -264,7 +264,6 @@ func (h *connectHandler) NewConn( sendMaxBytes: h.SendMaxBytes, }, unmarshaler: connectUnaryUnmarshaler{ - reader: requestBody, codec: codec, compressionPool: h.CompressionPools.Get(requestCompression), bufferPool: h.BufferPool, @@ -277,10 +276,10 @@ func (h *connectHandler) NewConn( spec: h.Spec, peer: peer, request: request, + requestBody: requestBody, responseWriter: responseWriter, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, codec: codec, compressMinBytes: h.CompressMinBytes, compressionPool: h.CompressionPools.Get(responseCompression), @@ -290,7 +289,6 @@ func (h *connectHandler) NewConn( }, unmarshaler: connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ - reader: requestBody, codec: codec, compressionPool: h.CompressionPools.Get(requestCompression), bufferPool: h.BufferPool, @@ -375,7 +373,6 @@ func (c *connectClient) NewConn( bufferPool: c.BufferPool, marshaler: connectUnaryRequestMarshaler{ connectUnaryMarshaler: connectUnaryMarshaler{ - writer: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, compressionName: c.CompressionName, @@ -386,7 +383,6 @@ func (c *connectClient) NewConn( }, }, unmarshaler: connectUnaryUnmarshaler{ - reader: duplexCall, codec: c.Codec, bufferPool: c.BufferPool, readMaxBytes: c.ReadMaxBytes, @@ -415,7 +411,6 @@ func (c *connectClient) NewConn( codec: c.Codec, marshaler: connectStreamingMarshaler{ envelopeWriter: envelopeWriter{ - writer: duplexCall, codec: c.Codec, compressMinBytes: c.CompressMinBytes, compressionPool: c.CompressionPools.Get(c.CompressionName), @@ -425,7 +420,6 @@ func (c *connectClient) NewConn( }, unmarshaler: connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ - reader: duplexCall, codec: c.Codec, bufferPool: c.BufferPool, readMaxBytes: c.ReadMaxBytes, @@ -461,7 +455,7 @@ func (cc *connectUnaryClientConn) Peer() Peer { } func (cc *connectUnaryClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + if err := cc.marshaler.Marshal(cc.duplexCall, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -479,7 +473,7 @@ func (cc *connectUnaryClientConn) Receive(msg any) error { if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { return err } - if err := cc.unmarshaler.Unmarshal(msg); err != nil { + if err := cc.unmarshaler.Unmarshal(msg, cc.duplexCall); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -529,12 +523,11 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err return serverErr } else if response.StatusCode != http.StatusOK { unmarshaler := connectUnaryUnmarshaler{ - reader: response.Body, compressionPool: cc.compressionPools.Get(compression), bufferPool: cc.bufferPool, } var wireErr connectWireError - if err := unmarshaler.UnmarshalFunc(&wireErr, json.Unmarshal); err != nil { + if err := unmarshaler.UnmarshalFunc(&wireErr, response.Body, json.Unmarshal); err != nil { return NewError( connectHTTPToCode(response.StatusCode), errors.New(response.Status), @@ -574,7 +567,7 @@ func (cc *connectStreamingClientConn) Peer() Peer { } func (cc *connectStreamingClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + if err := cc.marshaler.Marshal(cc.duplexCall, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -592,7 +585,7 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { return err } - err := cc.unmarshaler.Unmarshal(msg) + err := cc.unmarshaler.Unmarshal(msg, cc.duplexCall) if err == nil { return nil } @@ -663,6 +656,7 @@ type connectUnaryHandlerConn struct { spec Spec peer Peer request *http.Request + requestBody io.ReadCloser responseWriter http.ResponseWriter marshaler connectUnaryMarshaler unmarshaler connectUnaryUnmarshaler @@ -679,7 +673,7 @@ func (hc *connectUnaryHandlerConn) Peer() Peer { } func (hc *connectUnaryHandlerConn) Receive(msg any) error { - if err := hc.unmarshaler.Unmarshal(msg); err != nil { + if err := hc.unmarshaler.Unmarshal(msg, hc.requestBody); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -692,7 +686,7 @@ func (hc *connectUnaryHandlerConn) RequestHeader() http.Header { func (hc *connectUnaryHandlerConn) Send(msg any) error { hc.wroteBody = true hc.writeResponseHeader(nil /* error */) - if err := hc.marshaler.Marshal(msg); err != nil { + if err := hc.marshaler.Marshal(hc.responseWriter, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -760,6 +754,7 @@ type connectStreamingHandlerConn struct { spec Spec peer Peer request *http.Request + requestBody io.ReadCloser responseWriter http.ResponseWriter marshaler connectStreamingMarshaler unmarshaler connectStreamingUnmarshaler @@ -775,7 +770,7 @@ func (hc *connectStreamingHandlerConn) Peer() Peer { } func (hc *connectStreamingHandlerConn) Receive(msg any) error { - if err := hc.unmarshaler.Unmarshal(msg); err != nil { + if err := hc.unmarshaler.Unmarshal(msg, hc.requestBody); err != nil { // Clients may not send end-of-stream metadata, so we don't need to handle // errSpecialEnvelope. return err @@ -789,7 +784,7 @@ func (hc *connectStreamingHandlerConn) RequestHeader() http.Header { func (hc *connectStreamingHandlerConn) Send(msg any) error { defer flushResponseWriter(hc.responseWriter) - if err := hc.marshaler.Marshal(msg); err != nil { + if err := hc.marshaler.Marshal(hc.responseWriter, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -805,7 +800,7 @@ func (hc *connectStreamingHandlerConn) ResponseTrailer() http.Header { func (hc *connectStreamingHandlerConn) Close(err error) error { defer flushResponseWriter(hc.responseWriter) - if err := hc.marshaler.MarshalEndStream(err, hc.responseTrailer); err != nil { + if err := hc.marshaler.MarshalEndStream(hc.responseWriter, err, hc.responseTrailer); err != nil { _ = hc.request.Body.Close() return err } @@ -828,7 +823,7 @@ type connectStreamingMarshaler struct { envelopeWriter } -func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Header) *Error { +func (m *connectStreamingMarshaler) MarshalEndStream(dst io.Writer, err error, trailer http.Header) *Error { end := &connectEndStreamMessage{Trailer: trailer} if err != nil { end.Error = newConnectWireError(err) @@ -842,7 +837,7 @@ func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Hea } raw := bytes.NewBuffer(data) defer m.envelopeWriter.bufferPool.Put(raw) - return m.Write(&envelope{ + return m.Write(dst, &envelope{ Data: raw, Flags: connectFlagEnvelopeEndStream, }) @@ -855,8 +850,8 @@ type connectStreamingUnmarshaler struct { trailer http.Header } -func (u *connectStreamingUnmarshaler) Unmarshal(message any) *Error { - err := u.envelopeReader.Unmarshal(message) +func (u *connectStreamingUnmarshaler) Unmarshal(message any, src io.Reader) *Error { + err := u.envelopeReader.Unmarshal(message, src) if err == nil { return nil } @@ -892,7 +887,6 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error { } type connectUnaryMarshaler struct { - writer io.Writer codec Codec compressMinBytes int compressionName string @@ -902,9 +896,9 @@ type connectUnaryMarshaler struct { sendMaxBytes int } -func (m *connectUnaryMarshaler) Marshal(message any) *Error { +func (m *connectUnaryMarshaler) Marshal(dst io.Writer, message any) *Error { if message == nil { - return m.write(nil) + return m.write(dst, nil) } var data []byte var err error @@ -923,7 +917,7 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error { if m.sendMaxBytes > 0 && len(data) > m.sendMaxBytes { return NewError(CodeResourceExhausted, fmt.Errorf("message size %d exceeds sendMaxBytes %d", len(data), m.sendMaxBytes)) } - return m.write(data) + return m.write(dst, data) } compressed := m.bufferPool.Get() defer m.bufferPool.Put(compressed) @@ -934,11 +928,11 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error { return NewError(CodeResourceExhausted, fmt.Errorf("compressed message size %d exceeds sendMaxBytes %d", compressed.Len(), m.sendMaxBytes)) } setHeaderCanonical(m.header, connectUnaryHeaderCompression, m.compressionName) - return m.write(compressed.Bytes()) + return m.write(dst, compressed.Bytes()) } -func (m *connectUnaryMarshaler) write(data []byte) *Error { - if _, err := m.writer.Write(data); err != nil { +func (m *connectUnaryMarshaler) write(dst io.Writer, data []byte) *Error { + if _, err := dst.Write(data); err != nil { if connectErr, ok := asError(err); ok { return connectErr } @@ -957,19 +951,19 @@ type connectUnaryRequestMarshaler struct { duplexCall *duplexHTTPCall } -func (m *connectUnaryRequestMarshaler) Marshal(message any) *Error { +func (m *connectUnaryRequestMarshaler) Marshal(dst io.Writer, message any) *Error { if m.enableGet { if m.stableCodec == nil && !m.getUseFallback { return errorf(CodeInternal, "codec %s doesn't support stable marshal; can't use get", m.codec.Name()) } if m.stableCodec != nil { - return m.marshalWithGet(message) + return m.marshalWithGet(dst, message) } } - return m.connectUnaryMarshaler.Marshal(message) + return m.connectUnaryMarshaler.Marshal(dst, message) } -func (m *connectUnaryRequestMarshaler) marshalWithGet(message any) *Error { +func (m *connectUnaryRequestMarshaler) marshalWithGet(dst io.Writer, message any) *Error { // TODO(jchadwick-buf): This function is mostly a superset of // connectUnaryMarshaler.Marshal. This should be reconciled at some point. var data []byte @@ -995,7 +989,7 @@ func (m *connectUnaryRequestMarshaler) marshalWithGet(message any) *Error { } if m.compressionPool == nil { if m.getUseFallback { - return m.write(data) + return m.write(dst, data) } return NewError(CodeResourceExhausted, fmt.Errorf( "url size %d exceeds getURLMaxBytes %d: enabling request compression may help", @@ -1021,7 +1015,7 @@ func (m *connectUnaryRequestMarshaler) marshalWithGet(message any) *Error { } if m.getUseFallback { setHeaderCanonical(m.header, connectUnaryHeaderCompression, m.compressionName) - return m.write(compressed.Bytes()) + return m.write(dst, compressed.Bytes()) } return NewError(CodeResourceExhausted, fmt.Errorf("compressed url size %d exceeds getURLMaxBytes %d", len(url.String()), m.getURLMaxBytes)) } @@ -1052,7 +1046,6 @@ func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error { } type connectUnaryUnmarshaler struct { - reader io.Reader codec Codec compressionPool *compressionPool bufferPool *bufferPool @@ -1060,20 +1053,20 @@ type connectUnaryUnmarshaler struct { readMaxBytes int } -func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error { - return u.UnmarshalFunc(message, u.codec.Unmarshal) +func (u *connectUnaryUnmarshaler) Unmarshal(message any, src io.Reader) *Error { + return u.UnmarshalFunc(message, src, u.codec.Unmarshal) } -func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error { +func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, src io.Reader, unmarshal func([]byte, any) error) *Error { if u.alreadyRead { return NewError(CodeInternal, io.EOF) } u.alreadyRead = true data := u.bufferPool.Get() defer u.bufferPool.Put(data) - reader := u.reader + reader := src if u.readMaxBytes > 0 && int64(u.readMaxBytes) < math.MaxInt64 { - reader = io.LimitReader(u.reader, int64(u.readMaxBytes)+1) + reader = io.LimitReader(reader, int64(u.readMaxBytes)+1) } // ReadFrom ignores io.EOF, so any error here is real. bytesRead, err := data.ReadFrom(reader) @@ -1088,7 +1081,7 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by } if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) { // Attempt to read to end in order to allow connection re-use - discardedBytes, err := io.Copy(io.Discard, u.reader) + discardedBytes, err := io.Copy(io.Discard, src) if err != nil { return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err) } diff --git a/protocol_connect_test.go b/protocol_connect_test.go index eb5751a4..78cad08d 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -72,10 +72,10 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { assert.Nil(t, err) writer := envelopeWriter{ - writer: &buffer, bufferPool: bufferPool, } - err = writer.Write(&envelope{ + dst := &buffer + err = writer.Write(dst, &envelope{ Flags: connectFlagEnvelopeEndStream, Data: bytes.NewBuffer(endStreamData), }) @@ -83,11 +83,11 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { unmarshaler := connectStreamingUnmarshaler{ envelopeReader: envelopeReader{ - reader: &buffer, bufferPool: bufferPool, }, } - err = unmarshaler.Unmarshal(nil) // parameter won't be used + src := &buffer + err = unmarshaler.Unmarshal(nil /* message unused */, src) assert.ErrorIs(t, err, errSpecialEnvelope) assert.Equal(t, unmarshaler.Trailer().Values("Not-Canonical-Header"), []string{"a"}) assert.Equal(t, unmarshaler.Trailer().Values("Mixed-Canonical"), []string{"b", "b"}) diff --git a/protocol_grpc.go b/protocol_grpc.go index 177e31f2..7076ffd8 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -186,7 +186,6 @@ func (g *grpcHandler) NewConn( protobuf: g.Codecs.Protobuf(), // for errors marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, compressionPool: g.CompressionPools.Get(responseCompression), codec: codec, compressMinBytes: g.CompressMinBytes, @@ -200,7 +199,6 @@ func (g *grpcHandler) NewConn( request: request, unmarshaler: grpcUnmarshaler{ envelopeReader: envelopeReader{ - reader: request.Body, codec: codec, compressionPool: g.CompressionPools.Get(requestCompression), bufferPool: g.BufferPool, @@ -284,7 +282,6 @@ func (g *grpcClient) NewConn( protobuf: g.Protobuf, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: duplexCall, compressionPool: g.CompressionPools.Get(g.CompressionName), codec: g.Codec, compressMinBytes: g.CompressMinBytes, @@ -294,7 +291,6 @@ func (g *grpcClient) NewConn( }, unmarshaler: grpcUnmarshaler{ envelopeReader: envelopeReader{ - reader: duplexCall, codec: g.Codec, bufferPool: g.BufferPool, readMaxBytes: g.ReadMaxBytes, @@ -343,7 +339,7 @@ func (cc *grpcClientConn) Peer() Peer { } func (cc *grpcClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + if err := cc.marshaler.Marshal(cc.duplexCall, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -361,7 +357,7 @@ func (cc *grpcClientConn) Receive(msg any) error { if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { return err } - err := cc.unmarshaler.Unmarshal(msg) + err := cc.unmarshaler.Unmarshal(msg, cc.duplexCall) if err == nil { return nil } @@ -455,7 +451,7 @@ func (hc *grpcHandlerConn) Peer() Peer { } func (hc *grpcHandlerConn) Receive(msg any) error { - if err := hc.unmarshaler.Unmarshal(msg); err != nil { + if err := hc.unmarshaler.Unmarshal(msg, hc.request.Body); err != nil { return err // already coded } return nil // must be a literal nil: nil *Error is a non-nil error @@ -471,7 +467,7 @@ func (hc *grpcHandlerConn) Send(msg any) error { mergeHeaders(hc.responseWriter.Header(), hc.responseHeader) hc.wroteToBody = true } - if err := hc.marshaler.Marshal(msg); err != nil { + if err := hc.marshaler.Marshal(hc.responseWriter, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -526,7 +522,7 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) { if hc.web { // We're using gRPC-Web and we've already sent the headers, so we write // trailing metadata to the HTTP body. - if err := hc.marshaler.MarshalWebTrailers(mergedTrailers); err != nil { + if err := hc.marshaler.MarshalWebTrailers(hc.responseWriter, mergedTrailers); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -578,7 +574,7 @@ type grpcMarshaler struct { envelopeWriter } -func (m *grpcMarshaler) MarshalWebTrailers(trailer http.Header) *Error { +func (m *grpcMarshaler) MarshalWebTrailers(dst io.Writer, trailer http.Header) *Error { raw := m.envelopeWriter.bufferPool.Get() defer m.envelopeWriter.bufferPool.Put(raw) for key, values := range trailer { @@ -595,7 +591,7 @@ func (m *grpcMarshaler) MarshalWebTrailers(trailer http.Header) *Error { if err := trailer.Write(raw); err != nil { return errorf(CodeInternal, "format trailers: %w", err) } - return m.Write(&envelope{ + return m.Write(dst, &envelope{ Data: raw, Flags: grpcFlagEnvelopeTrailer, }) @@ -607,8 +603,8 @@ type grpcUnmarshaler struct { webTrailer http.Header } -func (u *grpcUnmarshaler) Unmarshal(message any) *Error { - err := u.envelopeReader.Unmarshal(message) +func (u *grpcUnmarshaler) Unmarshal(message any, src io.Reader) *Error { + err := u.envelopeReader.Unmarshal(message, src) if err == nil { return nil } diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 7dd8e587..1447977f 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -48,7 +48,6 @@ func TestGRPCHandlerSender(t *testing.T) { protobuf: protobufCodec, marshaler: grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, codec: protobufCodec, bufferPool: bufferPool, }, @@ -59,7 +58,6 @@ func TestGRPCHandlerSender(t *testing.T) { request: request, unmarshaler: grpcUnmarshaler{ envelopeReader: envelopeReader{ - reader: request.Body, codec: protobufCodec, bufferPool: bufferPool, }, @@ -181,7 +179,6 @@ func TestGRPCWebTrailerMarshalling(t *testing.T) { responseWriter := httptest.NewRecorder() marshaler := grpcMarshaler{ envelopeWriter: envelopeWriter{ - writer: responseWriter, bufferPool: newBufferPool(), }, } @@ -189,7 +186,7 @@ func TestGRPCWebTrailerMarshalling(t *testing.T) { trailer.Add("grpc-status", "0") trailer.Add("Grpc-Message", "Foo") trailer.Add("User-Provided", "bar") - err := marshaler.MarshalWebTrailers(trailer) + err := marshaler.MarshalWebTrailers(responseWriter, trailer) assert.Nil(t, err) responseWriter.Body.Next(5) // skip flags and message length marshalled := responseWriter.Body.String()