Skip to content

Commit

Permalink
Only check connection health if the connection read loop has been idle
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao Xu committed Oct 19, 2019
1 parent bc0d6c6 commit b2e6f87
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 63 deletions.
53 changes: 0 additions & 53 deletions http2/client_conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
package http2

import (
"context"
"crypto/tls"
"net/http"
"sync"
"time"
)

// ClientConnPool manages a pool of HTTP/2 client connections.
Expand Down Expand Up @@ -43,16 +41,6 @@ type clientConnPool struct {
dialing map[string]*dialCall // currently in-flight dials
keys map[*ClientConn][]string
addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls

// TODO: figure out a way to allow user to configure pingPeriod and
// pingTimeout.
pingPeriod time.Duration // how often pings are sent on idle
// connections. The connection will be closed if response is not
// received within pingTimeout. 0 means no periodic pings.
pingTimeout time.Duration // connection will be force closed if a Ping
// response is not received within pingTimeout.
pingStops map[*ClientConn]chan struct{} // channels to stop the
// periodic Pings.
}

func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
Expand Down Expand Up @@ -231,54 +219,13 @@ func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
if p.keys == nil {
p.keys = make(map[*ClientConn][]string)
}
if p.pingStops == nil {
p.pingStops = make(map[*ClientConn]chan struct{})
}
p.conns[key] = append(p.conns[key], cc)
p.keys[cc] = append(p.keys[cc], key)
if p.pingPeriod != 0 {
p.pingStops[cc] = p.pingConnection(key, cc)
}
}

// TODO: ping all connections at the same tick to save tickers?
func (p *clientConnPool) pingConnection(key string, cc *ClientConn) chan struct{} {
done := make(chan struct{})
go func() {
ticker := time.NewTicker(p.pingPeriod)
defer ticker.Stop()
for {
select {
case <-done:
return
default:
}

select {
case <-done:
return
case <-ticker.C:
ctx, _ := context.WithTimeout(context.Background(), p.pingTimeout)
err := cc.Ping(ctx)
if err != nil {
cc.closeForLostPing()
p.MarkDead(cc)
}
}
}
}()
return done
}

