diff --git a/http2/transport.go b/http2/transport.go index c51a73c06..65eb4c142 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -108,6 +108,19 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // ReadIdleTimeout is the timeout after which a health check using ping + // frame will be carried out if no frame is received on the connection. + // Note that a ping response will is considered a received frame, so if + // there is no other traffic on the connection, the health check will + // be performed every ReadIdleTimeout interval. + // If zero, no health check is performed. + ReadIdleTimeout time.Duration + + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // Defaults to 15s. + PingTimeout time.Duration + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -131,6 +144,14 @@ func (t *Transport) disableCompression() bool { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } +func (t *Transport) pingTimeout() time.Duration { + if t.PingTimeout == 0 { + return 15 * time.Second + } + return t.PingTimeout + +} + // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. // It returns an error if t1 has already been HTTP/2-enabled. func ConfigureTransport(t1 *http.Transport) error { @@ -674,6 +695,20 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return cc, nil } +func (cc *ClientConn) healthCheck() { + pingTimeout := cc.t.pingTimeout() + // We don't need to periodically ping in the health check, because the readLoop of ClientConn will + // trigger the healthCheck again if there is no frame received. + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return + } +} + func (cc *ClientConn) setGoAway(f *GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -834,14 +869,12 @@ func (cc *ClientConn) sendGoAway() error { return nil } -// Close closes the client connection immediately. -// -// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. -func (cc *ClientConn) Close() error { +// closes the client connection immediately. In-flight requests are interrupted. +// err is sent to streams. +func (cc *ClientConn) closeForError(err error) error { cc.mu.Lock() defer cc.cond.Broadcast() defer cc.mu.Unlock() - err := errors.New("http2: client connection force closed via ClientConn.Close") for id, cs := range cc.streams { select { case cs.resc <- resAndError{err: err}: @@ -854,6 +887,20 @@ func (cc *ClientConn) Close() error { return cc.tconn.Close() } +// Close closes the client connection immediately. +// +// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. +func (cc *ClientConn) Close() error { + err := errors.New("http2: client connection force closed via ClientConn.Close") + return cc.closeForError(err) +} + +// closes the client connection immediately. In-flight requests are interrupted. +func (cc *ClientConn) closeForLostPing() error { + err := errors.New("http2: client connection lost") + return cc.closeForError(err) +} + const maxAllocFrameSize = 512 << 10 // frameBuffer returns a scratch buffer suitable for writing DATA frames. @@ -1706,8 +1753,17 @@ func (rl *clientConnReadLoop) run() error { rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse gotReply := false // ever saw a HEADERS reply gotSettings := false + readIdleTimeout := cc.t.ReadIdleTimeout + var t *time.Timer + if readIdleTimeout != 0 { + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) + defer t.Stop() + } for { f, err := cc.fr.ReadFrame() + if t != nil { + t.Reset(readIdleTimeout) + } if err != nil { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } diff --git a/http2/transport_test.go b/http2/transport_test.go index 89ce5e5ad..397740ed3 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3244,6 +3244,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { req.Header = http.Header{} } +func TestTransportCloseAfterLostPing(t *testing.T) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.tr.PingTimeout = 1 * time.Second + ct.tr.ReadIdleTimeout = 1 * time.Second + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + _, err := ct.tr.RoundTrip(req) + if err == nil || !strings.Contains(err.Error(), "client connection lost") { + return fmt.Errorf("expected to get error about \"connection lost\", got %v", err) + } + return nil + } + ct.server = func() error { + ct.greet() + <-clientDone + return nil + } + ct.run() +} + +func TestTransportPingWhenReading(t *testing.T) { + testCases := []struct { + name string + readIdleTimeout time.Duration + serverResponseInterval time.Duration + expectedPingCount int + }{ + { + name: "two pings in each serverResponseInterval", + readIdleTimeout: 400 * time.Millisecond, + serverResponseInterval: 1000 * time.Millisecond, + expectedPingCount: 4, + }, + { + name: "one ping in each serverResponseInterval", + readIdleTimeout: 700 * time.Millisecond, + serverResponseInterval: 1000 * time.Millisecond, + expectedPingCount: 2, + }, + { + name: "zero ping in each serverResponseInterval", + readIdleTimeout: 1000 * time.Millisecond, + serverResponseInterval: 500 * time.Millisecond, + expectedPingCount: 0, + }, + { + name: "0 readIdleTimeout means no ping", + readIdleTimeout: 0 * time.Millisecond, + serverResponseInterval: 500 * time.Millisecond, + expectedPingCount: 0, + }, + } + + for _, tc := range testCases { + tc := tc // capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount) + }) + } +} + +func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) { + var pingCount int + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.tr.PingTimeout = 10 * time.Millisecond + ct.tr.ReadIdleTimeout = readIdleTimeout + // guards the ct.fr.Write + var wmu sync.Mutex + + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) + } + _, err = ioutil.ReadAll(res.Body) + return err + } + + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + go func() { + for i := 0; i < 2; i++ { + wmu.Lock() + if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil { + wmu.Unlock() + t.Error(err) + return + } + wmu.Unlock() + time.Sleep(serverResponseInterval) + } + wmu.Lock() + if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil { + wmu.Unlock() + t.Error(err) + return + } + wmu.Unlock() + }() + case *PingFrame: + pingCount++ + wmu.Lock() + if err := ct.fr.WritePing(true, f.Data); err != nil { + wmu.Unlock() + return err + } + wmu.Unlock() + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() + if e, a := expectedPingCount, pingCount; e != a { + t.Errorf("expected receiving %d pings, got %d pings", e, a) + + } +} + func TestTransportRetryAfterGOAWAY(t *testing.T) { var dialer struct { sync.Mutex