diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go index 2bb701897..f4d9b5ece 100644 --- a/http2/client_conn_pool.go +++ b/http2/client_conn_pool.go @@ -7,11 +7,9 @@ package http2 import ( - "context" "crypto/tls" "net/http" "sync" - "time" ) // ClientConnPool manages a pool of HTTP/2 client connections. @@ -43,16 +41,6 @@ type clientConnPool struct { dialing map[string]*dialCall // currently in-flight dials keys map[*ClientConn][]string addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls - - // TODO: figure out a way to allow user to configure pingPeriod and - // pingTimeout. - pingPeriod time.Duration // how often pings are sent on idle - // connections. The connection will be closed if response is not - // received within pingTimeout. 0 means no periodic pings. - pingTimeout time.Duration // connection will be force closed if a Ping - // response is not received within pingTimeout. - pingStops map[*ClientConn]chan struct{} // channels to stop the - // periodic Pings. } func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { @@ -231,54 +219,13 @@ func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) { if p.keys == nil { p.keys = make(map[*ClientConn][]string) } - if p.pingStops == nil { - p.pingStops = make(map[*ClientConn]chan struct{}) - } p.conns[key] = append(p.conns[key], cc) p.keys[cc] = append(p.keys[cc], key) - if p.pingPeriod != 0 { - p.pingStops[cc] = p.pingConnection(key, cc) - } -} - -// TODO: ping all connections at the same tick to save tickers? -func (p *clientConnPool) pingConnection(key string, cc *ClientConn) chan struct{} { - done := make(chan struct{}) - go func() { - ticker := time.NewTicker(p.pingPeriod) - defer ticker.Stop() - for { - select { - case <-done: - return - default: - } - - select { - case <-done: - return - case <-ticker.C: - ctx, _ := context.WithTimeout(context.Background(), p.pingTimeout) - err := cc.Ping(ctx) - if err != nil { - cc.closeForLostPing() - p.MarkDead(cc) - } - } - } - }() - return done } func (p *clientConnPool) MarkDead(cc *ClientConn) { p.mu.Lock() defer p.mu.Unlock() - - if done, ok := p.pingStops[cc]; ok { - close(done) - delete(p.pingStops, cc) - } - for _, key := range p.keys[cc] { vv, ok := p.conns[key] if !ok { diff --git a/http2/transport.go b/http2/transport.go index a47ab780b..8b1e6fb62 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -108,6 +108,23 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // PingPeriod controls how often pings are sent on idle connections to + // check the liveness of the connection. The connection will be closed + // if response is not received within PingTimeout. + // 0 means no periodic pings. Defaults to 0. + PingPeriod time.Duration + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // 0 means no periodic pings. Defaults to 0. + PingTimeout time.Duration + // ReadIdleTimeout is the timeout after which the periodic ping for + // connection health check will begin if no frame is received on the + // connection. + // The health check will stop once there is frame received on the + // connection. + // Defaults to 60s. + ReadIdleTimeout time.Duration + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -140,10 +157,6 @@ func ConfigureTransport(t1 *http.Transport) error { func configureTransport(t1 *http.Transport) (*Transport, error) { connPool := new(clientConnPool) - // TODO: figure out a way to allow user to configure pingPeriod and - // pingTimeout. - connPool.pingPeriod = 5 * time.Second - connPool.pingTimeout = 1 * time.Second t2 := &Transport{ ConnPool: noDialClientConnPool{connPool}, t1: t1, @@ -243,6 +256,8 @@ type ClientConn struct { wmu sync.Mutex // held while writing; acquire AFTER mu if holding both werr error // first write error that has occurred + + healthCheckStopCh chan struct{} } // clientStream is the state for a single HTTP/2 stream. One of these @@ -678,6 +693,49 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return cc, nil } +func (cc *ClientConn) healthCheck(stop chan struct{}) { + pingPeriod := cc.t.PingPeriod + pingTimeout := cc.t.PingTimeout + if pingPeriod == 0 || pingTimeout == 0 { + return + } + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + err := cc.Ping(ctx) + cancel() + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return + } + } + } +} + +func (cc *ClientConn) startHealthCheck() { + if cc.healthCheckStopCh != nil { + // a health check is already running + return + } + cc.healthCheckStopCh = make(chan struct{}) + go cc.healthCheck(cc.healthCheckStopCh) +} + +func (cc *ClientConn) stopHealthCheck() { + if cc.healthCheckStopCh == nil { + // no health check running + return + } + close(cc.healthCheckStopCh) + cc.healthCheckStopCh = nil +} + func (cc *ClientConn) setGoAway(f *GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -1717,13 +1775,42 @@ func (rl *clientConnReadLoop) cleanup() { cc.mu.Unlock() } +type frameAndError struct { + f Frame + err error +} + +func nonBlockingReadFrame(fr *Framer) chan frameAndError { + feCh := make(chan frameAndError) + go func() { + f, err := fr.ReadFrame() + feCh <- frameAndError{f: f, err: err} + }() + return feCh +} + func (rl *clientConnReadLoop) run() error { cc := rl.cc rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse gotReply := false // ever saw a HEADERS reply gotSettings := false for { - f, err := cc.fr.ReadFrame() + var fe frameAndError + feCh := nonBlockingReadFrame(cc.fr) + to := cc.t.ReadIdleTimeout + if to == 0 { + to = 60 * time.Second + } + readIdleTimer := time.NewTimer(to) + select { + case fe = <-feCh: + cc.stopHealthCheck() + readIdleTimer.Stop() + case <-readIdleTimer.C: + cc.startHealthCheck() + fe = <-feCh + } + f, err := fe.f, fe.err if err != nil { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } diff --git a/http2/transport_test.go b/http2/transport_test.go index de0c58272..7032a0e06 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3247,11 +3247,9 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { func TestTransportCloseAfterLostPing(t *testing.T) { clientDone := make(chan struct{}) ct := newClientTester(t) - connPool := new(clientConnPool) - connPool.pingPeriod = 1 * time.Second - connPool.pingTimeout = 100 * time.Millisecond - connPool.t = ct.tr - ct.tr.ConnPool = connPool + ct.tr.PingPeriod = 1 * time.Second + ct.tr.PingTimeout = 1 * time.Second + ct.tr.ReadIdleTimeout = 1 * time.Second ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() defer close(clientDone) @@ -3270,6 +3268,135 @@ func TestTransportCloseAfterLostPing(t *testing.T) { ct.run() } +func TestTransportPingWhenReading(t *testing.T) { + testTransportPingWhenReading(t, 50*time.Millisecond, 100*time.Millisecond) + testTransportPingWhenReading(t, 100*time.Millisecond, 50*time.Millisecond) +} + +func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration) { + var pinged bool + clientBodyBytes := []byte("hello, this is client") + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.tr.PingPeriod = 10 * time.Millisecond + ct.tr.PingTimeout = 10 * time.Millisecond + ct.tr.ReadIdleTimeout = readIdleTimeout + // guards the ct.fr.Write + var wmu sync.Mutex + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + req, err := http.NewRequest("PUT", "https://dummy.tld/", bytes.NewReader(clientBodyBytes)) + if err != nil { + return err + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) + } + _, err = ioutil.ReadAll(res.Body) + return err + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + var dataRecv int + var closed bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame, *HeadersFrame: + case *DataFrame: + dataLen := len(f.Data()) + if dataLen > 0 { + err := func() error { + wmu.Lock() + defer wmu.Unlock() + if dataRecv == 0 { + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + return nil + }() + if err != nil { + return err + } + } + dataRecv += dataLen + + if !closed && dataRecv == len(clientBodyBytes) { + closed = true + go func() { + for i := 0; i < 10; i++ { + wmu.Lock() + if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil { + wmu.Unlock() + t.Error(err) + return + } + wmu.Unlock() + time.Sleep(serverResponseInterval) + } + wmu.Lock() + if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server frame")); err != nil { + wmu.Unlock() + t.Error(err) + return + } + wmu.Unlock() + }() + } + case *PingFrame: + pinged = true + if serverResponseInterval > readIdleTimeout { + wmu.Lock() + if err := ct.fr.WritePing(true, f.Data); err != nil { + wmu.Unlock() + return err + } + wmu.Unlock() + } else { + return fmt.Errorf("Unexpected ping frame: %v", f) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() + if serverResponseInterval > readIdleTimeout && !pinged { + t.Errorf("expect ping") + } +} + func TestTransportRetryAfterGOAWAY(t *testing.T) { var dialer struct { sync.Mutex