func (p *clientConnPool) MarkDead(cc *ClientConn) {
p.mu.Lock()
defer p.mu.Unlock()

if done, ok := p.pingStops[cc]; ok {
close(done)
delete(p.pingStops, cc)
}

for _, key := range p.keys[cc] {
vv, ok := p.conns[key]
if !ok {
Expand Down
97 changes: 92 additions & 5 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool

// PingPeriod controls how often pings are sent on idle connections to
// check the liveness of the connection. The connection will be closed
// if response is not received within PingTimeout.
// 0 means no periodic pings. Defaults to 0.
PingPeriod time.Duration
// PingTimeout is the timeout after which the connection will be closed
// if a response to Ping is not received.
// 0 means no periodic pings. Defaults to 0.
PingTimeout time.Duration
// ReadIdleTimeout is the timeout after which the periodic ping for
// connection health check will begin if no frame is received on the
// connection.
// The health check will stop once there is frame received on the
// connection.
// Defaults to 60s.
ReadIdleTimeout time.Duration

// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
Expand Down Expand Up @@ -140,10 +157,6 @@ func ConfigureTransport(t1 *http.Transport) error {

func configureTransport(t1 *http.Transport) (*Transport, error) {
connPool := new(clientConnPool)
// TODO: figure out a way to allow user to configure pingPeriod and
// pingTimeout.
connPool.pingPeriod = 5 * time.Second
connPool.pingTimeout = 1 * time.Second
t2 := &Transport{
ConnPool: noDialClientConnPool{connPool},
t1: t1,
Expand Down Expand Up @@ -243,6 +256,8 @@ type ClientConn struct {

wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
werr error // first write error that has occurred

healthCheckStopCh chan struct{}
}

// clientStream is the state for a single HTTP/2 stream. One of these
Expand Down Expand Up @@ -678,6 +693,49 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return cc, nil
}

func (cc *ClientConn) healthCheck(stop chan struct{}) {
pingPeriod := cc.t.PingPeriod
pingTimeout := cc.t.PingTimeout
if pingPeriod == 0 || pingTimeout == 0 {
return
}
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
err := cc.Ping(ctx)
cancel()
if err != nil {
cc.closeForLostPing()
cc.t.connPool().MarkDead(cc)
return
}
}
}
}

func (cc *ClientConn) startHealthCheck() {
if cc.healthCheckStopCh != nil {
// a health check is already running
return
}
cc.healthCheckStopCh = make(chan struct{})
go cc.healthCheck(cc.healthCheckStopCh)
}

func (cc *ClientConn) stopHealthCheck() {
if cc.healthCheckStopCh == nil {
// no health check running
return
}
close(cc.healthCheckStopCh)
cc.healthCheckStopCh = nil
}

func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
Expand Down Expand Up @@ -1717,13 +1775,42 @@ func (rl *clientConnReadLoop) cleanup() {
cc.mu.Unlock()
}

type frameAndError struct {
f Frame
err error
}

func nonBlockingReadFrame(fr *Framer) chan frameAndError {
feCh := make(chan frameAndError)
go func() {
f, err := fr.ReadFrame()
feCh <- frameAndError{f: f, err: err}
}()
return feCh
}

func (rl *clientConnReadLoop) run() error {
cc := rl.cc
rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
gotReply := false // ever saw a HEADERS reply
gotSettings := false
for {
f, err := cc.fr.ReadFrame()
var fe frameAndError
feCh := nonBlockingReadFrame(cc.fr)
to := cc.t.ReadIdleTimeout
if to == 0 {
to = 60 * time.Second
}
readIdleTimer := time.NewTimer(to)
select {
case fe = <-feCh:
cc.stopHealthCheck()
readIdleTimer.Stop()
case <-readIdleTimer.C:
cc.startHealthCheck()
fe = <-feCh
}
f, err := fe.f, fe.err
if err != nil {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
Expand Down
137 changes: 132 additions & 5 deletions http2/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3247,11 +3247,9 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
func TestTransportCloseAfterLostPing(t *testing.T) {
clientDone := make(chan struct{})
ct := newClientTester(t)
connPool := new(clientConnPool)
connPool.pingPeriod = 1 * time.Second
connPool.pingTimeout = 100 * time.Millisecond
connPool.t = ct.tr
ct.tr.ConnPool = connPool
ct.tr.PingPeriod = 1 * time.Second
ct.tr.PingTimeout = 1 * time.Second
ct.tr.ReadIdleTimeout = 1 * time.Second
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
Expand All @@ -3270,6 +3268,135 @@ func TestTransportCloseAfterLostPing(t *testing.T) {
ct.run()
}

func TestTransportPingWhenReading(t *testing.T) {
testTransportPingWhenReading(t, 50*time.Millisecond, 100*time.Millisecond)
testTransportPingWhenReading(t, 100*time.Millisecond, 50*time.Millisecond)
}

func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration) {
var pinged bool
clientBodyBytes := []byte("hello, this is client")
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingPeriod = 10 * time.Millisecond
ct.tr.PingTimeout = 10 * time.Millisecond
ct.tr.ReadIdleTimeout = readIdleTimeout
// guards the ct.fr.Write
var wmu sync.Mutex
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)

req, err := http.NewRequest("PUT", "https://dummy.tld/", bytes.NewReader(clientBodyBytes))
if err != nil {
return err
}
res, err := ct.tr.RoundTrip(req)
if err != nil {
return fmt.Errorf("RoundTrip: %v", err)
}
defer res.Body.Close()
if res.StatusCode != 200 {
return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
}
_, err = ioutil.ReadAll(res.Body)
return err
}
ct.server = func() error {
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
var dataRecv int
var closed bool
for {
f, err := ct.fr.ReadFrame()
if err != nil {
select {
case <-clientDone:
// If the client's done, it
// will have reported any
// errors on its side.
return nil
default:
return err
}
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame, *HeadersFrame:
case *DataFrame:
dataLen := len(f.Data())
if dataLen > 0 {
err := func() error {
wmu.Lock()
defer wmu.Unlock()
if dataRecv == 0 {
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: f.StreamID,
EndHeaders: true,
EndStream: false,
BlockFragment: buf.Bytes(),
})
}
if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
return err
}
if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
return err
}
return nil
}()
if err != nil {
return err
}
}
dataRecv += dataLen

if !closed && dataRecv == len(clientBodyBytes) {
closed = true
go func() {
for i := 0; i < 10; i++ {
wmu.Lock()
if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
wmu.Unlock()
t.Error(err)
return
}
wmu.Unlock()
time.Sleep(serverResponseInterval)
}
wmu.Lock()
if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server frame")); err != nil {
wmu.Unlock()
t.Error(err)
return
}
wmu.Unlock()
}()
}
case *PingFrame:
pinged = true
if serverResponseInterval > readIdleTimeout {
wmu.Lock()
if err := ct.fr.WritePing(true, f.Data); err != nil {
wmu.Unlock()
return err
}
wmu.Unlock()
} else {
return fmt.Errorf("Unexpected ping frame: %v", f)
}
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
}
}
ct.run()
if serverResponseInterval > readIdleTimeout && !pinged {
t.Errorf("expect ping")
}
}

func TestTransportRetryAfterGOAWAY(t *testing.T) {
var dialer struct {
sync.Mutex
Expand Down

0 comments on commit b2e6f87

Please sign in to comment.