Skip to content

Commit

Permalink
Revert the atomic boolean for hangingConn.Write()
Browse files Browse the repository at this point in the history
  • Loading branch information
aranjans committed Aug 16, 2024
1 parent aebe562 commit fad99ba
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -58,8 +59,6 @@ func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

const goAwayFrameSize = 42

var (
expectedRequest = []byte("ping")
expectedResponse = []byte("pong")
Expand Down Expand Up @@ -2754,25 +2753,18 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
// signaled or a timeout occurs.
type hangingConn struct {
net.Conn
hangConn chan struct{}
hangConn chan struct{}
startHanging *atomic.Bool
}

func (hc *hangingConn) Write(b []byte) (n int, err error) {
n, err = hc.Conn.Write(b)
if n == goAwayFrameSize { // hang the conn after the goAway is received
if hc.startHanging.Load() {
<-hc.hangConn
}
return n, err
}

func hangingDialer(_ context.Context, addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &hangingConn{Conn: conn, hangConn: make(chan struct{})}, nil
}

// Tests the scenario where a client transport is closed and writing of the
// GOAWAY frame as part of the close does not complete because of a network
// hang. The test verifies that the client transport is closed without waiting
Expand All @@ -2788,14 +2780,35 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) {
}()

// Create the server set up.
server, ct, cancel := setUpWithOptions(t, 0, &ServerConfig{}, normal, ConnectOptions{Dialer: hangingDialer})
connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
defer server.stop()
addr := resolver.Address{Addr: "localhost:" + server.port}
isGreetingDone := &atomic.Bool{}
hangConn := make(chan struct{})
defer close(hangConn)
dialer := func(_ context.Context, addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil
}
copts := ConnectOptions{Dialer: dialer}
copts.ChannelzParent = channelzSubChannel(t)
// Create client transport with custom dialer
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
t.Fatalf("Failed to open stream: %v", err)
}

isGreetingDone.Store(true)
ct.Close(errors.New("manually closed by client"))
}

0 comments on commit fad99ba

Please sign in to comment.