diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 3c63c706986d..f63fae5b8b5a 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1009,7 +1009,12 @@ func (t *http2Client) Close(err error) { // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err}) - <-t.writerDone + timer := time.NewTimer(5 * time.Second) + select { + case <-t.writerDone: + case <-timer.C: + t.logger.Warningf("timeout waiting for the loopy writer to be closed.") + } t.cancel() t.conn.Close() channelz.RemoveEntry(t.channelz.ID) @@ -1035,6 +1040,7 @@ func (t *http2Client) Close(err error) { } sh.HandleConn(t.ctx, connEnd) } + t.logger.Infof("Closed the client connection") } // GracefulClose sets the state to draining, which prevents new streams from diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 7887c8be8647..23b2fbd284c7 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -90,6 +90,7 @@ const ( invalidHeaderField delayRead pingpong + goAwayFrameSize = 42 ) func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { @@ -2656,7 +2657,7 @@ func TestConnectionError_Unwrap(t *testing.T) { } } -// Test that in the event of a graceful client transport shutdown, i.e., +// TestClientSendsAGoAwayFrame verifies that in the event of a graceful client transport shutdown, i.e., // clientTransport.Close(), client sends a goaway to the server with the correct // error code and debug data. func (s) TestClientSendsAGoAwayFrame(t *testing.T) { @@ -2736,7 +2737,6 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { // Wait until server receives the headers and settings frame as part of greet. <-greetDone ct.Close(errors.New("manually closed by client")) - t.Logf("Closed the client connection") select { case err := <-errorCh: if err != nil { @@ -2746,3 +2746,144 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { t.Errorf("Context timed out") } } + +var writeHangSignal chan struct{} + +type hangingConn struct { + net.Conn +} + +func (hc *hangingConn) Read(b []byte) (n int, err error) { + n, err = hc.Conn.Read(b) + return n, err +} + +func (hc *hangingConn) Write(b []byte) (n int, err error) { + n, err = hc.Conn.Write(b) + if n == goAwayFrameSize { // GOAWAY frame + writeHangSignal = make(chan struct{}) + time.Sleep(15 * time.Second) + } + return n, err +} + +func (hc *hangingConn) Close() error { + return hc.Conn.Close() +} + +func (hc *hangingConn) LocalAddr() net.Addr { + return hc.Conn.LocalAddr() +} + +func (hc *hangingConn) RemoteAddr() net.Addr { + return hc.Conn.RemoteAddr() +} + +func (hc *hangingConn) SetDeadline(t time.Time) error { + return hc.Conn.SetDeadline(t) +} + +func (hc *hangingConn) SetReadDeadline(t time.Time) error { + return hc.Conn.SetReadDeadline(t) +} + +func (hc *hangingConn) SetWriteDeadline(t time.Time) error { + return hc.Conn.SetWriteDeadline(t) +} + +func hangingDialer(_ context.Context, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &hangingConn{Conn: conn}, nil +} + +// TestClientCloseTimeoutOnHang verifies that in the event of a graceful +// client transport shutdown, i.e., clientTransport.Close(), if the conn hung +// forever, client should still be close itself and do not wait for long. +func (s) TestClientCloseTimeoutOnHang(t *testing.T) { + // Create a server. + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening: %v", err) + } + defer lis.Close() + // greetDone is used to notify when server is done greeting the client. + greetDone := make(chan struct{}) + // errorCh verifies that desired GOAWAY not received by server + errorCh := make(chan error) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // Launch the server. + go func() { + sconn, err := lis.Accept() + if err != nil { + t.Errorf("Error while accepting: %v", err) + } + defer sconn.Close() + if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil { + t.Errorf("Error while writing settings ack: %v", err) + return + } + sfr := http2.NewFramer(sconn, sconn) + if err := sfr.WriteSettings(); err != nil { + t.Errorf("Error while writing settings %v", err) + return + } + fr, _ := sfr.ReadFrame() + if _, ok := fr.(*http2.SettingsFrame); !ok { + t.Errorf("Expected settings frame, got %v", fr) + } + fr, _ = sfr.ReadFrame() + if fr, ok := fr.(*http2.SettingsFrame); !ok || !fr.IsAck() { + t.Errorf("Expected settings ACK frame, got %v", fr) + } + fr, _ = sfr.ReadFrame() + if fr, ok := fr.(*http2.HeadersFrame); !ok || !fr.Flags.Has(http2.FlagHeadersEndHeaders) { + t.Errorf("Expected Headers frame with END_HEADERS frame, got %v", fr) + } + close(greetDone) + + frame, err := sfr.ReadFrame() + if err != nil { + return + } + switch fr := frame.(type) { + case *http2.GoAwayFrame: + // Records that the server successfully received a GOAWAY frame. + goAwayFrame := fr + if goAwayFrame.ErrCode == http2.ErrCodeNo { + t.Logf("Received goAway frame from client") + close(errorCh) + } else { + errorCh <- fmt.Errorf("received unexpected goAway frame: %v", err) + close(errorCh) + } + return + default: + errorCh <- fmt.Errorf("server received a frame other than GOAWAY: %v", err) + close(errorCh) + return + } + }() + + ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{Dialer: hangingDialer}, func(GoAwayReason) {}) + if err != nil { + t.Fatalf("Error while creating client transport: %v", err) + } + _, err = ct.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("failed to open stream: %v", err) + } + // Wait until server receives the headers and settings frame as part of greet. + <-greetDone + ct.Close(errors.New("manually closed by client")) + defer close(writeHangSignal) + select { + case <-writeHangSignal: + t.Errorf("error: channel closed too early.") + case <-ctx.Done(): + } + +}