diff --git a/.golangci.yml b/.golangci.yml index 0ad344b1..ce0de60a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -123,3 +123,10 @@ issues: - linters: [revive] text: "^if-return: " path: error_writer.go + # We want to set http.Server's logger + - linters: [forbidigo] + path: internal/memhttp + text: "use of `log.(New|Logger|Lshortfile)` forbidden by pattern .*" + # We want to show examples with http.Get + - linters: [noctx] + path: internal/memhttp/memhttp_test.go diff --git a/client_example_test.go b/client_example_test.go index 56359017..c85e8d44 100644 --- a/client_example_test.go +++ b/client_example_test.go @@ -27,9 +27,8 @@ import ( func Example_client() { logger := log.New(os.Stdout, "" /* prefix */, 0 /* flags */) - // Unfortunately, pkg.go.dev can't run examples that actually use the - // network. To keep this example runnable, we'll use an HTTP server and - // client that communicate over in-memory pipes. The client is still a plain + // To keep this example runnable, we'll use an HTTP server and client + // that communicate over in-memory pipes. The client is still a plain // *http.Client! var httpClient *http.Client = examplePingServer.Client() diff --git a/client_ext_test.go b/client_ext_test.go index cb4dede4..ce799958 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -18,7 +18,6 @@ import ( "context" "errors" "net/http" - "net/http/httptest" "strings" "testing" @@ -26,6 +25,7 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestNewClient_InitFailure(t *testing.T) { @@ -75,55 +75,56 @@ func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { t.Helper() client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithClientOptions(opts...), connect.WithInterceptors(&assertPeerInterceptor{t}), ) ctx := context.Background() - // unary - unaryReq := connect.NewRequest[pingv1.PingRequest](nil) - _, err := client.Ping(ctx, unaryReq) - assert.Nil(t, err) - assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod()) - text := strings.Repeat(".", 256) - r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) - assert.Nil(t, err) - assert.Equal(t, r.Msg.Text, text) - // client streaming - clientStream := client.Sum(ctx) - t.Cleanup(func() { - _, closeErr := clientStream.CloseAndReceive() - assert.Nil(t, closeErr) + t.Run("unary", func(t *testing.T) { + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.Nil(t, err) + assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod()) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) }) - assert.NotZero(t, clientStream.Peer().Addr) - assert.NotZero(t, clientStream.Peer().Protocol) - err = clientStream.Send(&pingv1.SumRequest{}) - assert.Nil(t, err) - // server streaming - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) - t.Cleanup(func() { - assert.Nil(t, serverStream.Close()) + t.Run("client_stream", func(t *testing.T) { + clientStream := client.Sum(ctx) + t.Cleanup(func() { + _, closeErr := clientStream.CloseAndReceive() + assert.Nil(t, closeErr) + }) + assert.NotZero(t, clientStream.Peer().Addr) + assert.NotZero(t, clientStream.Peer().Protocol) + err := clientStream.Send(&pingv1.SumRequest{}) + assert.Nil(t, err) }) - assert.Nil(t, err) - // bidi streaming - bidiStream := client.CumSum(ctx) - t.Cleanup(func() { - assert.Nil(t, bidiStream.CloseRequest()) - assert.Nil(t, bidiStream.CloseResponse()) + t.Run("server_stream", func(t *testing.T) { + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + t.Cleanup(func() { + assert.Nil(t, serverStream.Close()) + }) + assert.Nil(t, err) + }) + t.Run("bidi_stream", func(t *testing.T) { + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) + }) + assert.NotZero(t, bidiStream.Peer().Addr) + assert.NotZero(t, bidiStream.Peer().Protocol) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) }) - assert.NotZero(t, bidiStream.Peer().Addr) - assert.NotZero(t, bidiStream.Peer().Protocol) - err = bidiStream.Send(&pingv1.CumSumRequest{}) - assert.Nil(t, err) } t.Run("connect", func(t *testing.T) { @@ -157,14 +158,10 @@ func TestGetNotModified(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(¬ModifiedPingServer{etag: etag})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithHTTPGet(), ) ctx := context.Background() diff --git a/client_get_fallback_test.go b/client_get_fallback_test.go index a8076ffd..c9444ef6 100644 --- a/client_get_fallback_test.go +++ b/client_get_fallback_test.go @@ -17,12 +17,12 @@ package connect import ( "context" "net/http" - "net/http/httptest" "strings" "testing" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestClientUnaryGetFallback(t *testing.T) { @@ -38,14 +38,11 @@ func TestClientUnaryGetFallback(t *testing.T) { }, WithIdempotency(IdempotencyNoSideEffects), )) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) client := NewClient[pingv1.PingRequest, pingv1.PingResponse]( server.Client(), - server.URL+"/connect.ping.v1.PingService/Ping", + server.URL()+"/connect.ping.v1.PingService/Ping", WithHTTPGet(), WithHTTPGetMaxURLSize(1, true), WithSendGzip(), diff --git a/compression_test.go b/compression_test.go index 5ae53e6b..7db457ad 100644 --- a/compression_test.go +++ b/compression_test.go @@ -17,10 +17,10 @@ package connect import ( "context" "net/http" - "net/http/httptest" "testing" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/types/known/emptypb" ) @@ -42,12 +42,10 @@ func TestAcceptEncodingOrdering(t *testing.T) { w.WriteHeader(http.StatusOK) called = true }) - server := httptest.NewServer(verify) - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, verify) client := NewClient[emptypb.Empty, emptypb.Empty]( server.Client(), - server.URL, + server.URL(), withFakeBrotli, withGzip(), ) diff --git a/connect_ext_test.go b/connect_ext_test.go index 9e75253e..af541259 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -26,7 +26,6 @@ import ( "math" "math/rand" "net/http" - "net/http/httptest" "strings" "sync" "testing" @@ -37,6 +36,8 @@ import ( "connectrpc.com/connect/internal/gen/connect/import/v1/importv1connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoregistry" ) @@ -167,8 +168,10 @@ func TestServer(t *testing.T) { assert.Equal(t, got, expect) }) t.Run("count_up_error", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) stream, err := client.CountUp( - context.Background(), + ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1}), ) assert.Nil(t, err) @@ -180,14 +183,28 @@ func TestServer(t *testing.T) { connect.CodeOf(stream.Err()), connect.CodeInvalidArgument, ) + assert.Nil(t, stream.Close()) }) t.Run("count_up_timeout", func(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) - defer cancel() + t.Cleanup(cancel) _, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1})) assert.NotNil(t, err) assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) + t.Run("count_up_cancel_after_first_response", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + request := connect.NewRequest(&pingv1.CountUpRequest{Number: 5}) + request.Header().Set(clientHeader, headerValue) + stream, err := client.CountUp(ctx, request) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + cancel() + assert.False(t, stream.Receive()) + assert.NotNil(t, stream.Err()) + assert.Equal(t, connect.CodeOf(stream.Err()), connect.CodeCanceled) + assert.Nil(t, stream.Close()) + }) } testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { @@ -282,10 +299,16 @@ func TestServer(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeCanceled) assert.Equal(t, got, expect) assert.False(t, connect.IsWireError(err)) + assert.Nil(t, stream.CloseResponse()) }) t.Run("cumsum_cancel_before_send", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) stream := client.CumSum(ctx) + if !expectSuccess { // server doesn't support HTTP/2 + failNoHTTP2(t, stream) + cancel() + return + } stream.RequestHeader().Set(clientHeader, headerValue) assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 8})) cancel() @@ -294,6 +317,8 @@ func TestServer(t *testing.T) { err := stream.Send(&pingv1.CumSumRequest{Number: 19}) assert.Equal(t, connect.CodeOf(err), connect.CodeCanceled, assert.Sprintf("%v", err)) assert.False(t, connect.IsWireError(err)) + assert.Nil(t, stream.CloseRequest()) + assert.Nil(t, stream.CloseResponse()) }) } testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper @@ -344,10 +369,10 @@ func TestServer(t *testing.T) { assertIsHTTPMiddlewareError(t, stream.Err()) }) } - testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { //nolint:thelper + testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) + client := pingv1connect.NewPingServiceClient(client, url, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -428,17 +453,15 @@ func TestServer(t *testing.T) { t.Run("http1", func(t *testing.T) { t.Parallel() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - testMatrix(t, server, false /* bidi */) + server := memhttptest.NewServer(t, mux) + client := &http.Client{Transport: server.TransportHTTP1()} + testMatrix(t, client, server.URL(), false /* bidi */) }) t.Run("http2", func(t *testing.T) { t.Parallel() - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - testMatrix(t, server, true /* bidi */) + server := memhttptest.NewServer(t, mux) + client := server.Client() + testMatrix(t, client, server.URL(), true /* bidi */) }) } @@ -449,17 +472,14 @@ func TestConcurrentStreams(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) var done, start sync.WaitGroup start.Add(1) for i := 0; i < 100; i++ { done.Add(1) go func() { defer done.Done() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) var total int64 sum := client.CumSum(context.Background()) start.Wait() @@ -510,10 +530,9 @@ func TestHeaderBasic(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) request.Header().Set(key, cval) response, err := client.Ping(context.Background(), request) @@ -536,14 +555,11 @@ func TestHeaderHost(t *testing.T) { }, } - newHTTP2Server := func(t *testing.T) *httptest.Server { + newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) return server } @@ -560,21 +576,21 @@ func TestHeaderHost(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) callWithHost(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) callWithHost(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) callWithHost(t, client) }) } @@ -594,12 +610,11 @@ func TestTimeoutParsing(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) } @@ -607,11 +622,10 @@ func TestTimeoutParsing(t *testing.T) { func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := httptest.NewServer(handler) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithCodec(failCodec{}), ) stream := client.CumSum(context.Background()) @@ -625,11 +639,10 @@ func TestFailCodec(t *testing.T) { func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := httptest.NewServer(handler) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -650,14 +663,11 @@ func TestGRPCMarshalStatusError(t *testing.T) { pingServer{}, connect.WithCodec(failCodec{}), )) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) request := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) _, err := client.Fail(context.Background(), request) tb.Log(err) @@ -692,11 +702,8 @@ func TestGRPCMissingTrailersError(t *testing.T) { mux.Handle(pingv1connect.NewPingServiceHandler( pingServer{checkMetadata: true}, )) - server := httptest.NewUnstartedServer(trimTrailers(mux)) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + server := memhttptest.NewServer(t, trimTrailers(mux)) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { t.Helper() @@ -778,14 +785,16 @@ func TestBidiRequiresHTTP2(t *testing.T) { _, err := io.WriteString(w, "hello world") assert.Nil(t, err) }) - server := httptest.NewServer(handler) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, handler) client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL, + &http.Client{Transport: server.TransportHTTP1()}, + server.URL(), ) stream := client.CumSum(context.Background()) - assert.Nil(t, stream.Send(&pingv1.CumSumRequest{})) + // Stream creates an async request, can error on Send or Receive. + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { + assert.ErrorIs(t, err, io.EOF) + } assert.Nil(t, stream.CloseRequest()) _, err := stream.Receive() assert.NotNil(t, err) @@ -806,11 +815,10 @@ func TestCompressMinBytesClient(t *testing.T) { mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) - server := httptest.NewServer(mux) - tb.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) _, err := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithSendGzip(), connect.WithCompressMinBytes(8), ).Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Text: text})) @@ -842,10 +850,7 @@ func TestCompressMinBytes(t *testing.T) { pingServer{}, connect.WithCompressMinBytes(8), )) - server := httptest.NewServer(mux) - t.Cleanup(func() { - server.Close() - }) + server := memhttptest.NewServer(t, mux) client := server.Client() getPingResponse := func(t *testing.T, pingText string) *http.Response { @@ -856,7 +861,7 @@ func TestCompressMinBytes(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", bytes.NewReader(requestBytes), ) assert.Nil(t, err) @@ -899,11 +904,9 @@ func TestCustomCompression(t *testing.T) { pingServer{}, connect.WithCompression(compressionName, decompressor, compressor), )) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), - server.URL, + server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), connect.WithSendCompression(compressionName), ) @@ -920,11 +923,9 @@ func TestClientWithoutGzipSupport(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), - server.URL, + server.URL(), connect.WithAcceptCompression("gzip", nil, nil), connect.WithSendGzip(), ) @@ -939,16 +940,13 @@ func TestInvalidHeaderTimeout(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(func() { - server.Close() - }) + server := memhttptest.NewServer(t, mux) getPingResponseWithTimeout := func(t *testing.T, timeout string) *http.Response { t.Helper() request, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", strings.NewReader("{}"), ) assert.Nil(t, err) @@ -975,9 +973,8 @@ func TestInterceptorReturnsWrongType(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { return nil, err @@ -1052,48 +1049,45 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes)) }) } - newHTTP2Server := func(t *testing.T) *httptest.Server { + newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) return server } t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) } @@ -1141,55 +1135,47 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) }) } - newHTTP2Server := func(t *testing.T) *httptest.Server { - t.Helper() - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - return server - } t.Run("connect", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) } func TestClientWithReadMaxBytes(t *testing.T) { t.Parallel() - createServer := func(tb testing.TB, enableCompression bool) *httptest.Server { + createServer := func(tb testing.TB, enableCompression bool) *memhttp.Server { tb.Helper() mux := http.NewServeMux() var compressionOption connect.HandlerOption @@ -1199,10 +1185,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - tb.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) return server } serverUncompressed := createServer(t, false) @@ -1250,32 +1233,32 @@ func TestClientWithReadMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL, connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL, connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, true) }) } @@ -1334,7 +1317,7 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { } }) } - newHTTP2Server := func(t *testing.T, compressed bool, sendMaxBytes int) *httptest.Server { + newHTTP2Server := func(t *testing.T, compressed bool, sendMaxBytes int) *memhttp.Server { t.Helper() mux := http.NewServeMux() options := []connect.HandlerOption{connect.WithSendMaxBytes(sendMaxBytes)} @@ -1347,46 +1330,43 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { pingServer{}, options..., )) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) return server } t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, true) }) } @@ -1395,10 +1375,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { @@ -1450,37 +1427,37 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes)) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) } @@ -1498,14 +1475,10 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithClientOptions(opts...), connect.WithInterceptors(&assertPeerInterceptor{t}), ) @@ -1537,26 +1510,23 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { func TestStreamForServer(t *testing.T) { t.Parallel() - newPingServer := func(pingServer pingv1connect.PingServiceHandler) (pingv1connect.PingServiceClient, *httptest.Server) { + newPingClient := func(pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) - return client, server + return client } t.Run("not-proto-message", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { return stream.Conn().Send("foobar") }, }) - t.Cleanup(server.Close) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(nil)) _, err := stream.Receive() @@ -1566,12 +1536,11 @@ func TestStreamForServer(t *testing.T) { }) t.Run("nil-message", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { return stream.Send(nil) }, }) - t.Cleanup(server.Close) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(nil)) _, err := stream.Receive() @@ -1581,7 +1550,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("get-spec", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) @@ -1589,14 +1558,13 @@ func TestStreamForServer(t *testing.T) { return nil }, }) - t.Cleanup(server.Close) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(nil)) assert.Nil(t, stream.CloseRequest()) }) t.Run("server-stream", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) @@ -1605,7 +1573,6 @@ func TestStreamForServer(t *testing.T) { return nil }, }) - t.Cleanup(server.Close) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) assert.Nil(t, err) assert.NotNil(t, stream) @@ -1613,13 +1580,12 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream-send", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) return nil }, }) - t.Cleanup(server.Close) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) assert.Nil(t, err) assert.True(t, stream.Receive()) @@ -1630,7 +1596,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream-send-nil", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { stream.ResponseHeader().Set("foo", "bar") stream.ResponseTrailer().Set("bas", "blah") @@ -1638,7 +1604,6 @@ func TestStreamForServer(t *testing.T) { return nil }, }) - t.Cleanup(server.Close) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) assert.Nil(t, err) assert.False(t, stream.Receive()) @@ -1652,7 +1617,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) @@ -1664,7 +1629,6 @@ func TestStreamForServer(t *testing.T) { return connect.NewResponse(&pingv1.SumResponse{Sum: 1}), nil }, }) - t.Cleanup(server.Close) stream := client.Sum(context.Background()) assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() @@ -1674,13 +1638,12 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream-conn", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.NotNil(t, stream.Conn().Send("not-proto")) return connect.NewResponse(&pingv1.SumResponse{}), nil }, }) - t.Cleanup(server.Close) stream := client.Sum(context.Background()) assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() @@ -1689,13 +1652,12 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream-send-msg", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) return connect.NewResponse(&pingv1.SumResponse{}), nil }, }) - t.Cleanup(server.Close) stream := client.Sum(context.Background()) assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() @@ -1716,12 +1678,11 @@ func TestConnectHTTPErrorCodes(t *testing.T) { }, } mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", strings.NewReader("{}"), ) assert.Nil(t, err) @@ -1730,7 +1691,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { assert.Nil(t, err) defer resp.Body.Close() assert.Equal(t, wantHttpStatus, resp.StatusCode) - connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) @@ -1821,13 +1782,10 @@ func TestFailCompression(t *testing.T) { connect.WithCompression(compressorName, decompressor, compressor), ), ) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) pingclient := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithAcceptCompression(compressorName, decompressor, compressor), connect.WithSendCompression(compressorName), ) @@ -1859,10 +1817,7 @@ func TestUnflushableResponseWriter(t *testing.T) { handler.ServeHTTP(&unflushableWriter{w}, r) }) mux.Handle(path, wrapped) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) tests := []struct { name string @@ -1876,7 +1831,7 @@ func TestUnflushableResponseWriter(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, tt.options...) + pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), tt.options...) stream, err := pingclient.CountUp( context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 5}), @@ -1895,10 +1850,7 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) protoBytes, err := proto.Marshal(&pingv1.FailRequest{Code: int32(connect.CodeInternal)}) assert.Nil(t, err) @@ -1911,7 +1863,7 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingv1connect.PingServiceFailProcedure, + server.URL()+pingv1connect.PingServiceFailProcedure, bytes.NewReader(body), ) assert.Nil(t, err) @@ -1935,12 +1887,9 @@ func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -1959,8 +1908,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { pingServer{}, connect.WithRequireConnectProtocolHeader(), )) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) tests := []struct { headers http.Header @@ -1972,7 +1920,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", strings.NewReader("{}"), ) assert.Nil(t, err) @@ -1999,8 +1947,7 @@ func TestAllowCustomUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) // If the user has set a User-Agent, we shouldn't clobber it. tests := []struct { @@ -2012,7 +1959,7 @@ func TestAllowCustomUserAgent(t *testing.T) { {"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}}, } for _, testCase := range tests { - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, testCase.opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) _, err := client.Ping(context.Background(), req) @@ -2036,10 +1983,9 @@ func TestWebXUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) _, err := client.Ping(context.Background(), req) assert.Nil(t, err) @@ -2049,14 +1995,17 @@ func TestBidiOverHTTP1(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient( + &http.Client{Transport: server.TransportHTTP1()}, + server.URL(), + ) stream := client.CumSum(context.Background()) + // Stream creates an async request, can error on Send or Receive. if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil { assert.ErrorIs(t, err, io.EOF) } @@ -2095,9 +2044,8 @@ func TestHandlerReturnsNilResponse(t *testing.T) { return nil, nil //nolint: nilnil }, }, connect.WithRecover(recoverPanic))) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) @@ -2126,10 +2074,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, _ = io.Copy(io.Discard, request.Body) testcase(responseWriter, request) }) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) head := [5]byte{} payload := []byte(`{"number": 42}`) @@ -2309,7 +2254,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { t.Parallel() client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), testcase.options..., ) const upTo = 2 @@ -2415,6 +2360,7 @@ func (p *pluggablePingServer) CumSum( func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { tb.Helper() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { assert.ErrorIs(tb, err, io.EOF) assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) diff --git a/duplex_http_call.go b/duplex_http_call.go index ab1a6db4..7181dd65 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -42,13 +42,16 @@ type duplexHTTPCall struct { requestBodyReader *io.PipeReader requestBodyWriter *io.PipeWriter + // sendRequestOnce ensures we only send the request once. sendRequestOnce sync.Once - responseReady chan struct{} request *http.Request - response *http.Response - errMu sync.Mutex - err error + // responseReady is closed when the response is ready or when the request + // fails. Any error on request initialisation will be set on the + // responseErr. There's always a response if responseErr is nil. + responseReady chan struct{} + response *http.Response + responseErr error } func newDuplexHTTPCall( @@ -91,13 +94,11 @@ func newDuplexHTTPCall( } } -// Write to the request body. Returns an error wrapping io.EOF after SetError -// is called. +// Write to the request body. func (d *duplexHTTPCall) Write(data []byte) (int, error) { d.ensureRequestMade() // Before we send any data, check if the context has been canceled. if err := d.ctx.Err(); err != nil { - d.SetError(err) return 0, wrapIfContextError(err) } // It's safe to write to this side of the pipe while net/http concurrently @@ -157,14 +158,12 @@ func (d *duplexHTTPCall) SetMethod(method string) { func (d *duplexHTTPCall) Read(data []byte) (int, error) { // First, we wait until we've gotten the response headers and established the // server-to-client side of the stream. - d.BlockUntilResponseReady() - if err := d.getError(); err != nil { + if err := d.BlockUntilResponseReady(); err != nil { // The stream is already closed or corrupted. return 0, err } // Before we read, check if the context has been canceled. if err := d.ctx.Err(); err != nil { - d.SetError(err) return 0, wrapIfContextError(err) } if d.response == nil { @@ -175,11 +174,13 @@ func (d *duplexHTTPCall) Read(data []byte) (int, error) { } func (d *duplexHTTPCall) CloseRead() error { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response == nil { return nil } - if _, err := discard(d.response.Body); err != nil { + if _, err := discard(d.response.Body); err != nil && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { _ = d.response.Body.Close() return wrapIfRSTError(err) } @@ -188,16 +189,15 @@ func (d *duplexHTTPCall) CloseRead() error { // ResponseStatusCode is the response's HTTP status code. func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { - d.BlockUntilResponseReady() - if d.response == nil { - return 0, fmt.Errorf("nil response from %v", d.request.URL) + if err := d.BlockUntilResponseReady(); err != nil { + return 0, err } return d.response.StatusCode, nil } // ResponseHeader returns the response HTTP headers. func (d *duplexHTTPCall) ResponseHeader() http.Header { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Header } @@ -206,46 +206,29 @@ func (d *duplexHTTPCall) ResponseHeader() http.Header { // ResponseTrailer returns the response HTTP trailers. func (d *duplexHTTPCall) ResponseTrailer() http.Header { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Trailer } return make(http.Header) } -// SetError stores any error encountered processing the response. All -// subsequent calls to Read return this error, and all subsequent calls to -// Write return an error wrapping io.EOF. It's safe to call concurrently with -// any other method. -func (d *duplexHTTPCall) SetError(err error) { - d.errMu.Lock() - if d.err == nil { - d.err = wrapIfContextError(err) - } - // Closing the read side of the request body pipe acquires an internal lock, - // so we want to scope errMu's usage narrowly and avoid defer. - d.errMu.Unlock() - - // We've already hit an error, so we should stop writing to the request body. - // It's safe to call Close more than once and/or concurrently (calls after - // the first are no-ops), so it's okay for us to call this even though - // net/http sometimes closes the reader too. - // - // It's safe to ignore the returned error here. Under the hood, Close calls - // CloseWithError, which is documented to always return nil. - _ = d.requestBodyReader.Close() -} - // SetValidateResponse sets the response validation function. The function runs // in a background goroutine. func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Error) { d.validateResponse = validate } -func (d *duplexHTTPCall) BlockUntilResponseReady() { +// BlockUntilResponseReady returns when the response is ready or reports an +// error from initializing the request. +func (d *duplexHTTPCall) BlockUntilResponseReady() error { <-d.responseReady + return d.responseErr } +// ensureRequestMade sends the request headers and starts the response stream. +// It is not safe to call this concurrently. Write and CloseWrite call this but +// ensure that they're not called concurrently. func (d *duplexHTTPCall) ensureRequestMade() { d.sendRequestOnce.Do(func() { go d.makeRequest() @@ -267,6 +250,9 @@ func (d *duplexHTTPCall) makeRequest() { } // Once we send a message to the server, they send a message back and // establish the receive side of the stream. + // On error, we close the request body using the Write side of the pipe. + // This ensures HTTP2 streams receive an io.EOF from the Read side of the + // pipe. Write's check for io.ErrClosedPipe and will convert this to io.EOF. response, err := d.httpClient.Do(d.request) //nolint:bodyclose if err != nil { err = wrapIfContextError(err) @@ -276,33 +262,32 @@ func (d *duplexHTTPCall) makeRequest() { if _, ok := asError(err); !ok { err = NewError(CodeUnavailable, err) } - d.SetError(err) + d.responseErr = err + d.requestBodyWriter.Close() return } + // We've got a response. We can now read from the response body. + // Closing the response body is delegated to the caller even on error. d.response = response if err := d.validateResponse(response); err != nil { - d.SetError(err) + d.responseErr = err + d.requestBodyWriter.Close() return } if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { // If we somehow dialed an HTTP/1.x server, fail with an explicit message // rather than returning a more cryptic error later on. - d.SetError(errorf( + d.responseErr = errorf( CodeUnimplemented, "response from %v is HTTP/%d.%d: bidi streams require at least HTTP/2", d.request.URL, response.ProtoMajor, response.ProtoMinor, - )) + ) + d.requestBodyWriter.Close() } } -func (d *duplexHTTPCall) getError() error { - d.errMu.Lock() - defer d.errMu.Unlock() - return d.err -} - // See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33 func cloneURL(oldURL *url.URL) *url.URL { if oldURL == nil { diff --git a/example_init_test.go b/example_init_test.go index ab7e1563..9edf5114 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -15,130 +15,23 @@ package connect_test import ( - "context" - "errors" - "net" "net/http" - "net/http/httptest" - "sync" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" ) -var examplePingServer *inMemoryServer +var examplePingServer *memhttp.Server func init() { - // Generally, init functions are bad. + // Generally, init functions are bad. However, we need to set up the server + // before the examples run. // // To write testable examples that users can grok *and* can execute in the - // playground, where networking is disabled, we need an HTTP server that uses - // in-memory pipes instead of TCP. We don't want to pollute every example - // with this setup code. - // - // The least-awful option is to set up the server in init(). + // playground we use an in memory pipe as network based playgrounds can + // deadlock, see: + // (https://github.com/golang/go/issues/48394) mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - examplePingServer = newInMemoryServer(mux) -} - -// inMemoryServer is an HTTP server that uses in-memory pipes instead of TCP. -// It supports HTTP/2 and has TLS enabled. -// -// The Go Playground panics if we try to start a TCP-backed server. If you're -// not familiar with the Playground's behavior, it looks like our examples are -// broken. This server lets us write examples that work in the playground -// without abstracting over HTTP. -type inMemoryServer struct { - server *httptest.Server - listener *memoryListener + examplePingServer = memhttp.NewServer(mux) } - -// newInMemoryServer constructs and starts an inMemoryServer. -func newInMemoryServer(handler http.Handler) *inMemoryServer { - lis := &memoryListener{ - conns: make(chan net.Conn), - closed: make(chan struct{}), - } - server := httptest.NewUnstartedServer(handler) - server.Listener = lis - server.EnableHTTP2 = true - server.StartTLS() - return &inMemoryServer{ - server: server, - listener: lis, - } -} - -// Client returns an HTTP client configured to trust the server's TLS -// certificate and use HTTP/2 over an in-memory pipe. Automatic HTTP-level gzip -// compression is disabled. It closes its idle connections when the server is -// closed. -func (s *inMemoryServer) Client() *http.Client { - client := s.server.Client() - if transport, ok := client.Transport.(*http.Transport); ok { - transport.DialContext = s.listener.DialContext - transport.DisableCompression = true - } - return client -} - -// URL is the server's URL. -func (s *inMemoryServer) URL() string { - return s.server.URL -} - -// Close shuts down the server, blocking until all outstanding requests have -// completed. -func (s *inMemoryServer) Close() { - s.server.Close() -} - -type memoryListener struct { - conns chan net.Conn - once sync.Once - closed chan struct{} -} - -// Accept implements net.Listener. -func (l *memoryListener) Accept() (net.Conn, error) { - select { - case conn := <-l.conns: - return conn, nil - case <-l.closed: - return nil, errors.New("listener closed") - } -} - -// Close implements net.Listener. -func (l *memoryListener) Close() error { - l.once.Do(func() { - close(l.closed) - }) - return nil -} - -// Addr implements net.Listener. -func (l *memoryListener) Addr() net.Addr { - return &memoryAddr{} -} - -// DialContext is the type expected by http.Transport.DialContext. -func (l *memoryListener) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - select { - case <-l.closed: - return nil, errors.New("listener closed") - default: - } - server, client := net.Pipe() - l.conns <- server - return client, nil -} - -type memoryAddr struct{} - -// Network implements net.Addr. -func (*memoryAddr) Network() string { return "memory" } - -// String implements io.Stringer, returning a value that matches the -// certificates used by net/http/httptest. -func (*memoryAddr) String() string { return "example.com" } diff --git a/go.mod b/go.mod index 0bb3ca93..a1da35ba 100644 --- a/go.mod +++ b/go.mod @@ -11,3 +11,8 @@ require ( github.com/google/go-cmp v0.5.9 google.golang.org/protobuf v1.31.0 ) + +require ( + golang.org/x/net v0.16.0 // indirect + golang.org/x/text v0.13.0 // indirect +) diff --git a/go.sum b/go.sum index 4d0bc04e..8d2bee48 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= +golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= diff --git a/handler_ext_test.go b/handler_ext_test.go index 4aeb78f0..25cde595 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -21,7 +21,6 @@ import ( "encoding/json" "io" "net/http" - "net/http/httptest" "strings" "sync" "testing" @@ -30,6 +29,7 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestHandler_ServeHTTP(t *testing.T) { @@ -42,18 +42,15 @@ func TestHandler_ServeHTTP(t *testing.T) { mux.Handle("/prefixed/", http.StripPrefix("/prefixed", prefixed)) const pingProcedure = pingv1connect.PingServicePingProcedure const sumProcedure = pingv1connect.PingServiceSumProcedure - server := httptest.NewServer(mux) + server := memhttptest.NewServer(t, mux) client := server.Client() - t.Cleanup(func() { - server.Close() - }) t.Run("get_method_no_encoding", func(t *testing.T) { t.Parallel() request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader(""), ) assert.Nil(t, err) @@ -68,7 +65,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+pingProcedure+`?encoding=unk&message={}`, + server.URL()+pingProcedure+`?encoding=unk&message={}`, strings.NewReader(""), ) assert.Nil(t, err) @@ -83,7 +80,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+pingProcedure+`?encoding=json&message={}`, + server.URL()+pingProcedure+`?encoding=json&message={}`, strings.NewReader(""), ) assert.Nil(t, err) @@ -98,7 +95,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+"/prefixed"+pingProcedure+`?encoding=json&message={}`, + server.URL()+"/prefixed"+pingProcedure+`?encoding=json&message={}`, strings.NewReader(""), ) assert.Nil(t, err) @@ -113,7 +110,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+sumProcedure, + server.URL()+sumProcedure, strings.NewReader(""), ) assert.Nil(t, err) @@ -129,7 +126,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -158,7 +155,7 @@ func TestHandler_ServeHTTP(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -174,7 +171,7 @@ func TestHandler_ServeHTTP(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -190,7 +187,7 @@ func TestHandler_ServeHTTP(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -217,8 +214,7 @@ func TestHandlerMaliciousPrefix(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(successPingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) const ( concurrency = 256 @@ -234,7 +230,7 @@ func TestHandlerMaliciousPrefix(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingv1connect.PingServicePingProcedure, + server.URL()+pingv1connect.PingServicePingProcedure, bytes.NewReader(body), ) assert.Nil(t, err) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 1e87ac82..a671904b 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "net/http" - "net/http/httptest" "sync/atomic" "testing" @@ -26,6 +25,7 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestOnionOrderingEndToEnd(t *testing.T) { @@ -127,12 +127,10 @@ func TestOnionOrderingEndToEnd(t *testing.T) { handlerOnion, ), ) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), clientOnion, ) @@ -174,9 +172,8 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { } }) mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(interceptor)) + server := memhttptest.NewServer(t, mux) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) sumStream := connectClient.Sum(context.Background()) @@ -204,12 +201,10 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { connect.WithInterceptors(handlerChecker), ), ) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithInterceptors(clientChecker), ) diff --git a/internal/memhttp/listener.go b/internal/memhttp/listener.go new file mode 100644 index 00000000..adec7519 --- /dev/null +++ b/internal/memhttp/listener.go @@ -0,0 +1,94 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp + +import ( + "context" + "errors" + "net" + "sync" +) + +var ( + errListenerClosed = errors.New("listener closed") +) + +// memoryListener is a net.Listener that listens on an in memory network. +type memoryListener struct { + addr memoryAddr + + conns chan net.Conn + once sync.Once + closed chan struct{} +} + +// newMemoryListener returns a new in-memory listener. +func newMemoryListener(addr string) *memoryListener { + return &memoryListener{ + addr: memoryAddr(addr), + conns: make(chan net.Conn), + closed: make(chan struct{}), + } +} + +// Accept implements net.Listener. +func (l *memoryListener) Accept() (net.Conn, error) { + select { + case <-l.closed: + return nil, &net.OpError{ + Op: "accept", + Net: l.addr.Network(), + Addr: l.addr, + Err: errListenerClosed, + } + case server := <-l.conns: + return server, nil + } +} + +// Close implements net.Listener. +func (l *memoryListener) Close() error { + l.once.Do(func() { + close(l.closed) + }) + return nil +} + +// Addr implements net.Listener. +func (l *memoryListener) Addr() net.Addr { + return l.addr +} + +// DialContext is the type expected by http.Transport.DialContext. +func (l *memoryListener) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + server, client := net.Pipe() + select { + case <-ctx.Done(): + return nil, &net.OpError{Op: "dial", Net: l.addr.Network(), Err: ctx.Err()} + case l.conns <- server: + return client, nil + case <-l.closed: + return nil, &net.OpError{Op: "dial", Net: l.addr.Network(), Err: errListenerClosed} + } +} + +type memoryAddr string + +// Network implements net.Addr. +func (memoryAddr) Network() string { return "memory" } + +// String implements io.Stringer, returning a value that matches the +// certificates used by net/http/httptest. +func (a memoryAddr) String() string { return string(a) } diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go new file mode 100644 index 00000000..67de9935 --- /dev/null +++ b/internal/memhttp/memhttp.go @@ -0,0 +1,150 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "sync" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// Server is a net/http server that uses in-memory pipes instead of TCP. By +// default, it supports http/2 via h2c. It otherwise uses the same configuration +// as the zero value of [http.Server]. +type Server struct { + server http.Server + listener *memoryListener + url string + cleanupTimeout time.Duration + + serverWG sync.WaitGroup + serverErr error +} + +// NewServer creates a new Server that uses the given handler. Configuration +// options may be provided via [Option]s. +func NewServer(handler http.Handler, opts ...Option) *Server { + var cfg config + WithCleanupTimeout(5 * time.Second).apply(&cfg) + for _, opt := range opts { + opt.apply(&cfg) + } + + h2s := &http2.Server{} + handler = h2c.NewHandler(handler, h2s) + listener := newMemoryListener("1.2.3.4") // httptest.DefaultRemoteAddr + server := &Server{ + server: http.Server{ + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + }, + listener: listener, + url: "http://" + listener.Addr().String(), + cleanupTimeout: cfg.CleanupTimeout, + } + server.serverWG.Add(1) + go func() { + defer server.serverWG.Done() + server.serverErr = server.server.Serve(server.listener) + }() + return server +} + +// Transport returns a [http2.Transport] configured to use in-memory pipes +// rather than TCP and speak both HTTP/1.1 and HTTP/2. +// +// Callers may reconfigure the returned transport without affecting other transports. +func (s *Server) Transport() *http2.Transport { + return &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return s.listener.DialContext(ctx, network, addr) + }, + AllowHTTP: true, + } +} + +// TransportHTTP1 returns a [http.Transport] configured to use in-memory pipes +// rather than TCP and speak HTTP/1.1. +// +// Callers may reconfigure the returned transport without affecting other transports. +func (s *Server) TransportHTTP1() *http.Transport { + return &http.Transport{ + DialContext: s.listener.DialContext, + // TODO(emcfarlane): DisableKeepAlives false can causes tests + // to hang on shutdown. + DisableKeepAlives: true, + } +} + +// Client returns an [http.Client] configured to use in-memory pipes rather +// than TCP and speak HTTP/2. It is configured to use the same +// [http2.Transport] as [Transport]. +// +// Callers may reconfigure the returned client without affecting other clients. +func (s *Server) Client() *http.Client { + return &http.Client{Transport: s.Transport()} +} + +// URL returns the server's URL. +func (s *Server) URL() string { + return s.url +} + +// Shutdown gracefully shuts down the server, without interrupting any active +// connections. See [http.Server.Shutdown] for details. +func (s *Server) Shutdown(ctx context.Context) error { + if err := s.server.Shutdown(ctx); err != nil { + return err + } + return s.Wait() +} + +// Cleanup calls shutdown with a background context set with the cleanup timeout. +// The default timeout duration is 5 seconds. +func (s *Server) Cleanup() error { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, s.cleanupTimeout) + defer cancel() + return s.Shutdown(ctx) +} + +// Close closes the server's listener. It does not wait for connections to +// finish. +func (s *Server) Close() error { + return s.server.Close() +} + +// RegisterOnShutdown registers a function to call on Shutdown. See +// [http.Server.RegisterOnShutdown] for details. +func (s *Server) RegisterOnShutdown(f func()) { + s.server.RegisterOnShutdown(f) +} + +// Wait blocks until the server exits, then returns an error if not +// a [http.ErrServerClosed] error. +func (s *Server) Wait() error { + s.serverWG.Wait() + if !errors.Is(s.serverErr, http.ErrServerClosed) { + return s.serverErr + } + return nil +} diff --git a/internal/memhttp/memhttp_test.go b/internal/memhttp/memhttp_test.go new file mode 100644 index 00000000..06e11ee4 --- /dev/null +++ b/internal/memhttp/memhttp_test.go @@ -0,0 +1,140 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp_test + +import ( + "context" + "fmt" + "io" + "net/http" + "sync" + "testing" + "time" + + "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" +) + +func TestServerTransport(t *testing.T) { + t.Parallel() + const concurrency = 100 + const greeting = "Hello, world!" + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(greeting)) + }) + server := memhttptest.NewServer(t, handler) + + for _, transport := range []http.RoundTripper{ + server.Transport(), + server.TransportHTTP1(), + } { + client := &http.Client{Transport: transport} + t.Run(fmt.Sprintf("%T", transport), func(t *testing.T) { + t.Parallel() + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodGet, + server.URL(), + nil, + ) + assert.Nil(t, err) + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, res.StatusCode, http.StatusOK) + body, err := io.ReadAll(res.Body) + assert.Nil(t, err) + assert.Nil(t, res.Body.Close()) + assert.Equal(t, string(body), greeting) + }() + } + wg.Wait() + }) + } +} + +func TestRegisterOnShutdown(t *testing.T) { + t.Parallel() + okay := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + server := memhttp.NewServer(okay) + done := make(chan struct{}) + server.RegisterOnShutdown(func() { + close(done) + }) + assert.Nil(t, server.Shutdown(context.Background())) + select { + case <-done: + case <-time.After(5 * time.Second): + t.Error("OnShutdown hook didn't fire") + } +} + +func Example() { + hello := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello, world!") + }) + srv := memhttp.NewServer(hello) + defer srv.Close() + res, err := srv.Client().Get(srv.URL()) + if err != nil { + panic(err) + } + defer res.Body.Close() + fmt.Println(res.Status) + // Output: + // 200 OK +} + +func ExampleServer_Client() { + hello := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello, world!") + }) + srv := memhttp.NewServer(hello) + defer srv.Close() + client := srv.Client() + client.Timeout = 10 * time.Second + res, err := client.Get(srv.URL()) + if err != nil { + panic(err) + } + defer res.Body.Close() + fmt.Println(res.Status) + // Output: + // 200 OK +} + +func ExampleServer_Shutdown() { + hello := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello, world!") + }) + srv := memhttp.NewServer(hello) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + panic(err) + } + fmt.Println("Server has shut down") + // Output: + // Server has shut down +} diff --git a/internal/memhttp/memhttptest/http.go b/internal/memhttp/memhttptest/http.go new file mode 100644 index 00000000..f6d90c55 --- /dev/null +++ b/internal/memhttp/memhttptest/http.go @@ -0,0 +1,54 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttptest + +import ( + "log" + "net/http" + "testing" + + "connectrpc.com/connect/internal/memhttp" +) + +// NewServer constructs a [memhttp.Server] with defaults suitable for tests: +// it logs runtime errors to the provided testing.TB, and it automatically shuts +// down the server when the test completes. Startup and shutdown errors fail the +// test. +// +// To customize the server, use any [memhttp.Option]. In particular, it may be +// necessary to customize the shutdown timeout with +// [memhttp.WithCleanupTimeout]. +func NewServer(tb testing.TB, handler http.Handler, opts ...memhttp.Option) *memhttp.Server { + tb.Helper() + logger := log.New(&testWriter{tb}, "" /* prefix */, log.Lshortfile) + opts = append([]memhttp.Option{memhttp.WithErrorLog(logger)}, opts...) + server := memhttp.NewServer(handler, opts...) + tb.Cleanup(func() { + if err := server.Cleanup(); err != nil { + tb.Error(err) + } + }) + return server +} + +// testWriter is an io.Writer that logs to the testing.TB. +type testWriter struct { + tb testing.TB +} + +func (l *testWriter) Write(p []byte) (int, error) { + l.tb.Log(string(p)) + return len(p), nil +} diff --git a/internal/memhttp/option.go b/internal/memhttp/option.go new file mode 100644 index 00000000..b3e972d5 --- /dev/null +++ b/internal/memhttp/option.go @@ -0,0 +1,59 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp + +import ( + "log" + "time" +) + +// config is the configuration for a Server. +type config struct { + CleanupTimeout time.Duration + ErrorLog *log.Logger +} + +// An Option configures a Server. +type Option interface { + apply(*config) +} + +type optionFunc func(*config) + +func (f optionFunc) apply(cfg *config) { f(cfg) } + +// WithOptions composes multiple Options into one. +func WithOptions(opts ...Option) Option { + return optionFunc(func(cfg *config) { + for _, opt := range opts { + opt.apply(cfg) + } + }) +} + +// WithErrorLog sets [http.Server.ErrorLog]. +func WithErrorLog(l *log.Logger) Option { + return optionFunc(func(cfg *config) { + cfg.ErrorLog = l + }) +} + +// WithCleanupTimeout customizes the default five-second timeout for the +// server's Cleanup method. +func WithCleanupTimeout(d time.Duration) Option { + return optionFunc(func(cfg *config) { + cfg.CleanupTimeout = d + }) +} diff --git a/protocol_connect.go b/protocol_connect.go index b14eb4db..bd5500e4 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -476,7 +476,9 @@ func (cc *connectUnaryClientConn) CloseRequest() error { } func (cc *connectUnaryClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } if err := cc.unmarshaler.Unmarshal(msg); err != nil { return err } @@ -484,12 +486,12 @@ func (cc *connectUnaryClientConn) Receive(msg any) error { } func (cc *connectUnaryClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *connectUnaryClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer } @@ -587,7 +589,9 @@ func (cc *connectStreamingClientConn) CloseRequest() error { } func (cc *connectStreamingClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil @@ -601,7 +605,7 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) + _ = cc.duplexCall.CloseWrite() return serverErr } // If the error is EOF but not from a last message, we want to return @@ -612,18 +616,18 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // There's no error in the trailers, so this was probably an error // converting the bytes to a message, an error reading from the network, or // just an EOF. We're going to return it to the user, but we also want to - // setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // close the writer so Send errors out. + _ = cc.duplexCall.CloseWrite() return err } func (cc *connectStreamingClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *connectStreamingClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer } diff --git a/protocol_grpc.go b/protocol_grpc.go index 22d1eb78..177e31f2 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -358,7 +358,9 @@ func (cc *grpcClientConn) CloseRequest() error { } func (cc *grpcClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil @@ -386,23 +388,23 @@ func (cc *grpcClientConn) Receive(msg any) error { // the stream has ended, Receive must return an error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) + _ = cc.duplexCall.CloseWrite() return serverErr } // This was probably an error converting the bytes to a message or an error // reading from the network. We're going to return it to the - // user, but we also want to setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // user, but we also want to close writes so Send errors out. + _ = cc.duplexCall.CloseWrite() return err } func (cc *grpcClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *grpcClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer } diff --git a/recover_ext_test.go b/recover_ext_test.go index e8cb991b..a385b1d1 100644 --- a/recover_ext_test.go +++ b/recover_ext_test.go @@ -18,13 +18,13 @@ import ( "context" "fmt" "net/http" - "net/http/httptest" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) type panicPingServer struct { @@ -77,13 +77,10 @@ func TestWithRecover(t *testing.T) { pinger := &panicPingServer{} mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pinger, connect.WithRecover(handle))) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) for _, panicWith := range []any{42, nil} {