From 43b9fcf6fe9feb0ec67a7ddb01e5b542542b47c8 Mon Sep 17 00:00:00 2001 From: Michael Fraenkel Date: Sat, 6 Oct 2018 13:32:19 -0400 Subject: [PATCH] net/http: make Transport.MaxConnsPerHost work for HTTP/2 Treat HTTP/2 connections as an ongoing persistent connection. When we are told there is no cached connections, cleanup the associated connection and host connection count. Fixes #27753 Change-Id: I6b7bd915fc7819617cb5d3b35e46e225c75eda29 Reviewed-on: https://go-review.googlesource.com/c/go/+/140357 Reviewed-by: Brad Fitzpatrick --- src/net/http/transport.go | 19 ++++--- src/net/http/transport_test.go | 101 +++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 8 deletions(-) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index c94d2b50bddb0..88761909fd503 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -506,8 +506,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host - t.setReqCanceler(req, nil) // not cancelable with CancelRequest + t.putOrCloseIdleConn(pconn) + t.setReqCanceler(req, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) @@ -515,7 +515,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { if err == nil { return resp, nil } - if !pconn.shouldRetryRequest(req, err) { + if http2isNoCachedConnError(err) { + t.removeIdleConn(pconn) + t.decHostConnCount(cm.key()) // clean up the persistent connection + } else if !pconn.shouldRetryRequest(req, err) { // Issue 16465: return underlying net.Conn.Read error from peek, // as we've historically done. if e, ok := err.(transportReadFromServerError); ok { @@ -778,9 +781,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if pconn.isBroken() { return errConnBroken } - if pconn.alt != nil { - return errNotCachingH2Conn - } pconn.markReused() key := pconn.cacheKey @@ -829,7 +829,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if pconn.idleTimer != nil { pconn.idleTimer.Reset(t.IdleConnTimeout) } else { - pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle) + // idleTimer does not apply to HTTP/2 + if pconn.alt == nil { + pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle) + } } } pconn.idleAt = time.Now() @@ -1377,7 +1380,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil + return &persistConn{cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 857f0d59288f5..44a935960ee6a 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -588,6 +588,107 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { <-reqComplete } +func TestTransportMaxConnsPerHost(t *testing.T) { + defer afterTest(t) + if runtime.GOOS == "js" { + t.Skipf("skipping test on js/wasm") + } + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + testMaxConns := func(scheme string, ts *httptest.Server) { + defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatalf("ExportHttp2ConfigureTransport: %v", err) + } + + connCh := make(chan net.Conn, 1) + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) + connCh <- c + return c, err + } + + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, + } + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } + } + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("Too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt) + } + + (<-connCh).Close() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("Too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt) + } + } + + testMaxConns("http", httptest.NewServer(h)) + testMaxConns("https", httptest.NewTLSServer(h)) + + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + testMaxConns("http2", ts) +} + func TestTransportRemovesDeadIdleConnections(t *testing.T) { setParallel(t) defer afterTest(t)