Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Use netlink for tcpstat collector #2322

Merged
merged 1 commit into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions collector/fixtures/proc/net/tcpstat

This file was deleted.

135 changes: 87 additions & 48 deletions collector/tcpstat_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ package collector

import (
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"syscall"
"unsafe"

"github.com/go-kit/log"
"github.com/mdlayher/netlink"
"github.com/prometheus/client_golang/prometheus"
)

Expand Down Expand Up @@ -80,16 +79,64 @@ func NewTCPStatCollector(logger log.Logger) (Collector, error) {
}, nil
}

// InetDiagSockID (inet_diag_sockid) contains the socket identity.
// https://github.com/torvalds/linux/blob/v4.0/include/uapi/linux/inet_diag.h#L13
type InetDiagSockID struct {
SourcePort [2]byte
DestPort [2]byte
SourceIP [4][4]byte
DestIP [4][4]byte
Interface uint32
Cookie [2]uint32
}

// InetDiagReqV2 (inet_diag_req_v2) is used to request diagnostic data.
// https://github.com/torvalds/linux/blob/v4.0/include/uapi/linux/inet_diag.h#L37
type InetDiagReqV2 struct {
Family uint8
Protocol uint8
Ext uint8
Pad uint8
States uint32
ID InetDiagSockID
}

const sizeOfDiagRequest = 0x38

func (req *InetDiagReqV2) Serialize() []byte {
return (*(*[sizeOfDiagRequest]byte)(unsafe.Pointer(req)))[:]
}

func (req *InetDiagReqV2) Len() int {
return sizeOfDiagRequest
}

type InetDiagMsg struct {
Family uint8
State uint8
Timer uint8
Retrans uint8
ID InetDiagSockID
Expires uint32
RQueue uint32
WQueue uint32
UID uint32
Inode uint32
}

func parseInetDiagMsg(b []byte) *InetDiagMsg {
return (*InetDiagMsg)(unsafe.Pointer(&b[0]))
}

