Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On shutdown try to drain connections for 30s and close them #967

Merged
merged 10 commits into from
Nov 28, 2024
5 changes: 5 additions & 0 deletions bind/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ func HTTPServerConfig(fs *pflag.FlagSet, cfg *forwarder.HTTPServerConfig, prefix
namePrefix+"read-header-timeout", cfg.ReadHeaderTimeout,
"The amount of time allowed to read request headers.")

fs.DurationVar(&cfg.ShutdownTimeout,
namePrefix+"shutdown-timeout", cfg.ShutdownTimeout,
"The maximum amount of time to wait for the server to drain connections before closing. "+
"Zero means no limit. ")

fs.VarP(anyflag.NewValueWithRedact[*url.Userinfo](cfg.BasicAuth, &cfg.BasicAuth, forwarder.ParseUserinfo, RedactUserinfo),
namePrefix+"basic-auth", "", "<username[:password]>"+
"Basic authentication credentials to protect the server. ")
Expand Down
51 changes: 37 additions & 14 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func DefaultHTTPProxyConfig() *HTTPProxyConfig {
Protocol: HTTPScheme,
IdleTimeout: 1 * time.Hour,
ReadHeaderTimeout: 1 * time.Minute,
shutdownConfig: defaultShutdownConfig(),
TLSServerConfig: TLSServerConfig{
HandshakeTimeout: 10 * time.Second,
},
Expand Down Expand Up @@ -534,10 +535,20 @@ func (hp *HTTPProxy) runHTTPHandler(ctx context.Context) error {
var g errgroup.Group
g.Go(func() error {
<-ctx.Done()
if err := srv.Shutdown(context.Background()); err != nil {
hp.log.Errorf("failed to shutdown server error=%s", err)
ctxErr := ctx.Err()

var cancel context.CancelFunc
ctx, cancel = shutdownContext(hp.config.shutdownConfig)
defer cancel()

if err := srv.Shutdown(ctx); err != nil {
hp.log.Debugf("failed to gracefully shutdown server error=%s", err)
if err := srv.Close(); err != nil {
hp.log.Debugf("failed to close server error=%s", err)
}
}
return ctx.Err()

return ctxErr
})
for i := range hp.listeners {
l := hp.listeners[i]
Expand All @@ -556,8 +567,29 @@ func (hp *HTTPProxy) run(ctx context.Context) error {
var g errgroup.Group
g.Go(func() error {
<-ctx.Done()
hp.Close()
return ctx.Err()
ctxErr := ctx.Err()

// Close listeners first to prevent new connections.
if err := hp.Close(); err != nil {
hp.log.Debugf("failed to close listeners error=%s", err)
}

var cancel context.CancelFunc
ctx, cancel = shutdownContext(hp.config.shutdownConfig)
defer cancel()

if err := hp.proxy.Shutdown(ctx); err != nil {
hp.log.Debugf("failed to gracefully shutdown server error=%s", err)
if err := hp.proxy.Close(); err != nil {
hp.log.Debugf("failed to close server error=%s", err)
}
}

if tr, ok := hp.transport.(*http.Transport); ok {
tr.CloseIdleConnections()
}

return ctxErr
})
for i := range hp.listeners {
l := hp.listeners[i]
Expand Down Expand Up @@ -617,20 +649,11 @@ func (hp *HTTPProxy) Addr() (addrs []string, ok bool) {
}

func (hp *HTTPProxy) Close() error {
// Close listeners first to prevent new connections.
var err error
for _, l := range hp.listeners {
if e := l.Close(); e != nil {
err = multierr.Append(err, e)
}
}

// Close the proxy to stop serving requests.
hp.proxy.Close()

if tr, ok := hp.transport.(*http.Transport); ok {
tr.CloseIdleConnections()
}

return err
}
21 changes: 15 additions & 6 deletions http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/saucelabs/forwarder/httplog"
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/middleware"
"go.uber.org/multierr"
)

type Scheme string
Expand Down Expand Up @@ -80,8 +79,9 @@ type HTTPServerConfig struct {
ReadTimeout time.Duration
ReadHeaderTimeout time.Duration
WriteTimeout time.Duration
LogHTTPMode httplog.Mode
BasicAuth *url.Userinfo
shutdownConfig
LogHTTPMode httplog.Mode
BasicAuth *url.Userinfo
PromConfig
}

Expand All @@ -91,6 +91,7 @@ func DefaultHTTPServerConfig() *HTTPServerConfig {
Protocol: HTTPScheme,
IdleTimeout: 1 * time.Hour,
ReadHeaderTimeout: 1 * time.Minute,
shutdownConfig: defaultShutdownConfig(),
}
}

Expand Down Expand Up @@ -202,8 +203,16 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
defer wg.Done()

<-ctx.Done()
if err := hs.srv.Shutdown(context.Background()); err != nil {
hs.log.Errorf("failed to shutdown server error=%s", err)

var cancel context.CancelFunc
ctx, cancel = shutdownContext(hs.config.shutdownConfig)
defer cancel()

if err := hs.srv.Shutdown(ctx); err != nil {
hs.log.Debugf("failed to gracefully shutdown server error=%s", err)
if err := hs.srv.Close(); err != nil {
hs.log.Debugf("failed to close server error=%s", err)
}
}
}()

Expand Down Expand Up @@ -250,5 +259,5 @@ func (hs *HTTPServer) Addr() string {
}

func (hs *HTTPServer) Close() error {
return multierr.Combine(hs.listener.Close(), hs.srv.Close())
return hs.listener.Close()
}
3 changes: 2 additions & 1 deletion internal/martian/h2/testing/fixture.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package testing

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand Down Expand Up @@ -137,7 +138,7 @@ func New(spf []h2.StreamProcessorFactory) (*Fixture, error) {
func (f *Fixture) Close() error {
f.conn.Close()
f.server.Stop()
f.proxy.Close()
f.proxy.Shutdown(context.Background())
f.wg.Wait()

if err := f.proxyListener.Close(); err != nil {
Expand Down
94 changes: 77 additions & 17 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ import (
"crypto/tls"
"errors"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"

"github.com/saucelabs/forwarder/internal/martian/log"
"github.com/saucelabs/forwarder/internal/martian/mitm"
"github.com/saucelabs/forwarder/internal/martian/proxyutil"
"go.uber.org/multierr"
"golang.org/x/net/http/httpguts"
)

Expand Down Expand Up @@ -114,17 +117,18 @@ type Proxy struct {
// A zero or negative value means there will be no timeout.
WriteTimeout time.Duration

// BaseContex is the base context for all requests.
BaseContex context.Context //nolint:containedctx // It's intended to be used as a base context.
// BaseContext is the base context for all requests.
BaseContext context.Context //nolint:containedctx // It's intended to be used as a base context.

// TestingSkipRoundTrip skips the round trip for requests and returns a 200 OK response.
TestingSkipRoundTrip bool

initOnce sync.Once

rt http.RoundTripper
conns sync.WaitGroup
connsMu sync.Mutex // protects conns.Add/Wait from concurrent access
conns map[net.Conn]struct{}
connsWg atomic.Int32
connsMu sync.Mutex // protects connsWg.Add/Wait and conns from concurrent access
closeCh chan bool
closeOnce sync.Once
}
Expand Down Expand Up @@ -168,31 +172,80 @@ func (p *Proxy) init() {
}).DialContext
}

if p.BaseContex == nil {
p.BaseContex = context.Background()
if p.BaseContext == nil {
p.BaseContext = context.Background()
}

p.conns = make(map[net.Conn]struct{})
p.connsWg.Store(0)
p.closeCh = make(chan bool)
})
}

// Close sets the proxy to the closing state so it stops receiving new connections,
// Shutdown sets the proxy to the closing state so it stops receiving new connections,
// finishes processing any inflight requests, and closes existing connections without
// reading anymore requests from them.
func (p *Proxy) Close() {
func (p *Proxy) Shutdown(ctx context.Context) error {
p.init()

p.closeOnce.Do(func() {
log.Infof(context.TODO(), "closing down proxy")
log.Infof(context.TODO(), "shutting down proxy, draining connections")

p.connsMu.Lock()
defer p.connsMu.Unlock()

p.closeOnce.Do(func() {
close(p.closeCh)
})

log.Infof(context.TODO(), "waiting for connections to close")
p.connsMu.Lock()
p.conns.Wait()
p.connsMu.Unlock()
log.Infof(context.TODO(), "all connections closed")
const shutdownPollIntervalMax = 500 * time.Millisecond

pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
// Add 10% jitter.
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
// Double and clamp for next time.
pollIntervalBase *= 2
if pollIntervalBase > shutdownPollIntervalMax {
pollIntervalBase = shutdownPollIntervalMax
}
return interval
}

timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
if n := p.connsWg.Load(); n == 0 {
log.Infof(context.TODO(), "all connections closed")
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
timer.Reset(nextPollInterval())
}
}
}

// Close closes the proxy and all connections it has accepted.
func (p *Proxy) Close() error {
p.init()

p.connsMu.Lock()
defer p.connsMu.Unlock()

p.closeOnce.Do(func() {
close(p.closeCh)
})

var err error
for conn := range p.conns {
if e := conn.Close(); e != nil {
err = multierr.Append(err, e)
}
}

return err
}

// closing returns whether the proxy is in the closing state.
Expand Down Expand Up @@ -254,9 +307,16 @@ func (p *Proxy) handleLoop(conn net.Conn) {
start := time.Now()

p.connsMu.Lock()
p.conns.Add(1)
p.conns[conn] = struct{}{}
p.connsWg.Add(1)
p.connsMu.Unlock()
defer p.conns.Done()

defer func() {
p.connsMu.Lock()
delete(p.conns, conn)
p.connsMu.Unlock()
}()
defer p.connsWg.Add(-1)
defer conn.Close()
if p.closing() {
return
Expand Down
2 changes: 1 addition & 1 deletion internal/martian/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (p *proxyConn) readRequest() (*http.Request, error) {
if p.secure {
req.TLS = &p.cs
}
req = req.WithContext(withTraceID(p.BaseContex, newTraceID(req.Header.Get(p.RequestIDHeader))))
req = req.WithContext(withTraceID(p.BaseContext, newTraceID(req.Header.Get(p.RequestIDHeader))))

// Adjust the read deadline if necessary.
if !hdrDeadline.Equal(wholeReqDeadline) {
Expand Down
2 changes: 1 addition & 1 deletion internal/martian/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (p *Proxy) Handler() http.Handler {
}

func (p proxyHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
outreq := req.Clone(withTraceID(p.BaseContex, newTraceID(req.Header.Get(p.RequestIDHeader))))
outreq := req.Clone(withTraceID(p.BaseContext, newTraceID(req.Header.Get(p.RequestIDHeader))))
if req.ContentLength == 0 {
outreq.Body = http.NoBody
}
Expand Down
Loading