Skip to content

Commit

Permalink
net/http: make Transport.MaxConnsPerHost work for HTTP/2
Browse files Browse the repository at this point in the history
Treat HTTP/2 connections as an ongoing persistent connection. When we
are told there is no cached connections, cleanup the associated
connection and host connection count.

Fixes #27753

Change-Id: I6b7bd915fc7819617cb5d3b35e46e225c75eda29
Reviewed-on: https://go-review.googlesource.com/c/go/+/140357
Reviewed-by: Brad Fitzpatrick <[email protected]>
  • Loading branch information
fraenkel authored and bradfitz committed Apr 30, 2019
1 parent c706d42 commit 43b9fcf
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,16 +506,19 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
t.putOrCloseIdleConn(pconn)
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
}
if err == nil {
return resp, nil
}
if !pconn.shouldRetryRequest(req, err) {
if http2isNoCachedConnError(err) {
t.removeIdleConn(pconn)
t.decHostConnCount(cm.key()) // clean up the persistent connection
} else if !pconn.shouldRetryRequest(req, err) {
// Issue 16465: return underlying net.Conn.Read error from peek,
// as we've historically done.
if e, ok := err.(transportReadFromServerError); ok {
Expand Down Expand Up @@ -778,9 +781,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if pconn.isBroken() {
return errConnBroken
}
if pconn.alt != nil {
return errNotCachingH2Conn
}
pconn.markReused()
key := pconn.cacheKey

Expand Down Expand Up @@ -829,7 +829,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if pconn.idleTimer != nil {
pconn.idleTimer.Reset(t.IdleConnTimeout)
} else {
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
// idleTimer does not apply to HTTP/2
if pconn.alt == nil {
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
}
}
}
pconn.idleAt = time.Now()
Expand Down Expand Up @@ -1377,7 +1380,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon

if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
return &persistConn{cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
}
}

Expand Down
101 changes: 101 additions & 0 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,107 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
<-reqComplete
}

func TestTransportMaxConnsPerHost(t *testing.T) {
defer afterTest(t)
if runtime.GOOS == "js" {
t.Skipf("skipping test on js/wasm")
}
h := HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("foo"))
if err != nil {
t.Fatalf("Write: %v", err)
}
})

testMaxConns := func(scheme string, ts *httptest.Server) {
defer ts.Close()

c := ts.Client()
tr := c.Transport.(*Transport)
tr.MaxConnsPerHost = 1
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
}

connCh := make(chan net.Conn, 1)
var dialCnt, gotConnCnt, tlsHandshakeCnt int32
tr.Dial = func(network, addr string) (net.Conn, error) {
atomic.AddInt32(&dialCnt, 1)
c, err := net.Dial(network, addr)
connCh <- c
return c, err
}

doReq := func() {
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
if !connInfo.Reused {
atomic.AddInt32(&gotConnCnt, 1)
}
},
TLSHandshakeStart: func() {
atomic.AddInt32(&tlsHandshakeCnt, 1)
},
}
req, _ := NewRequest("GET", ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))

resp, err := c.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("read body failed: %v", err)
}
}

wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
doReq()
}()
}
wg.Wait()

expected := int32(tr.MaxConnsPerHost)
if dialCnt != expected {
t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
}
if gotConnCnt != expected {
t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
}
if ts.TLS != nil && tlsHandshakeCnt != expected {
t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
}

(<-connCh).Close()

doReq()
expected++
if dialCnt != expected {
t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
}
if gotConnCnt != expected {
t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
}
if ts.TLS != nil && tlsHandshakeCnt != expected {
t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
}
}

testMaxConns("http", httptest.NewServer(h))
testMaxConns("https", httptest.NewTLSServer(h))

ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
ts.StartTLS()
testMaxConns("http2", ts)
}

func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)
Expand Down

0 comments on commit 43b9fcf

Please sign in to comment.