From 1569ba3dd9a0b435b49f8201e58ec3b32fd28c9c Mon Sep 17 00:00:00 2001 From: Oliver Beattie Date: Thu, 25 Oct 2018 18:18:16 +0100 Subject: [PATCH] =?UTF-8?q?Connection=20draining=20for=20h2c=20connections?= =?UTF-8?q?=20=F0=9F=9A=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Connections that were upgraded to HTTP/2 by use of the H2cFilter can now be drained properly. The implementation is pretty ugly because Go does not have native support for connection draining on h2c connections, and as per golang/go#26682 this isn't a priority for the project. --- e2e_http1_test.go | 8 +-- e2e_http2_test.go | 8 +-- e2e_test.go | 34 ++++++------- h2c.go | 124 ++++++++++++++++++++++++++++++++++++++++++++-- request.go | 1 + server.go | 105 ++++++++++++++++++++++++++++----------- 6 files changed, 223 insertions(+), 57 deletions(-) diff --git a/e2e_http1_test.go b/e2e_http1_test.go index bdb1c515..a62a547a 100644 --- a/e2e_http1_test.go +++ b/e2e_http1_test.go @@ -12,13 +12,13 @@ type http1Flavour struct { T *testing.T } -func (f http1Flavour) Serve(svc Service) Server { +func (f http1Flavour) Serve(svc Service) *Server { s, err := Listen(svc, "localhost:0") require.NoError(f.T, err) return s } -func (f http1Flavour) URL(s Server) string { +func (f http1Flavour) URL(s *Server) string { return fmt.Sprintf("http://%s", s.Listener().Addr()) } @@ -31,7 +31,7 @@ type http1TLSFlavour struct { cert tls.Certificate } -func (f http1TLSFlavour) Serve(svc Service) Server { +func (f http1TLSFlavour) Serve(svc Service) *Server { l, err := tls.Listen("tcp", "localhost:0", &tls.Config{ Certificates: []tls.Certificate{f.cert}, ClientAuth: tls.NoClientCert}) @@ -41,7 +41,7 @@ func (f http1TLSFlavour) Serve(svc Service) Server { return s } -func (f http1TLSFlavour) URL(s Server) string { +func (f http1TLSFlavour) URL(s *Server) string { return fmt.Sprintf("https://%s", s.Listener().Addr()) } diff --git a/e2e_http2_test.go b/e2e_http2_test.go index 6b51de40..0e33a627 100644 --- a/e2e_http2_test.go +++ b/e2e_http2_test.go @@ -13,14 +13,14 @@ type http2H2cFlavour struct { client Service } -func (f http2H2cFlavour) Serve(svc Service) Server { +func (f http2H2cFlavour) Serve(svc Service) *Server { svc = svc.Filter(H2cFilter) s, err := Listen(svc, "localhost:0") require.NoError(f.T, err) return s } -func (f http2H2cFlavour) URL(s Server) string { +func (f http2H2cFlavour) URL(s *Server) string { return fmt.Sprintf("http://%s", s.Listener().Addr()) } @@ -34,7 +34,7 @@ type http2H2Flavour struct { cert tls.Certificate } -func (f http2H2Flavour) Serve(svc Service) Server { +func (f http2H2Flavour) Serve(svc Service) *Server { l, err := tls.Listen("tcp", "localhost:0", &tls.Config{ Certificates: []tls.Certificate{f.cert}, ClientAuth: tls.NoClientCert, @@ -45,7 +45,7 @@ func (f http2H2Flavour) Serve(svc Service) Server { return s } -func (f http2H2Flavour) URL(s Server) string { +func (f http2H2Flavour) URL(s *Server) string { return fmt.Sprintf("https://%s", s.Listener().Addr()) } diff --git a/e2e_test.go b/e2e_test.go index f76758ad..1ab20386 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -21,8 +21,8 @@ import ( ) type e2eFlavour interface { - Serve(Service) Server - URL(Server) string + Serve(Service) *Server + URL(*Server) string Proto() string } @@ -79,7 +79,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour) DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return net.Dial(network, addr) }} - defer transport.CloseIdleConnections() Client = HttpService(transport).Filter(ErrorFilter) impl(t, http2H2cFlavour{T: t}) }) @@ -92,7 +91,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour) AllowHTTP: false, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() Client = HttpService(transport).Filter(ErrorFilter) impl(t, http2H2Flavour{ T: t, @@ -115,7 +113,7 @@ func TestE2E(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{ "a": "b"}) @@ -156,7 +154,7 @@ func TestE2EStreaming(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) rsp := req.Send().Response() @@ -190,7 +188,7 @@ func TestE2EStreaming(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) reqS := Streamer() @@ -233,7 +231,7 @@ func TestE2EDomainSocket(t *testing.T) { s, err := Serve(svc, l) require.NoError(t, err) - defer s.Stop() + defer s.Stop(context.Background()) sockTransport := &httpcontrol.Transport{ Dial: func(network, address string) (net.Conn, error) { @@ -262,7 +260,7 @@ func TestE2EError(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) rsp := req.Send().Response() @@ -290,7 +288,7 @@ func TestE2ECancellation(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) @@ -322,7 +320,7 @@ func TestE2ENoFollowRedirect(t *testing.T) { return rsp }) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -354,14 +352,14 @@ func TestE2EProxiedStreamer(t *testing.T) { return rsp }) s := flav.Serve(downstream) - defer s.Stop() + defer s.Stop(context.Background()) proxy := Service(func(req Request) Response { proxyReq := NewRequest(req, "GET", flav.URL(s), nil) return proxyReq.Send().Response() }) ps := flav.Serve(proxy) - defer ps.Stop() + defer ps.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(ps), nil) rsp := req.Send().Response() @@ -400,7 +398,7 @@ func TestE2EInfiniteContext(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{ "a": "b"}) @@ -435,7 +433,7 @@ func TestE2ERequestAutoChunking(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -484,7 +482,7 @@ func TestE2EResponseAutoChunking(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -548,7 +546,7 @@ func TestE2EStreamingCancellation(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) @@ -573,7 +571,7 @@ func BenchmarkRequestResponse(b *testing.B) { l, _ := net.ListenUnix("unix", addr) defer l.Close() s, _ := Serve(svc, l) - defer s.Stop() + defer s.Stop(context.Background()) sockTransport := &httpcontrol.Transport{ Dial: func(network, address string) (net.Conn, error) { diff --git a/h2c.go b/h2c.go index 94bfdc86..c4588c1c 100644 --- a/h2c.go +++ b/h2c.go @@ -1,14 +1,20 @@ package typhon import ( + "bufio" + "context" + "net" + "net/http" "net/textproto" + "sync" + "github.com/monzo/terrors" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) -// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 Sections 3.2, 3.4). +// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 §3.2, §3.4). func H2cFilter(req Request, svc Service) Response { h := req.Header // h2c with prior knowledge (RFC 7540 Section 3.4) @@ -18,9 +24,121 @@ func H2cFilter(req Request, svc Service) Response { httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings") if isPrior || isUpgrade { rsp := NewResponse(req) - h2s := &http2.Server{} - h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rsp.Writer(), &req.Request) + rw, h2s, err := setupH2cHijacker(req, rsp.Writer()) + if err != nil { + return Response{Error: err} + } + h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rw, &req.Request) return rsp } return svc(req) } + +// Dear reader: I'm sorry, the code below isn't fun. This is because Go's h2c implementation doesn't have support for +// connection draining, and all the hooks that make would make this easy are unexported. +// +// If this ticket gets resolved this code can be dramatically simplified, but it is not a priority for the Go team: +// https://github.com/golang/go/issues/26682 +// +// 🤢 + +var h2cConns sync.Map // map[*Server]*h2cInfo + +// h2cInfo stores information about connections that have been upgraded by a single Typhon server +type h2cInfo struct { + sync.Mutex + conns []*hijackedConn + h2s *http2.Server +} + +// hijackedConn represents a network connection that has been hijacked for a h2c upgrade. This is necessary because we +// need to know when the connection has been closed, to know if/when graceful shutdown completes. +type hijackedConn struct { + net.Conn + closed chan struct{} + closeOnce sync.Once +} + +func (c *hijackedConn) Close() error { + defer c.closeOnce.Do(func() { close(c.closed) }) + return c.Conn.Close() +} + +type h2cHijacker struct { + http.ResponseWriter + http.Hijacker + hijacked func(*hijackedConn) +} + +func (h h2cHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + c, r, err := h.Hijacker.Hijack() + conn := &hijackedConn{ + Conn: c, + closed: make(chan struct{})} + h.hijacked(conn) + return conn, r, err +} + +func shutdownH2c(ctx context.Context, srv *Server) { + _h2c, ok := h2cConns.Load(srv) + if !ok { + return + } + h2c := _h2c.(*h2cInfo) + h2c.Lock() + defer h2c.Unlock() + +gracefulCloseLoop: + for _, c := range h2c.conns { + select { + case <-ctx.Done(): + break gracefulCloseLoop + case <-c.closed: + h2c.conns = h2c.conns[1:] + } + } + // If any connections remain after gracefulCloseLoop, we need to forcefully close them + for _, c := range h2c.conns { + c.Close() + h2c.conns = h2c.conns[1:] + } + h2cConns.Delete(srv) +} + +func setupH2cHijacker(req Request, rw http.ResponseWriter) (http.ResponseWriter, *http2.Server, error) { + hijacker, ok := rw.(http.Hijacker) + if !ok { + err := terrors.InternalService("hijack_impossible", "Cannot hijack response; h2c upgrade impossible", nil) + return nil, nil, err + } + srv := req.server + if srv == nil { + return rw, &http2.Server{}, nil + } + + h2c := &h2cInfo{ + h2s: &http2.Server{}} + _h2c, loaded := h2cConns.LoadOrStore(srv, h2c) + h2c = _h2c.(*h2cInfo) + if !loaded { + // http2.ConfigureServer wires up an unexported method within the http2 library so it gracefully drains h2c + // connections when the http1 server is stopped. However, this happens asynchronously: the http1 server will + // think it has shut down before the h2c connections have finished draining. To work around this, we add + // a shutdown function of our own in the Typhon server which waits for connections to be drained, or if things + // timeout before then to terminate them forcefully. + http2.ConfigureServer(srv.srv, h2c.h2s) + srv.addShutdownFunc(func(ctx context.Context) { + shutdownH2c(ctx, srv) + }) + } + + h := h2cHijacker{ + ResponseWriter: rw, + Hijacker: hijacker, + hijacked: func(c *hijackedConn) { + h2c.Lock() + defer h2c.Unlock() + h2c.conns = append(h2c.conns, c) + }} + return h, h2c.h2s, nil +} diff --git a/request.go b/request.go index a804741f..23df85fe 100644 --- a/request.go +++ b/request.go @@ -17,6 +17,7 @@ type Request struct { context.Context err error // Any error from request construction; read by Client hijacker http.Hijacker + server *Server } // unwrappedContext returns the most "unwrapped" Context possible for that in the request. diff --git a/server.go b/server.go index 34543510..f9492949 100644 --- a/server.go +++ b/server.go @@ -1,52 +1,101 @@ package typhon import ( + "context" "fmt" "net" + "net/http" "os" "strconv" - "time" + "sync" - "github.com/facebookgo/httpdown" + "github.com/monzo/slog" ) -const DefaultListenAddr = "127.0.0.1:0" +type Server struct { + l net.Listener + srv *http.Server + shuttingDown chan struct{} + shutdownOnce sync.Once + shutdownFuncs []func(context.Context) + shutdownFuncsM sync.Mutex +} -type Server interface { - httpdown.Server - Listener() net.Listener - WaitC() <-chan struct{} +// Listener returns the network listener that this server is active on. +func (s *Server) Listener() net.Listener { + return s.l } -type server struct { - httpdown.Server - l net.Listener +// Done returns a channel that will be closed when the server begins to shutdown. The server may still be draining its +// connections at the time the channel is closed. +func (s *Server) Done() <-chan struct{} { + return s.shuttingDown } -func (s server) Listener() net.Listener { - return s.l +// Stop shuts down the server, returning when there are no more connections still open. Graceful shutdown will be +// attempted until the passed context expires, at which time all connections will be forcibly terminated. +func (s *Server) Stop(ctx context.Context) { + s.shutdownFuncsM.Lock() + defer s.shutdownFuncsM.Unlock() + s.shutdownOnce.Do(func() { + close(s.shuttingDown) + // Shut down the HTTP server in parallel to calling any custom shutdown functions + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + if err := s.srv.Shutdown(ctx); err != nil { + slog.Debug(ctx, "Graceful shutdown failed; forcibly closing connections 👢") + if err := s.srv.Close(); err != nil { + slog.Critical(ctx, "Forceful shutdown failed, exiting 😱: %v", err) + panic(err) // Something is super hosed here + } + } + }() + for _, f := range s.shutdownFuncs { + f := f // capture range variable + wg.Add(1) + go func() { + defer wg.Done() + f(ctx) + }() + } + wg.Wait() + }) } -func (s server) WaitC() <-chan struct{} { - c := make(chan struct{}, 0) - go func() { - s.Server.Wait() - close(c) - }() - return c +// addShutdownFunc registers a function that will be called when the server is stopped. The function is expected to try +// to shutdown gracefully until the context expires, at which time it should terminate its work forcefully. +func (s *Server) addShutdownFunc(f func(context.Context)) { + s.shutdownFuncsM.Lock() + defer s.shutdownFuncsM.Unlock() + s.shutdownFuncs = append(s.shutdownFuncs, f) } -func Serve(svc Service, l net.Listener) (Server, error) { - downer := &httpdown.HTTP{ - StopTimeout: 20 * time.Second, - KillTimeout: 25 * time.Second} - downerServer := downer.Serve(HttpServer(svc), l) - return server{ - Server: downerServer, - l: l}, nil +// Serve starts a HTTP server, binding the passed Service to the passed listener. +func Serve(svc Service, l net.Listener) (*Server, error) { + s := &Server{ + l: l, + shuttingDown: make(chan struct{})} + svc = svc.Filter(func(req Request, svc Service) Response { + req.server = s + return svc(req) + }) + s.srv = HttpServer(svc) + go func() { + err := s.srv.Serve(l) + if err != nil && err != http.ErrServerClosed { + slog.Error(nil, "HTTP server error: %v", err) + // Stopping with an already-closed context means we go immediately to "forceful" mode + ctx, cancel := context.WithCancel(context.Background()) + cancel() + s.Stop(ctx) + } + }() + return s, nil } -func Listen(svc Service, addr string) (Server, error) { +func Listen(svc Service, addr string) (*Server, error) { // Determine on which address to listen, choosing in order one of: // 1. The passed addr // 2. PORT variable (listening on all interfaces)