func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error {
tcpStats, err := getTCPStats(procFilePath("net/tcp"))
tcpStats, err := getTCPStats(syscall.AF_INET)
if err != nil {
return fmt.Errorf("couldn't get tcpstats: %w", err)
}

// if enabled ipv6 system
tcp6File := procFilePath("net/tcp6")
if _, hasIPv6 := os.Stat(tcp6File); hasIPv6 == nil {
tcp6Stats, err := getTCPStats(tcp6File)
if _, hasIPv6 := os.Stat(procFilePath("net/tcp6")); hasIPv6 == nil {
tcp6Stats, err := getTCPStats(syscall.AF_INET6)
if err != nil {
return fmt.Errorf("couldn't get tcp6stats: %w", err)
}
Expand All @@ -102,59 +149,51 @@ func (c *tcpStatCollector) Update(ch chan<- prometheus.Metric) error {
for st, value := range tcpStats {
ch <- c.desc.mustNewConstMetric(value, st.String())
}

return nil
}

func getTCPStats(statsFile string) (map[tcpConnectionState]float64, error) {
file, err := os.Open(statsFile)
func getTCPStats(family uint8) (map[tcpConnectionState]float64, error) {
const TCPFAll = 0xFFF
const InetDiagInfo = 2
const SockDiagByFamily = 20

conn, err := netlink.Dial(syscall.NETLINK_INET_DIAG, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("couldn't connect netlink: %w", err)
}
defer conn.Close()

msg := netlink.Message{
Header: netlink.Header{
Type: SockDiagByFamily,
Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_DUMP,
},
Data: (&InetDiagReqV2{
Family: family,
Protocol: syscall.IPPROTO_TCP,
States: TCPFAll,
Ext: 0 | 1<<(InetDiagInfo-1),
}).Serialize(),
}
defer file.Close()

return parseTCPStats(file)
}

func parseTCPStats(r io.Reader) (map[tcpConnectionState]float64, error) {
tcpStats := map[tcpConnectionState]float64{}
contents, err := ioutil.ReadAll(r)
messages, err := conn.Execute(msg)
if err != nil {
return nil, err
}

for _, line := range strings.Split(string(contents), "\n")[1:] {
parts := strings.Fields(line)
if len(parts) == 0 {
continue
}
if len(parts) < 5 {
return nil, fmt.Errorf("invalid TCP stats line: %q", line)
}

qu := strings.Split(parts[4], ":")
if len(qu) < 2 {
return nil, fmt.Errorf("cannot parse tx_queues and rx_queues: %q", line)
}

tx, err := strconv.ParseUint(qu[0], 16, 64)
if err != nil {
return nil, err
}
tcpStats[tcpConnectionState(tcpTxQueuedBytes)] += float64(tx)

rx, err := strconv.ParseUint(qu[1], 16, 64)
if err != nil {
return nil, err
}
tcpStats[tcpConnectionState(tcpRxQueuedBytes)] += float64(rx)
return parseTCPStats(messages)
}

st, err := strconv.ParseInt(parts[3], 16, 8)
if err != nil {
return nil, err
}
func parseTCPStats(msgs []netlink.Message) (map[tcpConnectionState]float64, error) {
tcpStats := map[tcpConnectionState]float64{}

tcpStats[tcpConnectionState(st)]++
for _, m := range msgs {
msg := parseInetDiagMsg(m.Data)

tcpStats[tcpTxQueuedBytes] += float64(msg.WQueue)
tcpStats[tcpRxQueuedBytes] += float64(msg.RQueue)
tcpStats[tcpConnectionState(msg.State)]++
}

return tcpStats, nil
Expand Down
121 changes: 42 additions & 79 deletions collector/tcpstat_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,56 @@
package collector

import (
"os"
"strings"
"bytes"
"encoding/binary"
"syscall"
"testing"

"github.com/mdlayher/netlink"
)

func Test_parseTCPStatsError(t *testing.T) {
tests := []struct {
name string
in string
}{
{
name: "too few fields",
in: "sl local_address\n 0: 00000000:0016",
},
{
name: "missing colon in tx-rx field",
in: "sl local_address rem_address st tx_queue rx_queue\n" +
" 1: 0F02000A:0016 0202000A:8B6B 01 0000000000000001",
},
{
name: "tx parsing issue",
in: "sl local_address rem_address st tx_queue rx_queue\n" +
" 1: 0F02000A:0016 0202000A:8B6B 01 0000000x:00000001",
},
func Test_parseTCPStats(t *testing.T) {
encode := func(m InetDiagMsg) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, m)
if err != nil {
panic(err)
}
return buf.Bytes()
}

msg := []netlink.Message{
{
name: "rx parsing issue",
in: "sl local_address rem_address st tx_queue rx_queue\n" +
" 1: 0F02000A:0016 0202000A:8B6B 01 00000000:0000000x",
Data: encode(InetDiagMsg{
Family: syscall.AF_INET,
State: uint8(tcpEstablished),
Timer: 0,
Retrans: 0,
ID: InetDiagSockID{},
Expires: 0,
RQueue: 11,
WQueue: 21,
UID: 0,
Inode: 0,
}),
},
{
name: "state parsing issue",
in: "sl local_address rem_address st tx_queue rx_queue\n" +
" 1: 0F02000A:0016 0202000A:8B6B 0H 00000000:00000001",
Data: encode(InetDiagMsg{
Family: syscall.AF_INET,
State: uint8(tcpListen),
Timer: 0,
Retrans: 0,
ID: InetDiagSockID{},
Expires: 0,
RQueue: 11,
WQueue: 21,
UID: 0,
Inode: 0,
}),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := parseTCPStats(strings.NewReader(tt.in)); err == nil {
t.Fatal("expected an error, but none occurred")
}
})
}
}

func TestTCPStat(t *testing.T) {

noFile, _ := os.Open("follow the white rabbit")
defer noFile.Close()

if _, err := parseTCPStats(noFile); err == nil {
t.Fatal("expected an error, but none occurred")
}

file, err := os.Open("fixtures/proc/net/tcpstat")
if err != nil {
t.Fatal(err)
}
defer file.Close()

tcpStats, err := parseTCPStats(file)
tcpStats, err := parseTCPStats(msg)
if err != nil {
t.Fatal(err)
}
Expand All @@ -89,35 +79,8 @@ func TestTCPStat(t *testing.T) {
if want, got := 42, int(tcpStats[tcpTxQueuedBytes]); want != got {
t.Errorf("want tcpstat number of bytes in tx queue %d, got %d", want, got)
}
if want, got := 1, int(tcpStats[tcpRxQueuedBytes]); want != got {
if want, got := 22, int(tcpStats[tcpRxQueuedBytes]); want != got {
t.Errorf("want tcpstat number of bytes in rx queue %d, got %d", want, got)
}

}

func Test_getTCPStats(t *testing.T) {
type args struct {
statsFile string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "file not found",
args: args{statsFile: "somewhere over the rainbow"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := getTCPStats(tt.args.statsFile)
if (err != nil) != tt.wantErr {
t.Errorf("getTCPStats() error = %v, wantErr %v", err, tt.wantErr)
return
}
// other cases are covered by TestTCPStat()
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/jsimonetti/rtnetlink v1.1.1
github.com/lufia/iostat v1.2.1
github.com/mattn/go-xmlrpc v0.0.3
github.com/mdlayher/netlink v1.6.0
github.com/mdlayher/wifi v0.0.0-20220320220353-954ff73a19a5
github.com/prometheus/client_golang v1.12.1
github.com/prometheus/client_model v0.2.0
Expand Down