Skip to content

Commit

Permalink
net: export and document TrackedConn fields
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatczuk committed Feb 20, 2024
1 parent 13fae05 commit a1ad34c
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,13 @@ func (l *Listener) Accept() (net.Conn, error) {
return nil, err
}

conn = NewTrackedConn(conn)
conn = &TrackedConn{
Conn: conn,
}

if l.TLSConfig == nil {
l.metrics.accept()
conn.(*TrackedConn).onClose = l.metrics.close //nolint:forcetypeassert // we know it's a TrackedConn
conn.(*TrackedConn).OnClose = l.metrics.close //nolint:forcetypeassert // we know it's a TrackedConn
return conn, nil
}

Expand All @@ -152,7 +154,7 @@ func (l *Listener) Accept() (net.Conn, error) {
}

l.metrics.accept()
conn.(*TrackedConn).onClose = l.metrics.close //nolint:forcetypeassert // we know it's a TrackedConn
conn.(*TrackedConn).OnClose = l.metrics.close //nolint:forcetypeassert // we know it's a TrackedConn
return tconn, nil
}
}
Expand Down Expand Up @@ -186,43 +188,54 @@ func (l *Listener) Close() error {
return l.listener.Close()
}

// TrackedConn is a net.Conn that tracks the number of bytes read and written.
// It needs to be configured before first use by setting TrackTraffic and OnClose if needed.
type TrackedConn struct {
net.Conn
rx atomic.Uint64
tx atomic.Uint64
onClose func()
}

func NewTrackedConn(c net.Conn) *TrackedConn {
return &TrackedConn{
Conn: c,
}
// TrackTraffic enables counting of bytes read and written by the connection.
// Use Rx and Tx to get the number of bytes read and written.
TrackTraffic bool

// OnClose is called after the underlying connection is closed and before the Close method returns.
OnClose func()

rx atomic.Uint64
tx atomic.Uint64
}

func (c *TrackedConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
c.rx.Add(uint64(n))
if c.TrackTraffic {
c.rx.Add(uint64(n))
}
return n, err
}

func (c *TrackedConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
c.tx.Add(uint64(n))
if c.TrackTraffic {
c.tx.Add(uint64(n))
}
return n, err
}

// Rx returns the number of bytes read from the connection.
// It requires TrackTraffic to be set to true, otherwise it returns 0.
func (c *TrackedConn) Rx() uint64 {
return c.rx.Load()
}

// Tx returns the number of bytes written to the connection.
// It requires TrackTraffic to be set to true, otherwise it returns 0.
func (c *TrackedConn) Tx() uint64 {
return c.tx.Load()
}

func (c *TrackedConn) Close() error {
err := c.Conn.Close()
if c.onClose != nil {
c.onClose()
if c.OnClose != nil {
c.OnClose()
}
return err
}

0 comments on commit a1ad34c

Please sign in to comment.