diff --git a/layer4/listener.go b/layer4/listener.go index 70c2733..b7d532b 100644 --- a/layer4/listener.go +++ b/layer4/listener.go @@ -7,6 +7,7 @@ import ( "net" "runtime" "sync" + "sync/atomic" "time" "github.com/caddyserver/caddy/v2" @@ -65,6 +66,7 @@ func (lw *ListenerWrapper) WrapListener(l net.Listener) net.Listener { Listener: l, logger: lw.logger, compiledRoute: lw.compiledRoute, + done: make(chan struct{}), connChan: connChan, wg: new(sync.WaitGroup), } @@ -114,25 +116,30 @@ type listener struct { logger *zap.Logger compiledRoute Handler + closed atomic.Bool + done chan struct{} // closed when there is a non-recoverable error and all handle goroutines are done connChan chan net.Conn - err error // count running handles wg *sync.WaitGroup } +func (l *listener) Close() error { + l.closed.Store(true) + return l.Listener.Close() +} + // loop accept connection from underlying listener and pipe the connection if there are any func (l *listener) loop() { for { conn, err := l.Listener.Accept() var nerr net.Error - if errors.As(err, &nerr) && nerr.Temporary() { + if errors.As(err, &nerr) && nerr.Temporary() && !l.closed.Load() { l.logger.Error("temporary error accepting connection", zap.Error(err)) continue } if err != nil { - l.err = err break } @@ -145,6 +152,7 @@ func (l *listener) loop() { l.wg.Wait() close(l.connChan) }() + close(l.done) for conn := range l.connChan { _ = conn.Close() } @@ -185,11 +193,15 @@ func (l *listener) handle(conn net.Conn) { } func (l *listener) Accept() (net.Conn, error) { - for conn := range l.connChan { - return conn, nil + select { + case conn, ok := <-l.connChan: + if ok { + return conn, nil + } + return nil, net.ErrClosed + case <-l.done: + return nil, net.ErrClosed } - return nil, l.err - } func (l *listener) pipeConnection(conn *Connection) error {