Skip to content

Commit

Permalink
Connection draining for h2c connections 🚽
Browse files Browse the repository at this point in the history
Connections that were upgraded to HTTP/2 by use of the H2cFilter can now be drained properly. The implementation is pretty ugly because Go does not have native support for connection draining on h2c connections, and as per golang/go#26682 this isn't a priority for the project.
  • Loading branch information
obeattie committed Nov 15, 2018
1 parent a049e9b commit 842491d
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 57 deletions.
8 changes: 4 additions & 4 deletions e2e_http1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ type http1Flavour struct {
T *testing.T
}

func (f http1Flavour) Serve(svc Service) Server {
func (f http1Flavour) Serve(svc Service) *Server {
s, err := Listen(svc, "localhost:0")
require.NoError(f.T, err)
return s
}

func (f http1Flavour) URL(s Server) string {
func (f http1Flavour) URL(s *Server) string {
return fmt.Sprintf("http://%s", s.Listener().Addr())
}

Expand All @@ -31,7 +31,7 @@ type http1TLSFlavour struct {
cert tls.Certificate
}

func (f http1TLSFlavour) Serve(svc Service) Server {
func (f http1TLSFlavour) Serve(svc Service) *Server {
l, err := tls.Listen("tcp", "localhost:0", &tls.Config{
Certificates: []tls.Certificate{f.cert},
ClientAuth: tls.NoClientCert})
Expand All @@ -41,7 +41,7 @@ func (f http1TLSFlavour) Serve(svc Service) Server {
return s
}

func (f http1TLSFlavour) URL(s Server) string {
func (f http1TLSFlavour) URL(s *Server) string {
return fmt.Sprintf("https://%s", s.Listener().Addr())
}

Expand Down
8 changes: 4 additions & 4 deletions e2e_http2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ type http2H2cFlavour struct {
client Service
}

func (f http2H2cFlavour) Serve(svc Service) Server {
func (f http2H2cFlavour) Serve(svc Service) *Server {
svc = svc.Filter(H2cFilter)
s, err := Listen(svc, "localhost:0")
require.NoError(f.T, err)
return s
}

func (f http2H2cFlavour) URL(s Server) string {
func (f http2H2cFlavour) URL(s *Server) string {
return fmt.Sprintf("http://%s", s.Listener().Addr())
}

Expand All @@ -34,7 +34,7 @@ type http2H2Flavour struct {
cert tls.Certificate
}

func (f http2H2Flavour) Serve(svc Service) Server {
func (f http2H2Flavour) Serve(svc Service) *Server {
l, err := tls.Listen("tcp", "localhost:0", &tls.Config{
Certificates: []tls.Certificate{f.cert},
ClientAuth: tls.NoClientCert,
Expand All @@ -45,7 +45,7 @@ func (f http2H2Flavour) Serve(svc Service) Server {
return s
}

func (f http2H2Flavour) URL(s Server) string {
func (f http2H2Flavour) URL(s *Server) string {
return fmt.Sprintf("https://%s", s.Listener().Addr())
}

Expand Down
34 changes: 16 additions & 18 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import (
)

type e2eFlavour interface {
Serve(Service) Server
URL(Server) string
Serve(Service) *Server
URL(*Server) string
Proto() string
}

Expand Down Expand Up @@ -79,7 +79,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour)
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return net.Dial(network, addr)
}}
defer transport.CloseIdleConnections()
Client = HttpService(transport).Filter(ErrorFilter)
impl(t, http2H2cFlavour{T: t})
})
Expand All @@ -92,7 +91,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour)
AllowHTTP: false,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true}}
defer transport.CloseIdleConnections()
Client = HttpService(transport).Filter(ErrorFilter)
impl(t, http2H2Flavour{
T: t,
Expand All @@ -115,7 +113,7 @@ func TestE2E(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{
"a": "b"})
Expand Down Expand Up @@ -156,7 +154,7 @@ func TestE2EStreaming(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

req := NewRequest(ctx, "GET", flav.URL(s), nil)
rsp := req.Send().Response()
Expand Down Expand Up @@ -190,7 +188,7 @@ func TestE2EStreaming(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

req := NewRequest(ctx, "GET", flav.URL(s), nil)
reqS := Streamer()
Expand Down Expand Up @@ -233,7 +231,7 @@ func TestE2EDomainSocket(t *testing.T) {

s, err := Serve(svc, l)
require.NoError(t, err)
defer s.Stop()
defer s.Stop(context.Background())

sockTransport := &httpcontrol.Transport{
Dial: func(network, address string) (net.Conn, error) {
Expand Down Expand Up @@ -262,7 +260,7 @@ func TestE2EError(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

req := NewRequest(ctx, "GET", flav.URL(s), nil)
rsp := req.Send().Response()
Expand Down Expand Up @@ -290,7 +288,7 @@ func TestE2ECancellation(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
req := NewRequest(ctx, "GET", flav.URL(s), nil)
Expand Down Expand Up @@ -322,7 +320,7 @@ func TestE2ENoFollowRedirect(t *testing.T) {
return rsp
})
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -354,14 +352,14 @@ func TestE2EProxiedStreamer(t *testing.T) {
return rsp
})
s := flav.Serve(downstream)
defer s.Stop()
defer s.Stop(context.Background())

proxy := Service(func(req Request) Response {
proxyReq := NewRequest(req, "GET", flav.URL(s), nil)
return proxyReq.Send().Response()
})
ps := flav.Serve(proxy)
defer ps.Stop()
defer ps.Stop(context.Background())

req := NewRequest(ctx, "GET", flav.URL(ps), nil)
rsp := req.Send().Response()
Expand Down Expand Up @@ -400,7 +398,7 @@ func TestE2EInfiniteContext(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{
"a": "b"})
Expand Down Expand Up @@ -435,7 +433,7 @@ func TestE2ERequestAutoChunking(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -484,7 +482,7 @@ func TestE2EResponseAutoChunking(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -548,7 +546,7 @@ func TestE2EStreamingCancellation(t *testing.T) {
})
svc = svc.Filter(ErrorFilter)
s := flav.Serve(svc)
defer s.Stop()
defer s.Stop(context.Background())

ctx, cancel := context.WithCancel(context.Background())
req := NewRequest(ctx, "GET", flav.URL(s), nil)
Expand All @@ -573,7 +571,7 @@ func BenchmarkRequestResponse(b *testing.B) {
l, _ := net.ListenUnix("unix", addr)
defer l.Close()
s, _ := Serve(svc, l)
defer s.Stop()
defer s.Stop(context.Background())

sockTransport := &httpcontrol.Transport{
Dial: func(network, address string) (net.Conn, error) {
Expand Down
124 changes: 121 additions & 3 deletions h2c.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package typhon

import (
"bufio"
"context"
"net"
"net/http"
"net/textproto"
"sync"

"github.com/monzo/terrors"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)

// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 Sections 3.2, 3.4).
// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 §3.2, §3.4).
func H2cFilter(req Request, svc Service) Response {
h := req.Header
// h2c with prior knowledge (RFC 7540 Section 3.4)
Expand All @@ -18,9 +24,121 @@ func H2cFilter(req Request, svc Service) Response {
httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings")
if isPrior || isUpgrade {
rsp := NewResponse(req)
h2s := &http2.Server{}
h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rsp.Writer(), &req.Request)
rw, h2s, err := setupH2cHijacker(req, rsp.Writer())
if err != nil {
return Response{Error: err}
}
h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rw, &req.Request)
return rsp
}
return svc(req)
}

// Dear reader: I'm sorry, the code below isn't fun. This is because Go's h2c implementation doesn't have support for
// connection draining, and all the hooks that make would make this easy are unexported.
//
// If this ticket gets resolved this code can be dramatically simplified, but it is not a priority for the Go team:
// https://github.com/golang/go/issues/26682
//
// 🤢

var h2cConns sync.Map // map[*Server]*h2cInfo

// h2cInfo stores information about connections that have been upgraded by a single Typhon server
type h2cInfo struct {
sync.Mutex
conns []*hijackedConn
h2s *http2.Server
}

// hijackedConn represents a network connection that has been hijacked for a h2c upgrade. This is necessary because we
// need to know when the connection has been closed, to know if/when graceful shutdown completes.
type hijackedConn struct {
net.Conn
closed chan struct{}
closeOnce sync.Once
}

func (c *hijackedConn) Close() error {
defer c.closeOnce.Do(func() { close(c.closed) })
return c.Conn.Close()
}

type h2cHijacker struct {
http.ResponseWriter
http.Hijacker
hijacked func(*hijackedConn)
}

func (h h2cHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
c, r, err := h.Hijacker.Hijack()
conn := &hijackedConn{
Conn: c,
closed: make(chan struct{})}
h.hijacked(conn)
return conn, r, err
}

func shutdownH2c(ctx context.Context, srv *Server) {
_h2c, ok := h2cConns.Load(srv)
if !ok {
return
}
h2c := _h2c.(*h2cInfo)
h2c.Lock()
defer h2c.Unlock()

gracefulCloseLoop:
for _, c := range h2c.conns {
select {
case <-ctx.Done():
break gracefulCloseLoop
case <-c.closed:
h2c.conns = h2c.conns[1:]
}
}
// If any connections remain after gracefulCloseLoop, we need to forcefully close them
for _, c := range h2c.conns {
c.Close()
h2c.conns = h2c.conns[1:]
}
h2cConns.Delete(srv)
}

func setupH2cHijacker(req Request, rw http.ResponseWriter) (http.ResponseWriter, *http2.Server, error) {
hijacker, ok := rw.(http.Hijacker)
if !ok {
err := terrors.InternalService("hijack_impossible", "Cannot hijack response; h2c upgrade impossible", nil)
return nil, nil, err
}
srv := req.server
if srv == nil {
return rw, &http2.Server{}, nil
}

h2c := &h2cInfo{
h2s: &http2.Server{}}
_h2c, loaded := h2cConns.LoadOrStore(srv, h2c)
h2c = _h2c.(*h2cInfo)
if !loaded {
// http2.ConfigureServer wires up an unexported method within the http2 library so it gracefully drains h2c
// connections when the http1 server is stopped. However, this happens asynchronously: the http1 server will
// think it has shut down before the h2c connections have finished draining. To work around this, we add
// a shutdown function of our own in the Typhon server which waits for connections to be drained, or if things
// timeout before then to terminate them forcefully.
http2.ConfigureServer(srv.srv, h2c.h2s)
srv.addShutdownFunc(func(ctx context.Context) {
shutdownH2c(ctx, srv)
})
}

h := h2cHijacker{
ResponseWriter: rw,
Hijacker: hijacker,
hijacked: func(c *hijackedConn) {
h2c.Lock()
defer h2c.Unlock()
h2c.conns = append(h2c.conns, c)
}}
return h, h2c.h2s, nil
}
1 change: 1 addition & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Request struct {
context.Context
err error // Any error from request construction; read by Client
hijacker http.Hijacker
server *Server
}

// unwrappedContext returns the most "unwrapped" Context possible for that in the request.
Expand Down
Loading

0 comments on commit 842491d

Please sign in to comment.