diff --git a/e2e_http1_test.go b/e2e_http1_test.go index bdb1c515..a62a547a 100644 --- a/e2e_http1_test.go +++ b/e2e_http1_test.go @@ -12,13 +12,13 @@ type http1Flavour struct { T *testing.T } -func (f http1Flavour) Serve(svc Service) Server { +func (f http1Flavour) Serve(svc Service) *Server { s, err := Listen(svc, "localhost:0") require.NoError(f.T, err) return s } -func (f http1Flavour) URL(s Server) string { +func (f http1Flavour) URL(s *Server) string { return fmt.Sprintf("http://%s", s.Listener().Addr()) } @@ -31,7 +31,7 @@ type http1TLSFlavour struct { cert tls.Certificate } -func (f http1TLSFlavour) Serve(svc Service) Server { +func (f http1TLSFlavour) Serve(svc Service) *Server { l, err := tls.Listen("tcp", "localhost:0", &tls.Config{ Certificates: []tls.Certificate{f.cert}, ClientAuth: tls.NoClientCert}) @@ -41,7 +41,7 @@ func (f http1TLSFlavour) Serve(svc Service) Server { return s } -func (f http1TLSFlavour) URL(s Server) string { +func (f http1TLSFlavour) URL(s *Server) string { return fmt.Sprintf("https://%s", s.Listener().Addr()) } diff --git a/e2e_http2_test.go b/e2e_http2_test.go index 6b51de40..0e33a627 100644 --- a/e2e_http2_test.go +++ b/e2e_http2_test.go @@ -13,14 +13,14 @@ type http2H2cFlavour struct { client Service } -func (f http2H2cFlavour) Serve(svc Service) Server { +func (f http2H2cFlavour) Serve(svc Service) *Server { svc = svc.Filter(H2cFilter) s, err := Listen(svc, "localhost:0") require.NoError(f.T, err) return s } -func (f http2H2cFlavour) URL(s Server) string { +func (f http2H2cFlavour) URL(s *Server) string { return fmt.Sprintf("http://%s", s.Listener().Addr()) } @@ -34,7 +34,7 @@ type http2H2Flavour struct { cert tls.Certificate } -func (f http2H2Flavour) Serve(svc Service) Server { +func (f http2H2Flavour) Serve(svc Service) *Server { l, err := tls.Listen("tcp", "localhost:0", &tls.Config{ Certificates: []tls.Certificate{f.cert}, ClientAuth: tls.NoClientCert, @@ -45,7 +45,7 @@ func (f http2H2Flavour) Serve(svc Service) Server { return s } -func (f http2H2Flavour) URL(s Server) string { +func (f http2H2Flavour) URL(s *Server) string { return fmt.Sprintf("https://%s", s.Listener().Addr()) } diff --git a/e2e_test.go b/e2e_test.go index f76758ad..1ab20386 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -21,8 +21,8 @@ import ( ) type e2eFlavour interface { - Serve(Service) Server - URL(Server) string + Serve(Service) *Server + URL(*Server) string Proto() string } @@ -79,7 +79,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour) DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { return net.Dial(network, addr) }} - defer transport.CloseIdleConnections() Client = HttpService(transport).Filter(ErrorFilter) impl(t, http2H2cFlavour{T: t}) }) @@ -92,7 +91,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour) AllowHTTP: false, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true}} - defer transport.CloseIdleConnections() Client = HttpService(transport).Filter(ErrorFilter) impl(t, http2H2Flavour{ T: t, @@ -115,7 +113,7 @@ func TestE2E(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{ "a": "b"}) @@ -156,7 +154,7 @@ func TestE2EStreaming(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) rsp := req.Send().Response() @@ -190,7 +188,7 @@ func TestE2EStreaming(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) reqS := Streamer() @@ -233,7 +231,7 @@ func TestE2EDomainSocket(t *testing.T) { s, err := Serve(svc, l) require.NoError(t, err) - defer s.Stop() + defer s.Stop(context.Background()) sockTransport := &httpcontrol.Transport{ Dial: func(network, address string) (net.Conn, error) { @@ -262,7 +260,7 @@ func TestE2EError(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) rsp := req.Send().Response() @@ -290,7 +288,7 @@ func TestE2ECancellation(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) @@ -322,7 +320,7 @@ func TestE2ENoFollowRedirect(t *testing.T) { return rsp }) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -354,14 +352,14 @@ func TestE2EProxiedStreamer(t *testing.T) { return rsp }) s := flav.Serve(downstream) - defer s.Stop() + defer s.Stop(context.Background()) proxy := Service(func(req Request) Response { proxyReq := NewRequest(req, "GET", flav.URL(s), nil) return proxyReq.Send().Response() }) ps := flav.Serve(proxy) - defer ps.Stop() + defer ps.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(ps), nil) rsp := req.Send().Response() @@ -400,7 +398,7 @@ func TestE2EInfiniteContext(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{ "a": "b"}) @@ -435,7 +433,7 @@ func TestE2ERequestAutoChunking(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -484,7 +482,7 @@ func TestE2EResponseAutoChunking(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -548,7 +546,7 @@ func TestE2EStreamingCancellation(t *testing.T) { }) svc = svc.Filter(ErrorFilter) s := flav.Serve(svc) - defer s.Stop() + defer s.Stop(context.Background()) ctx, cancel := context.WithCancel(context.Background()) req := NewRequest(ctx, "GET", flav.URL(s), nil) @@ -573,7 +571,7 @@ func BenchmarkRequestResponse(b *testing.B) { l, _ := net.ListenUnix("unix", addr) defer l.Close() s, _ := Serve(svc, l) - defer s.Stop() + defer s.Stop(context.Background()) sockTransport := &httpcontrol.Transport{ Dial: func(network, address string) (net.Conn, error) { diff --git a/h2c.go b/h2c.go index 94bfdc86..c4588c1c 100644 --- a/h2c.go +++ b/h2c.go @@ -1,14 +1,20 @@ package typhon import ( + "bufio" + "context" + "net" + "net/http" "net/textproto" + "sync" + "github.com/monzo/terrors" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) -// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 Sections 3.2, 3.4). +// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 §3.2, §3.4). func H2cFilter(req Request, svc Service) Response { h := req.Header // h2c with prior knowledge (RFC 7540 Section 3.4) @@ -18,9 +24,121 @@ func H2cFilter(req Request, svc Service) Response { httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings") if isPrior || isUpgrade { rsp := NewResponse(req) - h2s := &http2.Server{} - h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rsp.Writer(), &req.Request) + rw, h2s, err := setupH2cHijacker(req, rsp.Writer()) + if err != nil { + return Response{Error: err} + } + h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rw, &req.Request) return rsp } return svc(req) } + +// Dear reader: I'm sorry, the code below isn't fun. This is because Go's h2c implementation doesn't have support for +// connection draining, and all the hooks that make would make this easy are unexported. +// +// If this ticket gets resolved this code can be dramatically simplified, but it is not a priority for the Go team: +// https://github.com/golang/go/issues/26682 +// +// 🤢 + +var h2cConns sync.Map // map[*Server]*h2cInfo + +// h2cInfo stores information about connections that have been upgraded by a single Typhon server +type h2cInfo struct { + sync.Mutex + conns []*hijackedConn + h2s *http2.Server +} + +// hijackedConn represents a network connection that has been hijacked for a h2c upgrade. This is necessary because we +// need to know when the connection has been closed, to know if/when graceful shutdown completes. +type hijackedConn struct { + net.Conn + closed chan struct{} + closeOnce sync.Once +} + +func (c *hijackedConn) Close() error { + defer c.closeOnce.Do(func() { close(c.closed) }) + return c.Conn.Close() +} + +type h2cHijacker struct { + http.ResponseWriter + http.Hijacker + hijacked func(*hijackedConn) +} + +func (h h2cHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + c, r, err := h.Hijacker.Hijack() + conn := &hijackedConn{ + Conn: c, + closed: make(chan struct{})} + h.hijacked(conn) + return conn, r, err +} + +func shutdownH2c(ctx context.Context, srv *Server) { + _h2c, ok := h2cConns.Load(srv) + if !ok { + return + } + h2c := _h2c.(*h2cInfo) + h2c.Lock() + defer h2c.Unlock() + +gracefulCloseLoop: + for _, c := range h2c.conns { + select { + case <-ctx.Done(): + break gracefulCloseLoop + case <-c.closed: + h2c.conns = h2c.conns[1:] + } + } + // If any connections remain after gracefulCloseLoop, we need to forcefully close them + for _, c := range h2c.conns { + c.Close() + h2c.conns = h2c.conns[1:] + } + h2cConns.Delete(srv) +} + +func setupH2cHijacker(req Request, rw http.ResponseWriter) (http.ResponseWriter, *http2.Server, error) { + hijacker, ok := rw.(http.Hijacker) + if !ok { + err := terrors.InternalService("hijack_impossible", "Cannot hijack response; h2c upgrade impossible", nil) + return nil, nil, err + } + srv := req.server + if srv == nil { + return rw, &http2.Server{}, nil + } + + h2c := &h2cInfo{ + h2s: &http2.Server{}} + _h2c, loaded := h2cConns.LoadOrStore(srv, h2c) + h2c = _h2c.(*h2cInfo) + if !loaded { + // http2.ConfigureServer wires up an unexported method within the http2 library so it gracefully drains h2c + // connections when the http1 server is stopped. However, this happens asynchronously: the http1 server will + // think it has shut down before the h2c connections have finished draining. To work around this, we add + // a shutdown function of our own in the Typhon server which waits for connections to be drained, or if things + // timeout before then to terminate them forcefully. + http2.ConfigureServer(srv.srv, h2c.h2s) + srv.addShutdownFunc(func(ctx context.Context) { + shutdownH2c(ctx, srv) + }) + } + + h := h2cHijacker{ + ResponseWriter: rw, + Hijacker: hijacker, + hijacked: func(c *hijackedConn) { + h2c.Lock() + defer h2c.Unlock() + h2c.conns = append(h2c.conns, c) + }} + return h, h2c.h2s, nil +} diff --git a/request.go b/request.go index a804741f..23df85fe 100644 --- a/request.go +++ b/request.go @@ -17,6 +17,7 @@ type Request struct { context.Context err error // Any error from request construction; read by Client hijacker http.Hijacker + server *Server } // unwrappedContext returns the most "unwrapped" Context possible for that in the request. diff --git a/server.go b/server.go index 34543510..f9492949 100644 --- a/server.go +++ b/server.go @@ -1,52 +1,101 @@ package typhon import ( + "context" "fmt" "net" + "net/http" "os" "strconv" - "time" + "sync" - "github.com/facebookgo/httpdown" + "github.com/monzo/slog" ) -const DefaultListenAddr = "127.0.0.1:0" +type Server struct { + l net.Listener + srv *http.Server + shuttingDown chan struct{} + shutdownOnce sync.Once + shutdownFuncs []func(context.Context) + shutdownFuncsM sync.Mutex +} -type Server interface { - httpdown.Server - Listener() net.Listener - WaitC() <-chan struct{} +// Listener returns the network listener that this server is active on. +func (s *Server) Listener() net.Listener { + return s.l } -type server struct { - httpdown.Server - l net.Listener +// Done returns a channel that will be closed when the server begins to shutdown. The server may still be draining its +// connections at the time the channel is closed. +func (s *Server) Done() <-chan struct{} { + return s.shuttingDown } -func (s server) Listener() net.Listener { - return s.l +// Stop shuts down the server, returning when there are no more connections still open. Graceful shutdown will be +// attempted until the passed context expires, at which time all connections will be forcibly terminated. +func (s *Server) Stop(ctx context.Context) { + s.shutdownFuncsM.Lock() + defer s.shutdownFuncsM.Unlock() + s.shutdownOnce.Do(func() { + close(s.shuttingDown) + // Shut down the HTTP server in parallel to calling any custom shutdown functions + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + if err := s.srv.Shutdown(ctx); err != nil { + slog.Debug(ctx, "Graceful shutdown failed; forcibly closing connections 👢") + if err := s.srv.Close(); err != nil { + slog.Critical(ctx, "Forceful shutdown failed, exiting 😱: %v", err) + panic(err) // Something is super hosed here + } + } + }() + for _, f := range s.shutdownFuncs { + f := f // capture range variable + wg.Add(1) + go func() { + defer wg.Done() + f(ctx) + }() + } + wg.Wait() + }) } -func (s server) WaitC() <-chan struct{} { - c := make(chan struct{}, 0) - go func() { - s.Server.Wait() - close(c) - }() - return c +// addShutdownFunc registers a function that will be called when the server is stopped. The function is expected to try +// to shutdown gracefully until the context expires, at which time it should terminate its work forcefully. +func (s *Server) addShutdownFunc(f func(context.Context)) { + s.shutdownFuncsM.Lock() + defer s.shutdownFuncsM.Unlock() + s.shutdownFuncs = append(s.shutdownFuncs, f) } -func Serve(svc Service, l net.Listener) (Server, error) { - downer := &httpdown.HTTP{ - StopTimeout: 20 * time.Second, - KillTimeout: 25 * time.Second} - downerServer := downer.Serve(HttpServer(svc), l) - return server{ - Server: downerServer, - l: l}, nil +// Serve starts a HTTP server, binding the passed Service to the passed listener. +func Serve(svc Service, l net.Listener) (*Server, error) { + s := &Server{ + l: l, + shuttingDown: make(chan struct{})} + svc = svc.Filter(func(req Request, svc Service) Response { + req.server = s + return svc(req) + }) + s.srv = HttpServer(svc) + go func() { + err := s.srv.Serve(l) + if err != nil && err != http.ErrServerClosed { + slog.Error(nil, "HTTP server error: %v", err) + // Stopping with an already-closed context means we go immediately to "forceful" mode + ctx, cancel := context.WithCancel(context.Background()) + cancel() + s.Stop(ctx) + } + }() + return s, nil } -func Listen(svc Service, addr string) (Server, error) { +func Listen(svc Service, addr string) (*Server, error) { // Determine on which address to listen, choosing in order one of: // 1. The passed addr // 2. PORT variable (listening on all interfaces)