Skip to content

Commit

Permalink
stop the accept loop when listener is closed (#261)
Browse files Browse the repository at this point in the history
* stop the loop when listener is closed

* simplify check

* set error after listener closed

* stop queuing new connections after listener closed

* simplify close condition check
  • Loading branch information
WeidiDeng authored Nov 4, 2024
1 parent d8ba3fb commit ec8fae2
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions layer4/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -65,6 +66,7 @@ func (lw *ListenerWrapper) WrapListener(l net.Listener) net.Listener {
Listener: l,
logger: lw.logger,
compiledRoute: lw.compiledRoute,
done: make(chan struct{}),
connChan: connChan,
wg: new(sync.WaitGroup),
}
Expand Down Expand Up @@ -114,25 +116,30 @@ type listener struct {
logger *zap.Logger
compiledRoute Handler

closed atomic.Bool
done chan struct{}
// closed when there is a non-recoverable error and all handle goroutines are done
connChan chan net.Conn
err error

// count running handles
wg *sync.WaitGroup
}

func (l *listener) Close() error {
l.closed.Store(true)
return l.Listener.Close()
}

// loop accept connection from underlying listener and pipe the connection if there are any
func (l *listener) loop() {
for {
conn, err := l.Listener.Accept()
var nerr net.Error
if errors.As(err, &nerr) && nerr.Temporary() {
if errors.As(err, &nerr) && nerr.Temporary() && !l.closed.Load() {
l.logger.Error("temporary error accepting connection", zap.Error(err))
continue
}
if err != nil {
l.err = err
break
}

Expand All @@ -145,6 +152,7 @@ func (l *listener) loop() {
l.wg.Wait()
close(l.connChan)
}()
close(l.done)
for conn := range l.connChan {
_ = conn.Close()
}
Expand Down Expand Up @@ -185,11 +193,15 @@ func (l *listener) handle(conn net.Conn) {
}

func (l *listener) Accept() (net.Conn, error) {
for conn := range l.connChan {
return conn, nil
select {
case conn, ok := <-l.connChan:
if ok {
return conn, nil
}
return nil, net.ErrClosed
case <-l.done:
return nil, net.ErrClosed
}
return nil, l.err

}

func (l *listener) pipeConnection(conn *Connection) error {
Expand Down

0 comments on commit ec8fae2

Please sign in to comment.