diff --git a/plugins/common/socket/socket_test.go b/plugins/common/socket/socket_test.go index 9a74a18050677..feefd0b590630 100644 --- a/plugins/common/socket/socket_test.go +++ b/plugins/common/socket/socket_test.go @@ -8,6 +8,7 @@ import ( "os" "runtime" "strings" + "sync" "testing" "time" @@ -477,14 +478,6 @@ func TestClosingConnections(t *testing.T) { return acc.NMetrics() >= 1 }, time.Second, 100*time.Millisecond, "did not receive metric") - // This has to be a stream-listener... - listener, ok := sock.listener.(*streamListener) - require.True(t, ok) - listener.Lock() - conns := len(listener.connections) - listener.Unlock() - require.NotZero(t, conns) - sock.Close() // Verify that plugin.Stop() closed the client's connection @@ -605,6 +598,64 @@ func TestNoSplitter(t *testing.T) { testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics()) } +func TestMemoryLeak(t *testing.T) { + cfg := &Config{ + ServerConfig: *pki.TLSServerConfig(), + } + + sock, err := cfg.NewSocket("tcp://127.0.0.1:0", nil, &testutil.Logger{}) + require.NoError(t, err) + require.NoError(t, sock.Setup()) + sock.ListenConnection(func(_ net.Addr, r io.ReadCloser) { + _, _ = io.Copy(io.Discard, r) + }, func(_ error) {}) + defer sock.Close() + + clientTLS := pki.TLSClientConfig() + tlsConfig, err := clientTLS.TLSConfig() + require.NoError(t, err) + msg := []byte("test v=1i") + client := func() { + conn, err := tls.Dial("tcp", sock.Address().String(), tlsConfig) + require.NoError(t, err) + _, err = conn.Write(msg) + require.NoError(t, err) + require.NoError(t, conn.Close()) + } + + var memStart, mem runtime.MemStats + + run := func(nClients int) { + nWorkers := runtime.GOMAXPROCS(0) + var wg sync.WaitGroup + for i := 0; i < nWorkers; i++ { + wg.Add(1) + go func(i int) { + for j := 0; j < nClients/nWorkers; j++ { + client() + } + wg.Done() + }(i) + } + wg.Wait() + } + // warmup + run(100) + runtime.GC() + runtime.GC() + runtime.ReadMemStats(&memStart) + + n := 5000 + run(n) + runtime.GC() + runtime.GC() + runtime.ReadMemStats(&mem) + + // It's unavoidable that there's going to be some fluctuation. But if there's going to be a leak, it's likely to be at + // least 1 object per loop. So use half the loop count as the threshold. + require.Less(t, mem.HeapObjects, memStart.HeapObjects+uint64(n/2)) +} + func createClient(endpoint string, addr net.Addr, tlsCfg *tls.Config) (net.Conn, error) { // Determine the protocol in a crude fashion parts := strings.SplitN(endpoint, "://", 2) diff --git a/plugins/common/socket/stream.go b/plugins/common/socket/stream.go index 4a5c51ccce4b8..96150c9f5ec39 100644 --- a/plugins/common/socket/stream.go +++ b/plugins/common/socket/stream.go @@ -2,6 +2,7 @@ package socket import ( "bufio" + "context" "crypto/tls" "errors" "fmt" @@ -15,6 +16,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -38,9 +40,9 @@ type streamListener struct { Splitter bufio.SplitFunc Log telegraf.Logger - listener net.Listener - connections map[net.Conn]bool - path string + cancel func() + listener net.Listener + path string wg sync.WaitGroup sync.Mutex @@ -126,17 +128,6 @@ func (l *streamListener) setupConnection(conn net.Conn) error { conn = c.NetConn() } - addr := conn.RemoteAddr().String() - l.Lock() - if l.MaxConnections > 0 && len(l.connections) >= l.MaxConnections { - l.Unlock() - // Ignore the returned error as we cannot do anything about it anyway - _ = conn.Close() - return fmt.Errorf("unable to accept connection from %q: too many connections", addr) - } - l.connections[conn] = true - l.Unlock() - if l.ReadBufferSize > 0 { if rb, ok := conn.(hasSetReadBuffer); ok { if err := rb.SetReadBuffer(l.ReadBufferSize); err != nil { @@ -171,14 +162,6 @@ func (l *streamListener) setupConnection(conn net.Conn) error { return nil } -func (l *streamListener) closeConnection(conn net.Conn) { - addr := conn.RemoteAddr().String() - if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, syscall.EPIPE) { - l.Log.Warnf("Cannot close connection to %q: %v", addr, err) - } - delete(l.connections, conn) -} - func (l *streamListener) address() net.Addr { return l.listener.Addr() } @@ -188,11 +171,7 @@ func (l *streamListener) close() error { return err } - l.Lock() - for conn := range l.connections { - l.closeConnection(conn) - } - l.Unlock() + l.cancel() l.wg.Wait() if l.path != "" { @@ -208,12 +187,16 @@ func (l *streamListener) close() error { return nil } -func (l *streamListener) listenData(onData CallbackData, onError CallbackError) { - l.connections = make(map[net.Conn]bool) +func (l *streamListener) listen(connFunc func(c net.Conn), onError CallbackError) { + var ctx context.Context + ctx, l.cancel = context.WithCancel(context.Background()) l.wg.Add(1) go func() { defer l.wg.Done() + defer context.AfterFunc(ctx, func() { _ = l.listener.Close() })() + + var connCount int32 var wg sync.WaitGroup for { @@ -225,76 +208,59 @@ func (l *streamListener) listenData(onData CallbackData, onError CallbackError) break } - if err := l.setupConnection(conn); err != nil && onError != nil { - onError(err) + if l.MaxConnections > 0 && int(atomic.LoadInt32(&connCount)) >= l.MaxConnections { + onError(fmt.Errorf("unable to accept connection from %q: too many connections", conn.RemoteAddr().String())) + _ = conn.Close() continue } + atomic.AddInt32(&connCount, 1) + wg.Add(1) go func(c net.Conn) { defer wg.Done() - defer func() { - l.Lock() - l.closeConnection(conn) - l.Unlock() - }() - - reader := l.read - if l.Splitter == nil { - reader = l.readAll - } - if err := reader(c, onData); err != nil { - if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) { - if onError != nil { - onError(err) - } - } + defer func() { _ = c.Close() }() + defer context.AfterFunc(ctx, func() { _ = conn.Close() })() + defer atomic.AddInt32(&connCount, -1) + + if err := l.setupConnection(c); err != nil && onError != nil { + onError(err) + return } + + connFunc(c) }(conn) } wg.Wait() }() } -func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) { - l.connections = make(map[net.Conn]bool) - - l.wg.Add(1) - go func() { - defer l.wg.Done() - - var wg sync.WaitGroup - for { - conn, err := l.listener.Accept() - if err != nil { - if !errors.Is(err, net.ErrClosed) && onError != nil { +func (l *streamListener) listenData(onData CallbackData, onError CallbackError) { + l.listen(func(c net.Conn) { + reader := l.read + if l.Splitter == nil { + reader = l.readAll + } + if err := reader(c, onData); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) { + if onError != nil { onError(err) } - break - } - - if err := l.setupConnection(conn); err != nil && onError != nil { - onError(err) - continue } + } + }, onError) +} - wg.Add(1) - go func(c net.Conn) { - defer wg.Done() - if err := l.handleConnection(c, onConnection); err != nil { - if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) { - if onError != nil { - onError(err) - } - } +func (l *streamListener) listenConnection(onConnection CallbackConnection, onError CallbackError) { + l.listen(func(c net.Conn) { + if err := l.handleConnection(c, onConnection); err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) { + if onError != nil { + onError(err) } - l.Lock() - l.closeConnection(conn) - l.Unlock() - }(conn) + } } - wg.Wait() - }() + }, onError) } func (l *streamListener) read(conn net.Conn, onData CallbackData) error {