diff --git a/api.go b/api.go index 50264662..58de0b12 100644 --- a/api.go +++ b/api.go @@ -73,7 +73,7 @@ func (h *APIHandler) healthz(w http.ResponseWriter, _ *http.Request) { } func (h *APIHandler) readyz(w http.ResponseWriter, r *http.Request) { - if h.ready(r.Context()) { + if h.ready == nil || h.ready(r.Context()) { w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "text/plain") w.Write([]byte("OK")) diff --git a/cmd/forwarder/httpbin/httpbin.go b/cmd/forwarder/httpbin/httpbin.go index 6ce41a9e..ed3f17b3 100644 --- a/cmd/forwarder/httpbin/httpbin.go +++ b/cmd/forwarder/httpbin/httpbin.go @@ -40,12 +40,14 @@ func (c *command) runE(cmd *cobra.Command, _ []string) error { if err != nil { return err } + defer s.Close() r := prometheus.NewRegistry() - a, err := forwarder.NewHTTPServer(c.apiServerConfig, forwarder.NewAPIHandler(r, s.Ready, config, ""), logger.Named("api")) + a, err := forwarder.NewHTTPServer(c.apiServerConfig, forwarder.NewAPIHandler(r, nil, config, ""), logger.Named("api")) if err != nil { return err } + defer a.Close() return runctx.NewGroup(s.Run, a.Run).Run() } diff --git a/cmd/forwarder/pac/server/server.go b/cmd/forwarder/pac/server/server.go index beba27cc..1e10a998 100644 --- a/cmd/forwarder/pac/server/server.go +++ b/cmd/forwarder/pac/server/server.go @@ -68,6 +68,7 @@ func (c *command) runE(cmd *cobra.Command, _ []string) error { if err != nil { return err } + defer s.Close() return runctx.NewGroup(s.Run).Run() } diff --git a/cmd/forwarder/run/run.go b/cmd/forwarder/run/run.go index 16f667ce..88c2224b 100644 --- a/cmd/forwarder/run/run.go +++ b/cmd/forwarder/run/run.go @@ -120,18 +120,22 @@ func (c *command) runE(cmd *cobra.Command, _ []string) error { } var g runctx.Group - p, err := forwarder.NewHTTPProxy(c.httpProxyConfig, pr, cm, rt, logger.Named("proxy")) - if err != nil { - return err + { + p, err := forwarder.NewHTTPProxy(c.httpProxyConfig, pr, cm, rt, logger.Named("proxy")) + if err != nil { + return err + } + defer p.Close() + g.Add(p.Run) } - g.Add(p.Run) if c.apiServerConfig.Addr != "" { - h := forwarder.NewAPIHandler(c.promReg, p.Ready, config, script) + h := forwarder.NewAPIHandler(c.promReg, nil, config, script) a, err := forwarder.NewHTTPServer(c.apiServerConfig, h, logger.Named("api")) if err != nil { return err } + defer a.Close() g.Add(a.Run) } diff --git a/go.mod b/go.mod index cf270c19..0dc622ce 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,12 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.37.0 + github.com/spf13/cast v1.4.1 github.com/spf13/cobra v1.6.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.10.0 go.uber.org/goleak v1.2.0 + go.uber.org/multierr v1.11.0 golang.org/x/exp v0.0.0-20230314191032-db074128a8ec golang.org/x/net v0.7.0 golang.org/x/sync v0.1.0 @@ -40,7 +42,6 @@ require ( github.com/prometheus/procfs v0.8.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect github.com/spf13/afero v1.9.2 // indirect - github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/subosito/gotenv v1.2.0 // indirect golang.org/x/sys v0.5.0 // indirect diff --git a/go.sum b/go.sum index c2b6a93f..b4d4f1ee 100644 --- a/go.sum +++ b/go.sum @@ -298,6 +298,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/http_proxy.go b/http_proxy.go index 8385eb8d..a6b1956b 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -18,7 +18,6 @@ import ( "net/url" "regexp" "sync" - "sync/atomic" "time" "github.com/saucelabs/forwarder/httplog" @@ -121,12 +120,13 @@ type HTTPProxy struct { log log.Logger proxy *martian.Proxy proxyFunc ProxyFunc - addr atomic.Pointer[string] + listener net.Listener TLSConfig *tls.Config - Listener net.Listener } +// NewHTTPProxy creates a new HTTP proxy. +// It is the caller's responsibility to call Close on the returned server. func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher, rt http.RoundTripper, log log.Logger) (*HTTPProxy, error) { if err := cfg.Validate(); err != nil { return nil, err @@ -161,6 +161,12 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher, return nil, err } + l, err := hp.listen() + if err != nil { + return nil, err + } + hp.listener = l + return hp, nil } @@ -473,15 +479,7 @@ func (hp *HTTPProxy) Handler() http.Handler { } func (hp *HTTPProxy) Run(ctx context.Context) error { - listener, err := hp.listener() - if err != nil { - return err - } - defer listener.Close() - - addr := listener.Addr().String() - hp.addr.Store(&addr) - hp.log.Infof("server listen address=%s protocol=%s", addr, hp.config.Protocol) + var srv *http.Server var wg sync.WaitGroup wg.Add(1) @@ -490,22 +488,27 @@ func (hp *HTTPProxy) Run(ctx context.Context) error { defer wg.Done() <-ctx.Done() - hp.proxy.Close() - listener.Close() + if srv != nil { + if err := srv.Shutdown(context.Background()); err != nil { + hp.log.Errorf("failed to shutdown server error=%s", err) + } + } else { + hp.Close() + } }() var srvErr error if hp.config.TestingHTTPHandler { hp.log.Infof("using http handler") - s := http.Server{ + srv = &http.Server{ Handler: hp.Handler(), ReadTimeout: hp.config.ReadTimeout, ReadHeaderTimeout: hp.config.ReadHeaderTimeout, WriteTimeout: hp.config.WriteTimeout, } - srvErr = s.Serve(listener) + srvErr = srv.Serve(hp.listener) } else { - srvErr = hp.proxy.Serve(listener) + srvErr = hp.proxy.Serve(hp.listener) } if srvErr != nil { if errors.Is(srvErr, net.ErrClosed) { @@ -518,11 +521,7 @@ func (hp *HTTPProxy) Run(ctx context.Context) error { return nil } -func (hp *HTTPProxy) listener() (net.Listener, error) { - if hp.Listener != nil { - return hp.Listener, nil - } - +func (hp *HTTPProxy) listen() (net.Listener, error) { listener, err := net.Listen("tcp", hp.config.Addr) if err != nil { return nil, fmt.Errorf("failed to open listener on address %s: %w", hp.config.Addr, err) @@ -539,16 +538,15 @@ func (hp *HTTPProxy) listener() (net.Listener, error) { } } -// Addr returns the address the server is listening on or an empty string if the server is not running. +// Addr returns the address the server is listening on. func (hp *HTTPProxy) Addr() string { - addr := hp.addr.Load() - if addr == nil { - return "" - } - return *addr + return hp.listener.Addr().String() } -// Ready returns true if the server is running and ready to accept requests. -func (hp *HTTPProxy) Ready(_ context.Context) bool { - return hp.Addr() != "" +func (hp *HTTPProxy) Close() error { + err := hp.listener.Close() + if !hp.proxy.Closing() { + hp.proxy.Close() + } + return err } diff --git a/http_proxy_test.go b/http_proxy_test.go index 62028284..ce958711 100644 --- a/http_proxy_test.go +++ b/http_proxy_test.go @@ -29,6 +29,7 @@ func TestAbortIf(t *testing.T) { if err != nil { t.Fatal(err) } + defer p.Close() check := func(t *testing.T, rt http.RoundTripper) { t.Helper() @@ -93,9 +94,7 @@ func TestNopDialer(t *testing.T) { if err != nil { t.Fatal(err) } - if err := p.configureProxy(); err != nil { - t.Fatal(err) - } + defer p.Close() req := &http.Request{ Method: http.MethodGet, diff --git a/http_server.go b/http_server.go index e72a76d3..9dbd49cc 100644 --- a/http_server.go +++ b/http_server.go @@ -15,13 +15,13 @@ import ( "net/http" "net/url" "sync" - "sync/atomic" "time" "github.com/prometheus/client_golang/prometheus" "github.com/saucelabs/forwarder/httplog" "github.com/saucelabs/forwarder/log" "github.com/saucelabs/forwarder/middleware" + "go.uber.org/multierr" ) type Scheme string @@ -105,14 +105,14 @@ func (c *HTTPServerConfig) Validate() error { } type HTTPServer struct { - config HTTPServerConfig - log log.Logger - srv *http.Server - addr atomic.Pointer[string] - - Listener net.Listener + config HTTPServerConfig + log log.Logger + srv *http.Server + listener net.Listener } +// NewHTTPServer creates a new HTTP server. +// It is the caller's responsibility to call Close on the returned server. func NewHTTPServer(cfg *HTTPServerConfig, h http.Handler, log log.Logger) (*HTTPServer, error) { if err := cfg.Validate(); err != nil { return nil, err @@ -143,6 +143,14 @@ func NewHTTPServer(cfg *HTTPServerConfig, h http.Handler, log log.Logger) (*HTTP } } + l, err := hs.listen() + if err != nil { + return nil, err + } + hs.listener = l + + hs.log.Infof("HTTP server listen address=%s protocol=%s", l.Addr(), hs.config.Protocol) + return hs, nil } @@ -190,16 +198,6 @@ func (hs *HTTPServer) configureHTTP2() error { } func (hs *HTTPServer) Run(ctx context.Context) error { - listener, err := hs.listener() - if err != nil { - return err - } - defer listener.Close() - - addr := listener.Addr().String() - hs.addr.Store(&addr) - hs.log.Infof("HTTP server listen address=%s protocol=%s", addr, hs.config.Protocol) - var wg sync.WaitGroup wg.Add(1) @@ -215,9 +213,9 @@ func (hs *HTTPServer) Run(ctx context.Context) error { var srvErr error switch hs.config.Protocol { case HTTPScheme: - srvErr = hs.srv.Serve(listener) + srvErr = hs.srv.Serve(hs.listener) case HTTP2Scheme, HTTPSScheme: - srvErr = hs.srv.ServeTLS(listener, "", "") + srvErr = hs.srv.ServeTLS(hs.listener, "", "") default: return fmt.Errorf("invalid protocol %q", hs.config.Protocol) } @@ -233,11 +231,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error { return nil } -func (hs *HTTPServer) listener() (net.Listener, error) { - if hs.Listener != nil { - return hs.Listener, nil - } - +func (hs *HTTPServer) listen() (net.Listener, error) { switch hs.config.Protocol { case HTTPScheme, HTTPSScheme, HTTP2Scheme: listener, err := net.Listen("tcp", hs.srv.Addr) @@ -250,16 +244,11 @@ func (hs *HTTPServer) listener() (net.Listener, error) { } } -// Addr returns the address the server is listening on or an empty string if the server is not running. +// Addr returns the address the server is listening on. func (hs *HTTPServer) Addr() string { - addr := hs.addr.Load() - if addr == nil { - return "" - } - return *addr + return hs.listener.Addr().String() } -// Ready returns true if the server is running and ready to accept requests. -func (hs *HTTPServer) Ready(_ context.Context) bool { - return hs.Addr() != "" +func (hs *HTTPServer) Close() error { + return multierr.Combine(hs.listener.Close(), hs.srv.Close()) }