From fad99bae49149d10b77447829a0afedbfee73af2 Mon Sep 17 00:00:00 2001 From: Abhishek Ranjan Date: Fri, 16 Aug 2024 23:04:56 +0530 Subject: [PATCH] Revert the atomic boolean for hangingConn.Write() --- internal/transport/transport_test.go | 39 ++++++++++++++++++---------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 6fcdd60d5594..36739490ed30 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -32,6 +32,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -58,8 +59,6 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } -const goAwayFrameSize = 42 - var ( expectedRequest = []byte("ping") expectedResponse = []byte("pong") @@ -2754,25 +2753,18 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { // signaled or a timeout occurs. type hangingConn struct { net.Conn - hangConn chan struct{} + hangConn chan struct{} + startHanging *atomic.Bool } func (hc *hangingConn) Write(b []byte) (n int, err error) { n, err = hc.Conn.Write(b) - if n == goAwayFrameSize { // hang the conn after the goAway is received + if hc.startHanging.Load() { <-hc.hangConn } return n, err } -func hangingDialer(_ context.Context, addr string) (net.Conn, error) { - conn, err := net.Dial("tcp", addr) - if err != nil { - return nil, err - } - return &hangingConn{Conn: conn, hangConn: make(chan struct{})}, nil -} - // Tests the scenario where a client transport is closed and writing of the // GOAWAY frame as part of the close does not complete because of a network // hang. The test verifies that the client transport is closed without waiting @@ -2788,14 +2780,35 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { }() // Create the server set up. - server, ct, cancel := setUpWithOptions(t, 0, &ServerConfig{}, normal, ConnectOptions{Dialer: hangingDialer}) + connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + server := setUpServerOnly(t, 0, &ServerConfig{}, normal) defer server.stop() + addr := resolver.Address{Addr: "localhost:" + server.port} + isGreetingDone := &atomic.Bool{} + hangConn := make(chan struct{}) + defer close(hangConn) + dialer := func(_ context.Context, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil + } + copts := ConnectOptions{Dialer: dialer} + copts.ChannelzParent = channelzSubChannel(t) + // Create client transport with custom dialer + ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) + if connErr != nil { + t.Fatalf("failed to create transport: %v", connErr) + } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { t.Fatalf("Failed to open stream: %v", err) } + + isGreetingDone.Store(true) ct.Close(errors.New("manually closed by client")) }