Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager listener init #405

Merged
merged 4 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (h *APIHandler) healthz(w http.ResponseWriter, _ *http.Request) {
}

func (h *APIHandler) readyz(w http.ResponseWriter, r *http.Request) {
if h.ready(r.Context()) {
if h.ready == nil || h.ready(r.Context()) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte("OK"))
Expand Down
4 changes: 3 additions & 1 deletion cmd/forwarder/httpbin/httpbin.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ func (c *command) runE(cmd *cobra.Command, _ []string) error {
if err != nil {
return err
}
defer s.Close()

r := prometheus.NewRegistry()
a, err := forwarder.NewHTTPServer(c.apiServerConfig, forwarder.NewAPIHandler(r, s.Ready, config, ""), logger.Named("api"))
a, err := forwarder.NewHTTPServer(c.apiServerConfig, forwarder.NewAPIHandler(r, nil, config, ""), logger.Named("api"))
if err != nil {
return err
}
defer a.Close()

return runctx.NewGroup(s.Run, a.Run).Run()
}
Expand Down
1 change: 1 addition & 0 deletions cmd/forwarder/pac/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (c *command) runE(cmd *cobra.Command, _ []string) error {
if err != nil {
return err
}
defer s.Close()

return runctx.NewGroup(s.Run).Run()
}
Expand Down
14 changes: 9 additions & 5 deletions cmd/forwarder/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,22 @@ func (c *command) runE(cmd *cobra.Command, _ []string) error {
}

var g runctx.Group
p, err := forwarder.NewHTTPProxy(c.httpProxyConfig, pr, cm, rt, logger.Named("proxy"))
if err != nil {
return err
{
p, err := forwarder.NewHTTPProxy(c.httpProxyConfig, pr, cm, rt, logger.Named("proxy"))
if err != nil {
return err
}
defer p.Close()
g.Add(p.Run)
}
g.Add(p.Run)

if c.apiServerConfig.Addr != "" {
h := forwarder.NewAPIHandler(c.promReg, p.Ready, config, script)
h := forwarder.NewAPIHandler(c.promReg, nil, config, script)
a, err := forwarder.NewHTTPServer(c.apiServerConfig, h, logger.Named("api"))
if err != nil {
return err
}
defer a.Close()
g.Add(a.Run)
}

Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ require (
github.com/prometheus/client_golang v1.13.0
github.com/prometheus/client_model v0.2.0
github.com/prometheus/common v0.37.0
github.com/spf13/cast v1.4.1
github.com/spf13/cobra v1.6.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.10.0
go.uber.org/goleak v1.2.0
go.uber.org/multierr v1.11.0
golang.org/x/exp v0.0.0-20230314191032-db074128a8ec
golang.org/x/net v0.7.0
golang.org/x/sync v0.1.0
Expand All @@ -40,7 +42,6 @@ require (
github.com/prometheus/procfs v0.8.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
golang.org/x/sys v0.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
Expand Down
60 changes: 29 additions & 31 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"net/url"
"regexp"
"sync"
"sync/atomic"
"time"

"github.com/saucelabs/forwarder/httplog"
Expand Down Expand Up @@ -121,12 +120,13 @@ type HTTPProxy struct {
log log.Logger
proxy *martian.Proxy
proxyFunc ProxyFunc
addr atomic.Pointer[string]
listener net.Listener

TLSConfig *tls.Config
Listener net.Listener
}

// NewHTTPProxy creates a new HTTP proxy.
// It is the caller's responsibility to call Close on the returned server.
func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher, rt http.RoundTripper, log log.Logger) (*HTTPProxy, error) {
if err := cfg.Validate(); err != nil {
return nil, err
Expand Down Expand Up @@ -161,6 +161,12 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher,
return nil, err
}

l, err := hp.listen()
if err != nil {
return nil, err
}
hp.listener = l

return hp, nil
}

Expand Down Expand Up @@ -473,15 +479,7 @@ func (hp *HTTPProxy) Handler() http.Handler {
}

func (hp *HTTPProxy) Run(ctx context.Context) error {
listener, err := hp.listener()
if err != nil {
return err
}
defer listener.Close()

addr := listener.Addr().String()
hp.addr.Store(&addr)
hp.log.Infof("server listen address=%s protocol=%s", addr, hp.config.Protocol)
var srv *http.Server

var wg sync.WaitGroup
wg.Add(1)
Expand All @@ -490,22 +488,27 @@ func (hp *HTTPProxy) Run(ctx context.Context) error {
defer wg.Done()

<-ctx.Done()
hp.proxy.Close()
listener.Close()
if srv != nil {
if err := srv.Shutdown(context.Background()); err != nil {
hp.log.Errorf("failed to shutdown server error=%s", err)
}
} else {
hp.Close()
}
}()

var srvErr error
if hp.config.TestingHTTPHandler {
hp.log.Infof("using http handler")
s := http.Server{
srv = &http.Server{
Handler: hp.Handler(),
ReadTimeout: hp.config.ReadTimeout,
ReadHeaderTimeout: hp.config.ReadHeaderTimeout,
WriteTimeout: hp.config.WriteTimeout,
}
srvErr = s.Serve(listener)
srvErr = srv.Serve(hp.listener)
} else {
srvErr = hp.proxy.Serve(listener)
srvErr = hp.proxy.Serve(hp.listener)
}
if srvErr != nil {
if errors.Is(srvErr, net.ErrClosed) {
Expand All @@ -518,11 +521,7 @@ func (hp *HTTPProxy) Run(ctx context.Context) error {
return nil
}

func (hp *HTTPProxy) listener() (net.Listener, error) {
if hp.Listener != nil {
return hp.Listener, nil
}

func (hp *HTTPProxy) listen() (net.Listener, error) {
listener, err := net.Listen("tcp", hp.config.Addr)
if err != nil {
return nil, fmt.Errorf("failed to open listener on address %s: %w", hp.config.Addr, err)
Expand All @@ -539,16 +538,15 @@ func (hp *HTTPProxy) listener() (net.Listener, error) {
}
}

// Addr returns the address the server is listening on or an empty string if the server is not running.
// Addr returns the address the server is listening on.
func (hp *HTTPProxy) Addr() string {
addr := hp.addr.Load()
if addr == nil {
return ""
}
return *addr
return hp.listener.Addr().String()
}

// Ready returns true if the server is running and ready to accept requests.
func (hp *HTTPProxy) Ready(_ context.Context) bool {
return hp.Addr() != ""
func (hp *HTTPProxy) Close() error {
err := hp.listener.Close()
if !hp.proxy.Closing() {
hp.proxy.Close()
}
return err
}
5 changes: 2 additions & 3 deletions http_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestAbortIf(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer p.Close()

check := func(t *testing.T, rt http.RoundTripper) {
t.Helper()
Expand Down Expand Up @@ -93,9 +94,7 @@ func TestNopDialer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := p.configureProxy(); err != nil {
t.Fatal(err)
}
defer p.Close()

req := &http.Request{
Method: http.MethodGet,
Expand Down
55 changes: 22 additions & 33 deletions http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import (
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/saucelabs/forwarder/httplog"
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/middleware"
"go.uber.org/multierr"
)

type Scheme string
Expand Down Expand Up @@ -105,14 +105,14 @@ func (c *HTTPServerConfig) Validate() error {
}

type HTTPServer struct {
config HTTPServerConfig
log log.Logger
srv *http.Server
addr atomic.Pointer[string]

Listener net.Listener
config HTTPServerConfig
log log.Logger
srv *http.Server
listener net.Listener
}

// NewHTTPServer creates a new HTTP server.
// It is the caller's responsibility to call Close on the returned server.
func NewHTTPServer(cfg *HTTPServerConfig, h http.Handler, log log.Logger) (*HTTPServer, error) {
if err := cfg.Validate(); err != nil {
return nil, err
Expand Down Expand Up @@ -143,6 +143,14 @@ func NewHTTPServer(cfg *HTTPServerConfig, h http.Handler, log log.Logger) (*HTTP
}
}

l, err := hs.listen()
if err != nil {
return nil, err
}
hs.listener = l

hs.log.Infof("HTTP server listen address=%s protocol=%s", l.Addr(), hs.config.Protocol)

return hs, nil
}

Expand Down Expand Up @@ -190,16 +198,6 @@ func (hs *HTTPServer) configureHTTP2() error {
}

func (hs *HTTPServer) Run(ctx context.Context) error {
listener, err := hs.listener()
if err != nil {
return err
}
defer listener.Close()

addr := listener.Addr().String()
hs.addr.Store(&addr)
hs.log.Infof("HTTP server listen address=%s protocol=%s", addr, hs.config.Protocol)

var wg sync.WaitGroup
wg.Add(1)

Expand All @@ -215,9 +213,9 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
var srvErr error
switch hs.config.Protocol {
case HTTPScheme:
srvErr = hs.srv.Serve(listener)
srvErr = hs.srv.Serve(hs.listener)
case HTTP2Scheme, HTTPSScheme:
srvErr = hs.srv.ServeTLS(listener, "", "")
srvErr = hs.srv.ServeTLS(hs.listener, "", "")
default:
return fmt.Errorf("invalid protocol %q", hs.config.Protocol)
}
Expand All @@ -233,11 +231,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
return nil
}

func (hs *HTTPServer) listener() (net.Listener, error) {
if hs.Listener != nil {
return hs.Listener, nil
}

func (hs *HTTPServer) listen() (net.Listener, error) {
switch hs.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
listener, err := net.Listen("tcp", hs.srv.Addr)
Expand All @@ -250,16 +244,11 @@ func (hs *HTTPServer) listener() (net.Listener, error) {
}
}

// Addr returns the address the server is listening on or an empty string if the server is not running.
// Addr returns the address the server is listening on.
func (hs *HTTPServer) Addr() string {
addr := hs.addr.Load()
if addr == nil {
return ""
}
return *addr
return hs.listener.Addr().String()
}

// Ready returns true if the server is running and ready to accept requests.
func (hs *HTTPServer) Ready(_ context.Context) bool {
return hs.Addr() != ""
func (hs *HTTPServer) Close() error {
return multierr.Combine(hs.listener.Close(), hs.srv.Close())
}