diff --git a/client.go b/client.go index 68327940..d1bfd030 100644 --- a/client.go +++ b/client.go @@ -76,17 +76,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien // once at client creation. unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - - var responseMsg Res - response := &Response[Res]{ - Msg: &responseMsg, - } - if err := client.protocolClient.Invoke(ctx, unarySpec, request, response); err != nil { - return nil, err - } - return response, nil - - /*conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) + conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) conn.onRequestSend(func(r *http.Request) { request.setRequestMethod(r.Method) }) @@ -107,7 +97,7 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien _ = conn.CloseResponse() return nil, err } - return response, conn.CloseResponse()*/ + return response, conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { unaryFunc = interceptor.WrapUnary(unaryFunc) @@ -145,7 +135,7 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo if c.err != nil { return &ClientStreamForClient[Req, Res]{err: c.err} } - return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)} + return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil /* header */, nil /* onRequestSend */)} } // CallServerStream calls a server streaming procedure. @@ -153,12 +143,11 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { + conn := c.newConn(ctx, StreamTypeServer, request.header, func(r *http.Request) { request.method = r.Method }) request.spec = conn.Spec() request.peer = conn.Peer() - mergeHeaders(conn.RequestHeader(), request.header) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -178,12 +167,14 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli if c.err != nil { return &BidiStreamForClient[Req, Res]{err: c.err} } - return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)} + return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil /* header */, nil /* onRequestSend */)} } -func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { +func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, header http.Header, onRequestSend func(r *http.Request)) StreamingClientConn { + if header == nil { + header = make(http.Header, 8) // arbitrary power of two, prevent immediate resizing + } newConn := func(ctx context.Context, spec Spec) StreamingClientConn { - header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing c.protocolClient.WriteRequestHeader(streamType, header) conn := c.protocolClient.NewConn(ctx, spec, header) conn.onRequestSend(onRequestSend) diff --git a/client_ext_test.go b/client_ext_test.go index cb4dede4..739e8aa3 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -109,7 +109,7 @@ func TestClientPeer(t *testing.T) { err = clientStream.Send(&pingv1.SumRequest{}) assert.Nil(t, err) // server streaming - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1})) t.Cleanup(func() { assert.Nil(t, serverStream.Close()) }) diff --git a/http_call.go b/http_call.go index ed4d896c..b71024fa 100644 --- a/http_call.go +++ b/http_call.go @@ -22,93 +22,16 @@ import ( "io" "net/http" "net/url" - "sync" ) -type unaryHTTPCall struct { - httpClient HTTPClient - onRequestSend func(*http.Request) - validateResponse func(*http.Response) *Error - - request *http.Request - response *http.Response // nil until the request has been sent -} - -func newUnaryHTTPCall( - ctx context.Context, - httpClient HTTPClient, - url *url.URL, - header http.Header, -) *unaryHTTPCall { - request := makeRequest(ctx, url, header, nil) - return &unaryHTTPCall{ - httpClient: httpClient, - request: request, - } -} - -// Do sends the request and waits for the response. -func (u *unaryHTTPCall) Do(body *bytes.Buffer) *Error { - // Body can be nil for GET requests. - if body != nil { - u.request.Body = io.NopCloser(body) - u.request.ContentLength = int64(body.Len()) - - // We need to set the GetBody function so that net/http can re-send the - // request if required. - bodyBytes := body.Bytes() - u.request.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(bytes.NewReader(bodyBytes)), nil - } - } - // Promote the header Host to the request object. - if host := u.request.Header.Get(headerHost); len(host) > 0 { - u.request.Host = host - } - if u.onRequestSend != nil { - u.onRequestSend(u.request) - } - // Complete the request. - response, err := u.httpClient.Do(u.request) - if err != nil { - err = wrapIfContextError(err) - err = wrapIfLikelyH2CNotConfiguredError(u.request, err) - err = wrapIfLikelyWithGRPCNotUsedError(err) - err = wrapIfRSTError(err) - if cerr, ok := asError(err); ok { - return cerr - } - return NewError(CodeUnavailable, err) - } - u.response = response - return u.validateResponse(response) -} - -func (u *unaryHTTPCall) Header() http.Header { - return u.request.Header -} - -func (u *unaryHTTPCall) SetMethod(method string) { - u.request.Method = method -} -func (u *unaryHTTPCall) URL() *url.URL { - return u.request.URL -} - -// SetValidateResponse sets the response validation function. The function runs -// in a background goroutine. -func (u *unaryHTTPCall) SetValidateResponse(validate func(*http.Response) *Error) { - u.validateResponse = validate -} - -// duplexHTTPCall is a full-duplex stream between the client and server. The +// httpCall is a full-duplex stream between the client and server. The // request body is the stream from client to server, and the response body is // the reverse. // // Be warned: we need to use some lesser-known APIs to do this with net/http. -type duplexHTTPCall struct { +type httpCall struct { ctx context.Context - httpClient HTTPClient + client HTTPClient streamType StreamType onRequestSend func(*http.Request) validateResponse func(*http.Response) *Error @@ -116,66 +39,75 @@ type duplexHTTPCall struct { // We'll use a pipe as the request body. We hand the read side of the pipe to // net/http, and we write to the write side (naturally). The two ends are // safe to use concurrently. - requestBodyReader *io.PipeReader requestBodyWriter *io.PipeWriter - sendRequestOnce sync.Once - responseReady chan struct{} - request *http.Request - response *http.Response + requestSent bool + request *http.Request - errMu sync.Mutex - err error + responseReady chan struct{} + response *http.Response + responseErr error } -func newDuplexHTTPCall( +func newHTTPCall( ctx context.Context, - httpClient HTTPClient, + client HTTPClient, url *url.URL, spec Spec, header http.Header, -) *duplexHTTPCall { - pipeReader, pipeWriter := io.Pipe() - request := makeRequest(ctx, url, header, pipeReader) - return &duplexHTTPCall{ - ctx: ctx, - httpClient: httpClient, - streamType: spec.StreamType, - requestBodyReader: pipeReader, - requestBodyWriter: pipeWriter, - request: request, - responseReady: make(chan struct{}), +) *httpCall { + // ensure we make a copy of the url before we pass along to the + // Request. This ensures if a transport out of our control wants + // to mutate the req.URL, we don't feel the effects of it. + url = cloneURL(url) + // This is mirroring what http.NewRequestContext did, but + // using an already parsed url.URL object, rather than a string + // and parsing it again. This is a bit funny with HTTP/1.1 + // explicitly, but this is logic copied over from + // NewRequestContext and doesn't effect the actual version + // being transmitted. + request := (&http.Request{ + Method: http.MethodPost, + URL: url, + Header: header, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: http.NoBody, + Host: url.Host, + }).WithContext(ctx) + return &httpCall{ + ctx: ctx, + client: client, + streamType: spec.StreamType, + request: request, + responseReady: make(chan struct{}), } } -// Write to the request body. Returns an error wrapping io.EOF after SetError -// is called. -func (d *duplexHTTPCall) Write(data []byte) (int, error) { - d.ensureRequestMade() +// Send the request headers and body. If the streamType is not client streaming, +// this method blocks until the response headers are received. +func (c *httpCall) Send(buffer *bytes.Buffer) error { // Before we send any data, check if the context has been canceled. - if err := d.ctx.Err(); err != nil { - d.SetError(err) - return 0, wrapIfContextError(err) + if err := c.ctx.Err(); err != nil { + return wrapIfContextError(err) } - // It's safe to write to this side of the pipe while net/http concurrently - // reads from the other side. - bytesWritten, err := d.requestBodyWriter.Write(data) - if err != nil && errors.Is(err, io.ErrClosedPipe) { - // Signal that the stream is closed with the more-typical io.EOF instead of - // io.ErrClosedPipe. This makes it easier for protocol-specific wrappers to - // match grpc-go's behavior. - return bytesWritten, io.EOF + if c.isClientStream() { + return c.sendStream(buffer) } - return bytesWritten, err + return c.sendUnary(buffer) } // Close the request body. Callers *must* call CloseWrite before Read when // using HTTP/1.x. -func (d *duplexHTTPCall) CloseWrite() error { - // Even if Write was never called, we need to make an HTTP request. This +func (c *httpCall) CloseWrite() error { + // Even if Send was never called, we need to make an HTTP request. This // ensures that we've sent any headers to the server and that we have an HTTP // response to read from. - d.ensureRequestMade() + if !c.requestSent { + c.requestSent = true + go c.makeRequest() + } // The user calls CloseWrite to indicate that they're done sending data. It's // safe to close the write side of the pipe while net/http is reading from // it. @@ -187,179 +119,193 @@ func (d *duplexHTTPCall) CloseWrite() error { // forever. To make sure users don't have to worry about this, the generated // code for unary, client streaming, and server streaming RPCs must call // CloseWrite automatically rather than requiring the user to do it. - return d.requestBodyWriter.Close() + if c.requestBodyWriter != nil { + return c.requestBodyWriter.Close() + } + return nil } // Header returns the HTTP request headers. -func (d *duplexHTTPCall) Header() http.Header { - return d.request.Header -} - -// Trailer returns the HTTP request trailers. -func (d *duplexHTTPCall) Trailer() http.Header { - return d.request.Trailer +func (c *httpCall) Header() http.Header { + return c.request.Header } // URL returns the URL for the request. -func (d *duplexHTTPCall) URL() *url.URL { - return d.request.URL +func (c *httpCall) URL() *url.URL { + return c.request.URL } // SetMethod changes the method of the request before it is sent. -func (d *duplexHTTPCall) SetMethod(method string) { - d.request.Method = method +func (c *httpCall) SetMethod(method string) { + c.request.Method = method } -// Read from the response body. Returns the first error passed to SetError. -func (d *duplexHTTPCall) Read(data []byte) (int, error) { +// Read from the response body. +func (c *httpCall) Read(data []byte) (int, error) { // First, we wait until we've gotten the response headers and established the // server-to-client side of the stream. - d.BlockUntilResponseReady() - if err := d.getError(); err != nil { - // The stream is already closed or corrupted. + if err := c.BlockUntilResponseReady(); err != nil { return 0, err } // Before we read, check if the context has been canceled. - if err := d.ctx.Err(); err != nil { - d.SetError(err) + if err := c.ctx.Err(); err != nil { return 0, wrapIfContextError(err) } - if d.response == nil { - return 0, fmt.Errorf("nil response from %v", d.request.URL) - } - n, err := d.response.Body.Read(data) + n, err := c.response.Body.Read(data) return n, wrapIfRSTError(err) } -func (d *duplexHTTPCall) CloseRead() error { - d.BlockUntilResponseReady() - if d.response == nil { +func (c *httpCall) CloseRead() error { + _ = c.BlockUntilResponseReady() + if c.response == nil { return nil } - if _, err := discard(d.response.Body); err != nil { - _ = d.response.Body.Close() + if _, err := discard(c.response.Body); err != nil { + _ = c.response.Body.Close() return wrapIfRSTError(err) } - return wrapIfRSTError(d.response.Body.Close()) + return wrapIfRSTError(c.response.Body.Close()) } // ResponseStatusCode is the response's HTTP status code. -func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { - d.BlockUntilResponseReady() - if d.response == nil { - return 0, fmt.Errorf("nil response from %v", d.request.URL) +func (c *httpCall) ResponseStatusCode() (int, error) { + if err := c.BlockUntilResponseReady(); err != nil { + return 0, err } - return d.response.StatusCode, nil + return c.response.StatusCode, nil } // ResponseHeader returns the response HTTP headers. -func (d *duplexHTTPCall) ResponseHeader() http.Header { - d.BlockUntilResponseReady() - if d.response != nil { - return d.response.Header +func (c *httpCall) ResponseHeader() http.Header { + _ = c.BlockUntilResponseReady() + if c.response != nil { + return c.response.Header } return make(http.Header) } // ResponseTrailer returns the response HTTP trailers. -func (d *duplexHTTPCall) ResponseTrailer() http.Header { - d.BlockUntilResponseReady() - if d.response != nil { - return d.response.Trailer +func (c *httpCall) ResponseTrailer() http.Header { + _ = c.BlockUntilResponseReady() + if c.response != nil { + return c.response.Trailer } return make(http.Header) } -// SetError stores any error encountered processing the response. All -// subsequent calls to Read return this error, and all subsequent calls to -// Write return an error wrapping io.EOF. It's safe to call concurrently with -// any other method. -func (d *duplexHTTPCall) SetError(err error) { - d.errMu.Lock() - if d.err == nil { - d.err = wrapIfContextError(err) - } - // Closing the read side of the request body pipe acquires an internal lock, - // so we want to scope errMu's usage narrowly and avoid defer. - d.errMu.Unlock() +// SetValidateResponse sets the response validation function. The function runs +// in a background goroutine. +func (c *httpCall) SetValidateResponse(validate func(*http.Response) *Error) { + c.validateResponse = validate +} - // We've already hit an error, so we should stop writing to the request body. - // It's safe to call Close more than once and/or concurrently (calls after - // the first are no-ops), so it's okay for us to call this even though - // net/http sometimes closes the reader too. - // - // It's safe to ignore the returned error here. Under the hood, Close calls - // CloseWithError, which is documented to always return nil. - _ = d.requestBodyReader.Close() +func (c *httpCall) BlockUntilResponseReady() error { + <-c.responseReady + return c.responseErr } -// SetValidateResponse sets the response validation function. The function runs -// in a background goroutine. -func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Error) { - d.validateResponse = validate +func (c *httpCall) isClientStream() bool { + return c.streamType&StreamTypeClient != 0 } -func (d *duplexHTTPCall) BlockUntilResponseReady() { - <-d.responseReady +func (c *httpCall) sendUnary(buffer *bytes.Buffer) error { + if c.requestSent { + return fmt.Errorf("unary call already sent") + } + c.requestSent = true + // Client request is unary, build the full request body and block on + // sending the request. + if c.request.Method != http.MethodGet && buffer != nil { + payload := buffer.Bytes() + c.request.Body = io.NopCloser(buffer) + c.request.ContentLength = int64(buffer.Len()) + c.request.GetBody = func() (io.ReadCloser, error) { + buffer = bytes.NewBuffer(payload) + return io.NopCloser(buffer), nil + } + } + c.makeRequest() // blocks until the response is ready + return nil // Only report response errors on Read } -func (d *duplexHTTPCall) ensureRequestMade() { - d.sendRequestOnce.Do(func() { - go d.makeRequest() - }) +func (c *httpCall) sendStream(buffer *bytes.Buffer) error { + if !c.requestSent { + c.requestSent = true + // Client request is streaming, so we need to start sending the request + // before we start writing to the request body. This ensures that we've + // sent any headers to the server. + pipeReader, pipeWriter := io.Pipe() + c.requestBodyWriter = pipeWriter + c.request.Body = pipeReader + c.request.ContentLength = -1 + go c.makeRequest() // concurrent request + } + // It's safe to write to this side of the pipe while net/http concurrently + // reads from the other side. + _, err := c.requestBodyWriter.Write(buffer.Bytes()) + if err != nil && errors.Is(err, io.ErrClosedPipe) { + // Signal that the stream is closed with the more-typical io.EOF instead of + // io.ErrClosedPipe. This makes it easier for protocol-specific wrappers to + // match grpc-go's behavior. + return io.EOF + } + return err } -func (d *duplexHTTPCall) makeRequest() { +func (c *httpCall) makeRequest() { // This runs concurrently with Write and CloseWrite. Read and CloseRead wait // on d.responseReady, so we can't race with them. - defer close(d.responseReady) + defer close(c.responseReady) // Promote the header Host to the request object. - if host := d.request.Header.Get(headerHost); len(host) > 0 { - d.request.Host = host + if host := c.request.Header.Get(headerHost); len(host) > 0 { + c.request.Host = host } - if d.onRequestSend != nil { - d.onRequestSend(d.request) + if c.onRequestSend != nil { + c.onRequestSend(c.request) } // Once we send a message to the server, they send a message back and // establish the receive side of the stream. - response, err := d.httpClient.Do(d.request) + response, err := c.client.Do(c.request) if err != nil { err = wrapIfContextError(err) - err = wrapIfLikelyH2CNotConfiguredError(d.request, err) + err = wrapIfLikelyH2CNotConfiguredError(c.request, err) err = wrapIfLikelyWithGRPCNotUsedError(err) err = wrapIfRSTError(err) if _, ok := asError(err); !ok { err = NewError(CodeUnavailable, err) } - d.SetError(err) + c.responseErr = err + if c.requestBodyWriter != nil { + c.requestBodyWriter.Close() + } return } - d.response = response - if err := d.validateResponse(response); err != nil { - d.SetError(err) + c.response = response + if err := c.validateResponse(response); err != nil { + c.responseErr = err + if c.requestBodyWriter != nil { + c.requestBodyWriter.Close() + } return } - if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { + if (c.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { // If we somehow dialed an HTTP/1.x server, fail with an explicit message // rather than returning a more cryptic error later on. - d.SetError(errorf( + c.responseErr = errorf( CodeUnimplemented, "response from %v is HTTP/%d.%d: bidi streams require at least HTTP/2", - d.request.URL, + c.request.URL, response.ProtoMajor, response.ProtoMinor, - )) + ) + if c.requestBodyWriter != nil { + c.requestBodyWriter.Close() + } } } -func (d *duplexHTTPCall) getError() error { - d.errMu.Lock() - defer d.errMu.Unlock() - return d.err -} - // See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33 func cloneURL(oldURL *url.URL) *url.URL { if oldURL == nil { @@ -373,27 +319,3 @@ func cloneURL(oldURL *url.URL) *url.URL { } return newURL } - -func makeRequest(ctx context.Context, url *url.URL, header http.Header, body io.ReadCloser) *http.Request { - // ensure we make a copy of the url before we pass along to the - // Request. This ensures if a transport out of our control wants - // to mutate the req.URL, we don't feel the effects of it. - url = cloneURL(url) - // This is mirroring what http.NewRequestContext did, but - // using an already parsed url.URL object, rather than a string - // and parsing it again. This is a bit funny with HTTP/1.1 - // explicitly, but this is logic copied over from - // NewRequestContext and doesn't effect the actual version - // being transmitted. - return (&http.Request{ - Method: http.MethodPost, - URL: url, - Header: header, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Body: body, - Host: url.Host, - }).WithContext(ctx) - -} diff --git a/protocol.go b/protocol.go index a2aedfa5..cc6455cd 100644 --- a/protocol.go +++ b/protocol.go @@ -147,9 +147,6 @@ type protocolClient interface { // unary call, implementations may assume that the Sender's Send and Close // methods return before the Receiver's Receive or Close methods are called. NewConn(context.Context, Spec, http.Header) streamingClientConn - - // Invoke a unary RPC. - Invoke(context.Context, Spec, AnyRequest, AnyResponse) error } // streamingClientConn extends StreamingClientConn with a method for registering diff --git a/protocol_connect.go b/protocol_connect.go index 9db9151e..86ca14c8 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -348,7 +348,11 @@ func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.He } } -func (c *connectClient) encodeDeadlineFromContext(ctx context.Context, header http.Header) { +func (c *connectClient) NewConn( + ctx context.Context, + spec Spec, + header http.Header, +) streamingClientConn { if deadline, ok := ctx.Deadline(); ok { millis := int64(time.Until(deadline) / time.Millisecond) if millis > 0 { @@ -358,123 +362,82 @@ func (c *connectClient) encodeDeadlineFromContext(ctx context.Context, header ht } // else effectively unbounded } } -} - -func (c *connectClient) NewConn( - ctx context.Context, - spec Spec, - header http.Header, -) streamingClientConn { - c.encodeDeadlineFromContext(ctx, header) - duplexCall := newDuplexHTTPCall(ctx, c.HTTPClient, c.URL, spec, header) - streamingConn := &connectStreamingClientConn{ - spec: spec, - peer: c.Peer(), - duplexCall: duplexCall, - compressionPools: c.CompressionPools, - bufferPool: c.BufferPool, - codec: c.Codec, - marshaler: connectStreamingMarshaler{ - envelopeWriter: envelopeWriter{ - codec: c.Codec, - compressMinBytes: c.CompressMinBytes, - compressionPool: c.CompressionPools.Get(c.CompressionName), - bufferPool: c.BufferPool, - sendMaxBytes: c.SendMaxBytes, + call := newHTTPCall(ctx, c.HTTPClient, c.URL, spec, header) + var conn streamingClientConn + if spec.StreamType == StreamTypeUnary { + unaryConn := &connectUnaryClientConn{ + spec: spec, + peer: c.Peer(), + call: call, + compressionPools: c.CompressionPools, + bufferPool: c.BufferPool, + marshaler: connectUnaryRequestMarshaler{ + connectUnaryMarshaler: connectUnaryMarshaler{ + codec: c.Codec, + compressMinBytes: c.CompressMinBytes, + compressionName: c.CompressionName, + compressionPool: c.CompressionPools.Get(c.CompressionName), + bufferPool: c.BufferPool, + header: call.Header(), + sendMaxBytes: c.SendMaxBytes, + }, }, - }, - unmarshaler: connectStreamingUnmarshaler{ - envelopeReader: envelopeReader{ + unmarshaler: connectUnaryUnmarshaler{ codec: c.Codec, bufferPool: c.BufferPool, readMaxBytes: c.ReadMaxBytes, }, - }, - responseHeader: make(http.Header), - responseTrailer: make(http.Header), - } - duplexCall.SetValidateResponse(streamingConn.validateResponse) - return wrapClientConnWithCodedErrors(streamingConn) -} - -func (c *connectClient) Invoke( - ctx context.Context, spec Spec, request AnyRequest, response AnyResponse, -) error { - header := request.Header() - c.encodeDeadlineFromContext(ctx, header) - - unaryCall := newUnaryHTTPCall(ctx, c.HTTPClient, c.URL, header) - conn := &connectUnaryClientConn{ - spec: spec, - peer: c.Peer(), - unaryCall: unaryCall, - compressionPools: c.CompressionPools, - bufferPool: c.BufferPool, - marshaler: connectUnaryRequestMarshaler{ - connectUnaryMarshaler: connectUnaryMarshaler{ - codec: c.Codec, - compressMinBytes: c.CompressMinBytes, - compressionName: c.CompressionName, - compressionPool: c.CompressionPools.Get(c.CompressionName), - bufferPool: c.BufferPool, - header: header, - sendMaxBytes: c.SendMaxBytes, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), + } + if spec.IdempotencyLevel == IdempotencyNoSideEffects { + unaryConn.marshaler.enableGet = c.EnableGet + unaryConn.marshaler.getURLMaxBytes = c.GetURLMaxBytes + unaryConn.marshaler.getUseFallback = c.GetUseFallback + unaryConn.marshaler.call = call + if stableCodec, ok := c.Codec.(stableCodec); ok { + unaryConn.marshaler.stableCodec = stableCodec + } + } + conn = unaryConn + call.SetValidateResponse(unaryConn.validateResponse) + } else { + streamingConn := &connectStreamingClientConn{ + spec: spec, + peer: c.Peer(), + call: call, + compressionPools: c.CompressionPools, + bufferPool: c.BufferPool, + codec: c.Codec, + marshaler: connectStreamingMarshaler{ + envelopeWriter: envelopeWriter{ + codec: c.Codec, + compressMinBytes: c.CompressMinBytes, + compressionPool: c.CompressionPools.Get(c.CompressionName), + bufferPool: c.BufferPool, + sendMaxBytes: c.SendMaxBytes, + }, }, - unaryCall: unaryCall, - }, - unmarshaler: connectUnaryUnmarshaler{ - codec: c.Codec, - bufferPool: c.BufferPool, - readMaxBytes: c.ReadMaxBytes, - }, - responseHeader: make(http.Header), - responseTrailer: make(http.Header), - } - if spec.IdempotencyLevel == IdempotencyNoSideEffects { - conn.marshaler.enableGet = c.EnableGet - conn.marshaler.getURLMaxBytes = c.GetURLMaxBytes - conn.marshaler.getUseFallback = c.GetUseFallback - if stableCodec, ok := c.Codec.(stableCodec); ok { - conn.marshaler.stableCodec = stableCodec + unmarshaler: connectStreamingUnmarshaler{ + envelopeReader: envelopeReader{ + codec: c.Codec, + bufferPool: c.BufferPool, + readMaxBytes: c.ReadMaxBytes, + }, + }, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), } + conn = streamingConn + call.SetValidateResponse(streamingConn.validateResponse) } - unaryCall.SetValidateResponse(conn.validateResponse) - - conn.onRequestSend(func(r *http.Request) { - request.setRequestMethod(r.Method) - }) - - // Send always returns an io.EOF unless the error is from the client-side. - // We want the user to continue to call Receive in those cases to get the - // full error from the server-side. - if err := conn.Send(request.Any()); err != nil && !errors.Is(err, io.EOF) { - _ = conn.CloseRequest() - _ = conn.CloseResponse() - return err - } - if err := conn.CloseRequest(); err != nil { - _ = conn.CloseResponse() - return err - } - - if err := conn.Receive(response.Any()); err != nil { - return err - } - // In a well-formed stream, the response message may be followed by a block - // of in-stream trailers or HTTP trailers. To ensure that we receive the - // trailers, try to read another message from the stream. - if err := conn.Receive(nil); err == nil { - return NewError(CodeUnknown, errors.New("unary stream has multiple messages")) - } else if err != nil && !errors.Is(err, io.EOF) { - return NewError(CodeUnknown, err) - } - return conn.CloseResponse() + return wrapClientConnWithCodedErrors(conn) } type connectUnaryClientConn struct { spec Spec peer Peer - unaryCall *unaryHTTPCall + call *httpCall compressionPools readOnlyCompressionPools bufferPool *bufferPool marshaler connectUnaryRequestMarshaler @@ -492,45 +455,57 @@ func (cc *connectUnaryClientConn) Peer() Peer { } func (cc *connectUnaryClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + buffer := cc.bufferPool.Get() + defer cc.bufferPool.Put(buffer) + if err := cc.marshaler.Marshal(buffer, msg); err != nil { return err } + if err := cc.call.Send(buffer); err != nil { + if cerr, ok := asError(err); ok { + return cerr + } + return errorf(CodeUnknown, "send message: %w", err) + } return nil // must be a literal nil: nil *Error is a non-nil error } func (cc *connectUnaryClientConn) RequestHeader() http.Header { - return cc.unaryCall.Header() + return cc.call.Header() } func (cc *connectUnaryClientConn) CloseRequest() error { - return nil // nop for unary requests. + return cc.call.CloseWrite() // nop for unary } func (cc *connectUnaryClientConn) Receive(msg any) error { - defer cc.unaryCall.response.Body.Close() - if err := cc.unmarshaler.Unmarshal(msg, cc.unaryCall.response.Body); err != nil { + if err := cc.call.BlockUntilResponseReady(); err != nil { + return err + } + if err := cc.unmarshaler.Unmarshal(msg, cc.call.response.Body); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error } func (cc *connectUnaryClientConn) ResponseHeader() http.Header { + _ = cc.call.BlockUntilResponseReady() return cc.responseHeader } func (cc *connectUnaryClientConn) ResponseTrailer() http.Header { + _ = cc.call.BlockUntilResponseReady() return cc.responseTrailer } func (cc *connectUnaryClientConn) CloseResponse() error { - if response := cc.unaryCall.response; response != nil { + if response := cc.call.response; response != nil { return response.Body.Close() } return nil } func (cc *connectUnaryClientConn) onRequestSend(fn func(*http.Request)) { - cc.unaryCall.onRequestSend = fn + cc.call.onRequestSend = fn } func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Error { @@ -584,7 +559,7 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err type connectStreamingClientConn struct { spec Spec peer Peer - duplexCall *duplexHTTPCall + call *httpCall compressionPools readOnlyCompressionPools bufferPool *bufferPool codec Codec @@ -603,24 +578,27 @@ func (cc *connectStreamingClientConn) Peer() Peer { } func (cc *connectStreamingClientConn) Send(msg any) error { - dst := cc.duplexCall - if err := cc.marshaler.Marshal(dst, msg); err != nil { + buffer := cc.bufferPool.Get() + defer cc.bufferPool.Put(buffer) + if err := cc.marshaler.Marshal(buffer, msg); err != nil { return err } - return nil // must be a literal nil: nil *Error is a non-nil error + return cc.call.Send(buffer) } func (cc *connectStreamingClientConn) RequestHeader() http.Header { - return cc.duplexCall.Header() + return cc.call.Header() } func (cc *connectStreamingClientConn) CloseRequest() error { - return cc.duplexCall.CloseWrite() + return cc.call.CloseWrite() } func (cc *connectStreamingClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() - src := cc.duplexCall + if err := cc.call.BlockUntilResponseReady(); err != nil { + return err + } + src := cc.call err := cc.unmarshaler.Unmarshal(msg, src) if err == nil { return nil @@ -634,7 +612,7 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) + _ = cc.call.CloseWrite() return serverErr } // If the error is EOF but not from a last message, we want to return @@ -645,27 +623,27 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // There's no error in the trailers, so this was probably an error // converting the bytes to a message, an error reading from the network, or // just an EOF. We're going to return it to the user, but we also want to - // setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // close the writer so Send errors out. + _ = cc.call.CloseWrite() return err } func (cc *connectStreamingClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.call.BlockUntilResponseReady() return cc.responseHeader } func (cc *connectStreamingClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.call.BlockUntilResponseReady() return cc.responseTrailer } func (cc *connectStreamingClientConn) CloseResponse() error { - return cc.duplexCall.CloseRead() + return cc.call.CloseRead() } func (cc *connectStreamingClientConn) onRequestSend(fn func(*http.Request)) { - cc.duplexCall.onRequestSend = fn + cc.call.onRequestSend = fn } func (cc *connectStreamingClientConn) validateResponse(response *http.Response) *Error { @@ -988,12 +966,10 @@ type connectUnaryRequestMarshaler struct { getURLMaxBytes int getUseFallback bool stableCodec stableCodec - unaryCall *unaryHTTPCall + call *httpCall } -func (m *connectUnaryRequestMarshaler) Marshal(message any) *Error { - buffer := m.bufferPool.Get() - defer m.bufferPool.Put(buffer) +func (m *connectUnaryRequestMarshaler) Marshal(buffer *bytes.Buffer, message any) *Error { if m.enableGet { if m.stableCodec == nil && !m.getUseFallback { return errorf(CodeInternal, "codec %s doesn't support stable marshal; can't use get", m.codec.Name()) @@ -1001,15 +977,8 @@ func (m *connectUnaryRequestMarshaler) Marshal(message any) *Error { if err := m.marshalWithGet(buffer, message); err != nil { return err } - } else { - if err := m.connectUnaryMarshaler.Marshal(buffer, message); err != nil { - return err - } } - if err := m.unaryCall.Do(buffer); err != nil { - return err - } - return nil + return m.connectUnaryMarshaler.Marshal(buffer, message) } func (m *connectUnaryRequestMarshaler) marshalWithGet(dst io.Writer, message any) *Error { @@ -1070,7 +1039,7 @@ func (m *connectUnaryRequestMarshaler) marshalWithGet(dst io.Writer, message any } func (m *connectUnaryRequestMarshaler) buildGetURL(data []byte, compressed bool) *url.URL { - url := *m.unaryCall.URL() + url := *m.call.URL() query := url.Query() query.Set(connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue) query.Set(connectUnaryEncodingQueryParameter, m.codec.Name()) @@ -1089,8 +1058,8 @@ func (m *connectUnaryRequestMarshaler) buildGetURL(data []byte, compressed bool) func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error { delete(m.header, connectHeaderProtocolVersion) - m.unaryCall.SetMethod(http.MethodGet) - *m.unaryCall.URL() = *url + m.call.SetMethod(http.MethodGet) + *m.call.URL() = *url return nil } diff --git a/protocol_grpc.go b/protocol_grpc.go index 25a43d64..c64fd23c 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -266,7 +266,7 @@ func (g *grpcClient) NewConn( encodedDeadline := grpcEncodeTimeout(time.Until(deadline)) header[grpcHeaderTimeout] = []string{encodedDeadline} } - duplexCall := newDuplexHTTPCall( + call := newHTTPCall( ctx, g.HTTPClient, g.URL, @@ -276,7 +276,7 @@ func (g *grpcClient) NewConn( conn := &grpcClientConn{ spec: spec, peer: g.Peer(), - duplexCall: duplexCall, + call: call, compressionPools: g.CompressionPools, bufferPool: g.BufferPool, protobuf: g.Protobuf, @@ -299,14 +299,14 @@ func (g *grpcClient) NewConn( responseHeader: make(http.Header), responseTrailer: make(http.Header), } - duplexCall.SetValidateResponse(conn.validateResponse) + call.SetValidateResponse(conn.validateResponse) if g.web { conn.unmarshaler.web = true - conn.readTrailers = func(unmarshaler *grpcUnmarshaler, _ *duplexHTTPCall) http.Header { + conn.readTrailers = func(unmarshaler *grpcUnmarshaler, _ *httpCall) http.Header { return unmarshaler.WebTrailer() } } else { - conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { + conn.readTrailers = func(_ *grpcUnmarshaler, call *httpCall) http.Header { // To access HTTP trailers, we need to read the body to EOF. _, _ = discard(call) return call.ResponseTrailer() @@ -315,56 +315,11 @@ func (g *grpcClient) NewConn( return wrapClientConnWithCodedErrors(conn) } -func (g *grpcClient) Invoke( - ctx context.Context, spec Spec, request AnyRequest, response AnyResponse, -) error { - // TODO: should we have a grpcUnaryClient struct? - unaryCall := newUnaryHTTPCall( - ctx, g.HTTPClient, g.URL, request.Header(), - ) - marshaler := grpcMarshaler{ - envelopeWriter: envelopeWriter{ - compressionPool: g.CompressionPools.Get(g.CompressionName), - codec: g.Codec, - compressMinBytes: g.CompressMinBytes, - bufferPool: g.BufferPool, - sendMaxBytes: g.SendMaxBytes, - }, - } - unmarshaler := grpcUnmarshaler{ - envelopeReader: envelopeReader{ - codec: g.Codec, - bufferPool: g.BufferPool, - readMaxBytes: g.ReadMaxBytes, - }, - } - var ( - responseHeader http.Header - responseTrailer http.Header - ) - unaryCall.SetValidateResponse(func(response *http.Response) *Error { - if err := grpcValidateResponse( - response, - responseHeader, - responseTrailer, - g.CompressionPools, - g.Protobuf, - ); err != nil { - return err - } - compression := getHeaderCanonical(response.Header, grpcHeaderCompression) - unmarshaler.envelopeReader.compressionPool = g.CompressionPools.Get(compression) - return nil - }) - - return nil -} - // grpcClientConn works for both gRPC and gRPC-Web. type grpcClientConn struct { spec Spec peer Peer - duplexCall *duplexHTTPCall + call *httpCall compressionPools readOnlyCompressionPools bufferPool *bufferPool protobuf Codec // for errors @@ -372,7 +327,7 @@ type grpcClientConn struct { unmarshaler grpcUnmarshaler responseHeader http.Header responseTrailer http.Header - readTrailers func(*grpcUnmarshaler, *duplexHTTPCall) http.Header + readTrailers func(*grpcUnmarshaler, *httpCall) http.Header } func (cc *grpcClientConn) Spec() Spec { @@ -384,24 +339,27 @@ func (cc *grpcClientConn) Peer() Peer { } func (cc *grpcClientConn) Send(msg any) error { - dst := cc.duplexCall - if err := cc.marshaler.Marshal(dst, msg); err != nil { + buffer := cc.bufferPool.Get() + defer cc.bufferPool.Put(buffer) + if err := cc.marshaler.Marshal(buffer, msg); err != nil { return err } - return nil // must be a literal nil: nil *Error is a non-nil error + return cc.call.Send(buffer) } func (cc *grpcClientConn) RequestHeader() http.Header { - return cc.duplexCall.Header() + return cc.call.Header() } func (cc *grpcClientConn) CloseRequest() error { - return cc.duplexCall.CloseWrite() + return cc.call.CloseWrite() } func (cc *grpcClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() - src := cc.duplexCall + if err := cc.call.BlockUntilResponseReady(); err != nil { + return err + } + src := cc.call err := cc.unmarshaler.Unmarshal(msg, src) if err == nil { return nil @@ -415,7 +373,7 @@ func (cc *grpcClientConn) Receive(msg any) error { // See if the server sent an explicit error in the HTTP or gRPC-Web trailers. mergeHeaders( cc.responseTrailer, - cc.readTrailers(&cc.unmarshaler, cc.duplexCall), + cc.readTrailers(&cc.unmarshaler, cc.call), ) serverErr := grpcErrorFromTrailer(cc.protobuf, cc.responseTrailer) if serverErr != nil && (errors.Is(err, io.EOF) || !errors.Is(serverErr, errTrailersWithoutGRPCStatus)) { @@ -429,32 +387,32 @@ func (cc *grpcClientConn) Receive(msg any) error { // the stream has ended, Receive must return an error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) + _ = cc.call.CloseWrite() return serverErr } // This was probably an error converting the bytes to a message or an error // reading from the network. We're going to return it to the - // user, but we also want to setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // user, but we also want to close writes so Send errors out. + _ = cc.call.CloseWrite() return err } func (cc *grpcClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.call.BlockUntilResponseReady() return cc.responseHeader } func (cc *grpcClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.call.BlockUntilResponseReady() return cc.responseTrailer } func (cc *grpcClientConn) CloseResponse() error { - return cc.duplexCall.CloseRead() + return cc.call.CloseRead() } func (cc *grpcClientConn) onRequestSend(fn func(*http.Request)) { - cc.duplexCall.onRequestSend = fn + cc.call.onRequestSend = fn } func (cc *grpcClientConn) validateResponse(response *http.Response) *Error {