diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 3c63c706986d..48e70fb555a5 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1009,7 +1009,12 @@ func (t *http2Client) Close(err error) { // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err}) - <-t.writerDone + timer := time.NewTimer(5 * time.Second) + select { + case <-t.writerDone: + case <-timer.C: + t.logger.Warningf("timeout waiting for the loopy writer to be closed.") + } t.cancel() t.conn.Close() channelz.RemoveEntry(t.channelz.ID) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 7887c8be8647..22ebbf41b3e2 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -2656,10 +2656,89 @@ func TestConnectionError_Unwrap(t *testing.T) { } } -// Test that in the event of a graceful client transport shutdown, i.e., +// TestClientSendsAGoAwayFrame verifies that in the event of a graceful client transport shutdown, i.e., // clientTransport.Close(), client sends a goaway to the server with the correct // error code and debug data. func (s) TestClientSendsAGoAwayFrame(t *testing.T) { + ctx, errorCh := createClientServerConn(t) + select { + case err := <-errorCh: + if err != nil { + t.Errorf("Error receiving the GOAWAY frame: %v", err) + } + case <-ctx.Done(): + t.Errorf("Context timed out") + } +} + +var connWriterHung = make(chan struct{}) + +type hangingConn struct { + net.Conn +} + +func (hc *hangingConn) Read(b []byte) (n int, err error) { + n, err = hc.Conn.Read(b) + return n, err +} + +func (hc *hangingConn) Write(b []byte) (n int, err error) { + n, err = hc.Conn.Write(b) + if n == 42 { // GOAWAY frame + close(connWriterHung) + } + return n, err +} + +func (hc *hangingConn) Close() error { + fmt.Printf("hangingConn Close %v\n", time.Now()) + return hc.Conn.Close() +} + +func (hc *hangingConn) LocalAddr() net.Addr { + return hc.Conn.LocalAddr() +} + +func (hc *hangingConn) RemoteAddr() net.Addr { + return hc.Conn.RemoteAddr() +} + +func (hc *hangingConn) SetDeadline(t time.Time) error { + return hc.Conn.SetDeadline(t) +} + +func (hc *hangingConn) SetReadDeadline(t time.Time) error { + return hc.Conn.SetReadDeadline(t) +} + +func (hc *hangingConn) SetWriteDeadline(t time.Time) error { + return hc.Conn.SetWriteDeadline(t) +} + +func hangingDialer(_ context.Context, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &hangingConn{Conn: conn}, nil +} + +// TestClientCloseTimeoutOnHang verifies that in the event of a graceful +// client transport shutdown, i.e., clientTransport.Close(), if the conn hung +// forever, client should still be close itself and do not wait for long. +func (s) TestClientCloseTimeoutOnHang(t *testing.T) { + ctx, _ := createClientServerConn(t) + select { + case <-connWriterHung: + case <-ctx.Done(): + t.Errorf("Context timed out") + } +} + +// createClientServerConn sets up a listener, and a client transport +// which sends a GOAWAY frame to server, returns test context and errorCh +// (which signals errors in the server's GOAWAY handling) +func createClientServerConn(t *testing.T) (context.Context, chan error) { // Create a server. lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -2725,7 +2804,7 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { } }() - ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) + ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{Dialer: hangingDialer}, func(GoAwayReason) {}) if err != nil { t.Fatalf("Error while creating client transport: %v", err) } @@ -2737,12 +2816,5 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { <-greetDone ct.Close(errors.New("manually closed by client")) t.Logf("Closed the client connection") - select { - case err := <-errorCh: - if err != nil { - t.Errorf("Error receiving the GOAWAY frame: %v", err) - } - case <-ctx.Done(): - t.Errorf("Context timed out") - } + return ctx, errorCh }