Skip to content

Commit

Permalink
[ebpfless] Make UDP tests pass in ebpfless test suite (#30934)
Browse files Browse the repository at this point in the history
  • Loading branch information
pimlu authored Nov 12, 2024
1 parent 2a4999b commit 631608e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 63 deletions.
21 changes: 14 additions & 7 deletions pkg/network/tracer/connection/ebpfless_tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,20 @@ func (t *ebpfLessTracer) processConnection(
conn.Duration = time.Duration(time.Now().UnixNano())
}

if ip4 == nil && ip6 == nil {
return nil
}
var err error
switch conn.Type {
case network.UDP:
if (ip4 != nil && !t.config.CollectUDPv4Conns) || (ip6 != nil && !t.config.CollectUDPv6Conns) {
return nil
}
err = t.udp.process(conn, pktType, udp)
case network.TCP:
if (ip4 != nil && !t.config.CollectTCPv4Conns) || (ip6 != nil && !t.config.CollectTCPv6Conns) {
return nil
}
err = t.tcp.process(conn, pktType, ip4, ip6, tcp)
default:
err = fmt.Errorf("unsupported connection type %d", conn.Type)
Expand Down Expand Up @@ -242,13 +251,11 @@ func (t *ebpfLessTracer) determineConnectionDirection(conn *network.ConnectionSt
return
}

if conn.Type == network.TCP {
switch pktType {
case unix.PACKET_HOST:
conn.Direction = network.INCOMING
case unix.PACKET_OUTGOING:
conn.Direction = network.OUTGOING
}
switch pktType {
case unix.PACKET_HOST:
conn.Direction = network.INCOMING
case unix.PACKET_OUTGOING:
conn.Direction = network.OUTGOING
}
}

Expand Down
70 changes: 42 additions & 28 deletions pkg/network/tracer/tracer_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,24 +576,31 @@ func (s *TracerSuite) TestUnconnectedUDPSendIPv6() {
tr := setupTracer(t, cfg)
linkLocal, err := offsetguess.GetIPv6LinkLocalAddress()
require.NoError(t, err)
remoteAddr := linkLocal[0]
remoteAddr.Port = rand.Int()%5000 + 15000

remotePort := rand.Int()%5000 + 15000
remoteAddr := &net.UDPAddr{IP: net.ParseIP(offsetguess.InterfaceLocalMulticastIPv6), Port: remotePort}
conn, err := net.ListenUDP("udp6", linkLocal[0])
conn, err := net.ListenUDP("udp6", remoteAddr)
require.NoError(t, err)
defer conn.Close()
message := []byte("payload")
bytesSent, err := conn.WriteTo(message, remoteAddr)
require.NoError(t, err)

connections := getConnections(t, tr)
outgoing := network.FilterConnections(connections, func(cs network.ConnectionStats) bool {
return cs.DPort == uint16(remotePort)
})
require.EventuallyWithT(t, func(ct *assert.CollectT) {
connections := getConnections(t, tr)
outgoing := network.FilterConnections(connections, func(cs network.ConnectionStats) bool {
if cs.Type != network.UDP {
return false
}
return cs.DPort == uint16(remoteAddr.Port)
})
if !assert.Len(ct, outgoing, 1) {
return
}
assert.Equal(ct, remoteAddr.IP.String(), outgoing[0].Dest.String())
assert.Equal(ct, bytesSent, int(outgoing[0].Monotonic.SentBytes))
}, 3*time.Second, 100*time.Millisecond)

require.Len(t, outgoing, 1)
assert.Equal(t, remoteAddr.IP.String(), outgoing[0].Dest.String())
assert.Equal(t, bytesSent, int(outgoing[0].Monotonic.SentBytes))
}

func (s *TracerSuite) TestGatewayLookupNotEnabled() {
Expand Down Expand Up @@ -1471,28 +1478,35 @@ func testUDPReusePort(t *testing.T, udpnet string, ip string) {

// Iterate through active connections until we find connection created above, and confirm send + recv counts
t.Logf("port: %d", assignedPort)
connections := getConnections(t, tr)
for _, c := range connections.Conns {
t.Log(c)
}

incoming, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), connections)
if assert.True(t, ok, "unable to find incoming connection") {
assert.Equal(t, network.INCOMING, incoming.Direction)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// use t instead of ct because getConnections uses require (not assert), and we get a better error message that way
connections := getConnections(t, tr)

// make sure the inverse values are seen for the other message
assert.Equal(t, serverMessageSize, int(incoming.Monotonic.SentBytes), "incoming sent")
assert.Equal(t, clientMessageSize, int(incoming.Monotonic.RecvBytes), "incoming recv")
assert.True(t, incoming.IntraHost, "incoming intrahost")
}
incoming, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), connections)
if assert.True(t, ok, "unable to find incoming connection") {
assert.Equal(t, network.INCOMING, incoming.Direction)

// make sure the inverse values are seen for the other message
assert.Equal(t, serverMessageSize, int(incoming.Monotonic.SentBytes), "incoming sent")
assert.Equal(t, clientMessageSize, int(incoming.Monotonic.RecvBytes), "incoming recv")
assert.True(t, incoming.IntraHost, "incoming intrahost")
}

