diff --git a/http2/transport.go b/http2/transport.go index 46e169fe4..8b1e6fb62 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -108,6 +108,23 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // PingPeriod controls how often pings are sent on idle connections to + // check the liveness of the connection. The connection will be closed + // if response is not received within PingTimeout. + // 0 means no periodic pings. Defaults to 0. + PingPeriod time.Duration + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // 0 means no periodic pings. Defaults to 0. + PingTimeout time.Duration + // ReadIdleTimeout is the timeout after which the periodic ping for + // connection health check will begin if no frame is received on the + // connection. + // The health check will stop once there is frame received on the + // connection. + // Defaults to 60s. + ReadIdleTimeout time.Duration + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -140,10 +157,6 @@ func ConfigureTransport(t1 *http.Transport) error { func configureTransport(t1 *http.Transport) (*Transport, error) { connPool := new(clientConnPool) - // TODO: figure out a way to allow user to configure pingPeriod and - // pingTimeout. - connPool.pingPeriod = 5 * time.Second - connPool.pingTimeout = 1 * time.Second t2 := &Transport{ ConnPool: noDialClientConnPool{connPool}, t1: t1, @@ -244,7 +257,7 @@ type ClientConn struct { wmu sync.Mutex // held while writing; acquire AFTER mu if holding both werr error // first write error that has occurred - healthCheckCancel chan struct{} + healthCheckStopCh chan struct{} } // clientStream is the state for a single HTTP/2 stream. One of these @@ -680,42 +693,47 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return cc, nil } -func (cc *ClientConn) healthCheck(cancel chan struct{}) { - // TODO: CHAO: configurable - pingPeriod := 15 * time.Second +func (cc *ClientConn) healthCheck(stop chan struct{}) { + pingPeriod := cc.t.PingPeriod + pingTimeout := cc.t.PingTimeout + if pingPeriod == 0 || pingTimeout == 0 { + return + } ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { - case <-cancel: + case <-stop: return case <-ticker.C: - ctx, _ := context.WithTimeout(context.Background(), p.pingTimeout) + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) err := cc.Ping(ctx) + cancel() if err != nil { cc.closeForLostPing() cc.t.connPool().MarkDead(cc) + return } } } } func (cc *ClientConn) startHealthCheck() { - if cc.healthCheckCancel != nil { + if cc.healthCheckStopCh != nil { // a health check is already running return } - cc.healthCheckCancel = make(chan struct{}) - go cc.healthCheck(cc.healthCheckCancel) + cc.healthCheckStopCh = make(chan struct{}) + go cc.healthCheck(cc.healthCheckStopCh) } func (cc *ClientConn) stopHealthCheck() { - if cc.healthCheckCancel == nil { + if cc.healthCheckStopCh == nil { // no health check running return } - close(cc.healthCheckCancel) - cc.healthCheckCancel = nil + close(cc.healthCheckStopCh) + cc.healthCheckStopCh = nil } func (cc *ClientConn) setGoAway(f *GoAwayFrame) { @@ -1757,25 +1775,16 @@ func (rl *clientConnReadLoop) cleanup() { cc.mu.Unlock() } -func ReadFrameAndProbablyStartOrStopPingLoop() { - select { - case <-timer: - // start ping loop - case <-read: - // stop ping loop - } -} - type frameAndError struct { f Frame err error } -func nonBlockingReadFrame(f *Framer) chan frameAndError { - feCh := make(chan FrameAndError) +func nonBlockingReadFrame(fr *Framer) chan frameAndError { + feCh := make(chan frameAndError) go func() { f, err := fr.ReadFrame() - feCh <- frameAndError{frame: f, err: err} + feCh <- frameAndError{f: f, err: err} }() return feCh } @@ -1788,15 +1797,18 @@ func (rl *clientConnReadLoop) run() error { for { var fe frameAndError feCh := nonBlockingReadFrame(cc.fr) - // TODO: CHAO: make it configurable - readIdleTimer := time.NewTimer(15 * time.Second) + to := cc.t.ReadIdleTimeout + if to == 0 { + to = 60 * time.Second + } + readIdleTimer := time.NewTimer(to) select { - case fe <- feCh: + case fe = <-feCh: cc.stopHealthCheck() readIdleTimer.Stop() case <-readIdleTimer.C: cc.startHealthCheck() - fe <- feCh + fe = <-feCh } f, err := fe.f, fe.err if err != nil { diff --git a/http2/transport_test.go b/http2/transport_test.go index de0c58272..7032a0e06 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3247,11 +3247,9 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { func TestTransportCloseAfterLostPing(t *testing.T) { clientDone := make(chan struct{}) ct := newClientTester(t) - connPool := new(clientConnPool) - connPool.pingPeriod = 1 * time.Second - connPool.pingTimeout = 100 * time.Millisecond - connPool.t = ct.tr - ct.tr.ConnPool = connPool + ct.tr.PingPeriod = 1 * time.Second + ct.tr.PingTimeout = 1 * time.Second + ct.tr.ReadIdleTimeout = 1 * time.Second ct.client = func() error { defer ct.cc.(*net.TCPConn).CloseWrite() defer close(clientDone) @@ -3270,6 +3268,135 @@ func TestTransportCloseAfterLostPing(t *testing.T) { ct.run() } +func TestTransportPingWhenReading(t *testing.T) { + testTransportPingWhenReading(t, 50*time.Millisecond, 100*time.Millisecond) + testTransportPingWhenReading(t, 100*time.Millisecond, 50*time.Millisecond) +} + +func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration) { + var pinged bool + clientBodyBytes := []byte("hello, this is client") + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.tr.PingPeriod = 10 * time.Millisecond + ct.tr.PingTimeout = 10 * time.Millisecond + ct.tr.ReadIdleTimeout = readIdleTimeout + // guards the ct.fr.Write + var wmu sync.Mutex + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + req, err := http.NewRequest("PUT", "https://dummy.tld/", bytes.NewReader(clientBodyBytes)) + if err != nil { + return err + } + res, err := ct.tr.RoundTrip(req) + if err != nil { + return fmt.Errorf("RoundTrip: %v", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200) + } + _, err = ioutil.ReadAll(res.Body) + return err + } + ct.server = func() error { + ct.greet() + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + var dataRecv int + var closed bool + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame, *HeadersFrame: + case *DataFrame: + dataLen := len(f.Data()) + if dataLen > 0 { + err := func() error { + wmu.Lock() + defer wmu.Unlock() + if dataRecv == 0 { + enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: f.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil { + return err + } + if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil { + return err + } + return nil + }() + if err != nil { + return err + } + } + dataRecv += dataLen + + if !closed && dataRecv == len(clientBodyBytes) { + closed = true + go func() { + for i := 0; i < 10; i++ { + wmu.Lock() + if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil { + wmu.Unlock() + t.Error(err) + return + } + wmu.Unlock() + time.Sleep(serverResponseInterval) + } + wmu.Lock() + if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server frame")); err != nil { + wmu.Unlock() + t.Error(err) + return + } + wmu.Unlock() + }() + } + case *PingFrame: + pinged = true + if serverResponseInterval > readIdleTimeout { + wmu.Lock() + if err := ct.fr.WritePing(true, f.Data); err != nil { + wmu.Unlock() + return err + } + wmu.Unlock() + } else { + return fmt.Errorf("Unexpected ping frame: %v", f) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() + if serverResponseInterval > readIdleTimeout && !pinged { + t.Errorf("expect ping") + } +} + func TestTransportRetryAfterGOAWAY(t *testing.T) { var dialer struct { sync.Mutex