outgoing, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections)
if assert.True(t, ok, "unable to find outgoing connection") {
assert.Equal(t, network.OUTGOING, outgoing.Direction)
outgoing, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections)
if assert.True(t, ok, "unable to find outgoing connection") {
assert.Equal(t, network.OUTGOING, outgoing.Direction)

assert.Equal(t, clientMessageSize, int(outgoing.Monotonic.SentBytes), "outgoing sent")
assert.Equal(t, serverMessageSize, int(outgoing.Monotonic.RecvBytes), "outgoing recv")
assert.True(t, outgoing.IntraHost, "outgoing intrahost")
assert.Equal(t, clientMessageSize, int(outgoing.Monotonic.SentBytes), "outgoing sent")
assert.Equal(t, serverMessageSize, int(outgoing.Monotonic.RecvBytes), "outgoing recv")
assert.True(t, outgoing.IntraHost, "outgoing intrahost")
}
}, 3*time.Second, 100*time.Millisecond)

// log the connections at the end in case the test failed
connections := getConnections(t, tr)
for _, c := range connections.Conns {
t.Log(c)
}
}

Expand Down
62 changes: 34 additions & 28 deletions pkg/network/tracer/tracer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,26 +452,29 @@ func testUDPSendAndReceive(t *testing.T, tr *Tracer, addr string) {
require.NoError(t, err)

// Iterate through active connections until we find connection created above, and confirm send + recv counts
connections := getConnections(t, tr)

incoming, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), connections)
if assert.True(t, ok, "unable to find incoming connection") {
assert.Equal(t, network.INCOMING, incoming.Direction)
require.EventuallyWithT(t, func(ct *assert.CollectT) {
// use t instead of ct because getConnections uses require (not assert), and we get a better error message
connections := getConnections(t, tr)
incoming, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), connections)
if assert.True(ct, ok, "unable to find incoming connection") {
assert.Equal(ct, network.INCOMING, incoming.Direction)

// make sure the inverse values are seen for the other message
assert.Equal(ct, serverMessageSize, int(incoming.Monotonic.SentBytes), "incoming sent")
assert.Equal(ct, clientMessageSize, int(incoming.Monotonic.RecvBytes), "incoming recv")
assert.True(ct, incoming.IntraHost, "incoming intrahost")
}

// make sure the inverse values are seen for the other message
assert.Equal(t, serverMessageSize, int(incoming.Monotonic.SentBytes), "incoming sent")
assert.Equal(t, clientMessageSize, int(incoming.Monotonic.RecvBytes), "incoming recv")
assert.True(t, incoming.IntraHost, "incoming intrahost")
}
outgoing, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections)
if assert.True(t, ok, "unable to find outgoing connection") {
assert.Equal(t, network.OUTGOING, outgoing.Direction)

outgoing, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections)
if assert.True(t, ok, "unable to find outgoing connection") {
assert.Equal(t, network.OUTGOING, outgoing.Direction)
assert.Equal(t, clientMessageSize, int(outgoing.Monotonic.SentBytes), "outgoing sent")
assert.Equal(t, serverMessageSize, int(outgoing.Monotonic.RecvBytes), "outgoing recv")
assert.True(t, outgoing.IntraHost, "outgoing intrahost")
}

assert.Equal(t, clientMessageSize, int(outgoing.Monotonic.SentBytes), "outgoing sent")
assert.Equal(t, serverMessageSize, int(outgoing.Monotonic.RecvBytes), "outgoing recv")
assert.True(t, outgoing.IntraHost, "outgoing intrahost")
}
}, 3*time.Second, 100*time.Millisecond)
}

func (s *TracerSuite) TestUDPDisabled() {
Expand Down Expand Up @@ -1141,13 +1144,15 @@ func (s *TracerSuite) TestUnconnectedUDPSendIPv4() {
bytesSent, err := conn.WriteTo(message, remoteAddr)
require.NoError(t, err)

connections := getConnections(t, tr)
outgoing := network.FilterConnections(connections, func(cs network.ConnectionStats) bool {
return cs.DPort == uint16(remotePort)
})
require.EventuallyWithT(t, func(ct *assert.CollectT) {
connections := getConnections(t, tr)
outgoing := network.FilterConnections(connections, func(cs network.ConnectionStats) bool {
return cs.DPort == uint16(remotePort)
})

require.Len(t, outgoing, 1)
assert.Equal(t, bytesSent, int(outgoing[0].Monotonic.SentBytes))
assert.Len(ct, outgoing, 1)
assert.Equal(ct, bytesSent, int(outgoing[0].Monotonic.SentBytes))
}, 3*time.Second, 100*time.Millisecond)
}

func (s *TracerSuite) TestConnectedUDPSendIPv6() {
Expand All @@ -1168,18 +1173,19 @@ func (s *TracerSuite) TestConnectedUDPSendIPv6() {
require.NoError(t, err)

var outgoing []network.ConnectionStats
require.Eventually(t, func() bool {
require.EventuallyWithT(t, func(ct *assert.CollectT) {
connections := getConnections(t, tr)
outgoing = network.FilterConnections(connections, func(cs network.ConnectionStats) bool {
return cs.DPort == uint16(remotePort)
})
if !assert.Len(ct, outgoing, 1) {
return
}

return len(outgoing) == 1
assert.Equal(ct, remoteAddr.IP.String(), outgoing[0].Dest.String())
assert.Equal(ct, bytesSent, int(outgoing[0].Monotonic.SentBytes))
}, 3*time.Second, 100*time.Millisecond, "failed to find connection")

require.Len(t, outgoing, 1)
assert.Equal(t, remoteAddr.IP.String(), outgoing[0].Dest.String())
assert.Equal(t, bytesSent, int(outgoing[0].Monotonic.SentBytes))
}

func (s *TracerSuite) TestTCPDirection() {
Expand Down

0 comments on commit 631608e

Please sign in to comment.