Skip to content

Commit

Permalink
Add trace module and trace subcommand
Browse files Browse the repository at this point in the history
  • Loading branch information
jschwinger233 committed Jan 21, 2024
1 parent 32ea550 commit 22d565f
Show file tree
Hide file tree
Showing 6 changed files with 679 additions and 0 deletions.
69 changes: 69 additions & 0 deletions cmd/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2024, daeuniverse Organization <[email protected]>
*/

package cmd

import (
"context"
"os/signal"
"syscall"

"github.com/daeuniverse/dae/cmd/internal"
"github.com/daeuniverse/dae/trace"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)

var (
IPv4, IPv6 bool
L4Proto string
Port int
OutputFile string
)

func init() {
traceCmd := &cobra.Command{
Use: "trace",
Short: "To trace traffic",
Run: func(cmd *cobra.Command, args []string) {
internal.AutoSu()

if IPv4 && IPv6 {
logrus.Fatalln("IPv4 and IPv6 cannot be set at the same time")
}
if !IPv4 && !IPv6 {
IPv4 = true
}
IPVersion := 4
if IPv6 {
IPVersion = 6
}

var L4ProtoNo uint16
switch L4Proto {
case "tcp":
L4ProtoNo = syscall.IPPROTO_TCP
case "udp":
L4ProtoNo = syscall.IPPROTO_UDP
default:
logrus.Fatalf("Unknown L4 protocol: %s\n", L4Proto)
}

ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, OutputFile); err != nil {
logrus.Fatalln(err)
}
},
}

traceCmd.PersistentFlags().BoolVarP(&IPv4, "ipv4", "4", false, "Capture IPv4 traffic")
traceCmd.PersistentFlags().BoolVarP(&IPv6, "ipv6", "6", false, "Capture IPv6 traffic")
traceCmd.PersistentFlags().StringVarP(&L4Proto, "l4-proto", "p", "tcp", "Layer 4 protocol")
traceCmd.PersistentFlags().IntVarP(&Port, "port", "P", 80, "Port")
traceCmd.PersistentFlags().StringVarP(&OutputFile, "output", "o", "/dev/stdout", "Output file")

rootCmd.AddCommand(traceCmd)
}
71 changes: 71 additions & 0 deletions trace/kallsyms.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2024, daeuniverse Organization <[email protected]>
*/

package trace

import (
"bufio"
"os"
"sort"
"strconv"
"strings"

"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)

type Symbol struct {
Type string
Name string
Addr uint64
}

var kallsyms []Symbol
var kallsymsByName map[string]Symbol = make(map[string]Symbol)
var kallsymsByAddr map[uint64]Symbol = make(map[uint64]Symbol)

func init() {
readKallsyms()
}

func readKallsyms() {
file, err := os.Open("/proc/kallsyms")
if err != nil {
logrus.Fatalf("failed to open /proc/kallsyms: %v", err)
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) < 3 {
continue
}
addr, err := strconv.ParseUint(parts[0], 16, 64)
if err != nil {
continue
}
typ, name := parts[1], parts[2]
kallsyms = append(kallsyms, Symbol{typ, name, addr})
kallsymsByName[name] = Symbol{typ, name, addr}
kallsymsByAddr[addr] = Symbol{typ, name, addr}
}
sort.Slice(kallsyms, func(i, j int) bool {
return kallsyms[i].Addr < kallsyms[j].Addr
})
}

func NearestSymbol(addr uint64) Symbol {
idx, _ := slices.BinarySearchFunc(kallsyms, addr, func(x Symbol, addr uint64) int { return int(x.Addr - addr) })
if idx == len(kallsyms) {
return kallsyms[idx-1]
}
if kallsyms[idx].Addr == addr {
return kallsyms[idx]
}
if idx == 0 {
return kallsyms[0]
}
return kallsyms[idx-1]
}
1 change: 1 addition & 0 deletions trace/kern/headers
Submodule headers added at d72c67
233 changes: 233 additions & 0 deletions trace/kern/trace.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
#include "headers/if_ether_defs.h"
#include "headers/vmlinux.h"

#include "headers/bpf_core_read.h"
#include "headers/bpf_endian.h"
#include "headers/bpf_helpers.h"
#include "headers/bpf_tracing.h"

#define IFNAMSIZ 16
#define PNAME_LEN 32

static const bool TRUE = true;

union addr {
u32 v4addr;
struct {
u64 d1;
u64 d2;
} v6addr;
} __attribute__((packed));

struct tuple {
union addr saddr;
union addr daddr;
u16 sport;
u16 dport;
u16 l3_proto;
u8 l4_proto;
u8 pad;
} __attribute__((packed));

struct meta {
u64 pc;
u64 skb;
u32 mark;
u32 netns;
u32 ifindex;
u32 pid;
unsigned char ifname[IFNAMSIZ];
unsigned char pname[PNAME_LEN];
} __attribute__((packed));

struct event {
struct meta meta;
struct tuple tuple;
u64 second_param;
} __attribute__((packed));

const struct event *_ __attribute__((unused));

struct tracing_config {
u16 port;
u16 l4_proto;
u8 ip_vsn;
};

static volatile const struct tracing_config tracing_cfg;

struct {
__uint(type, BPF_MAP_TYPE_HASH);
__type(key, __u64);
__type(value, bool);
__uint(max_entries, 1024);
} skb_addresses SEC(".maps");

struct {
__uint(type, BPF_MAP_TYPE_RINGBUF);
__uint(max_entries, 1<<29);
} events SEC(".maps");

static __always_inline u32
get_netns(struct sk_buff *skb)
{
u32 netns = BPF_CORE_READ(skb, dev, nd_net.net, ns.inum);

// if skb->dev is not initialized, try to get ns from sk->__sk_common.skc_net.net->ns.inum
if (netns == 0) {
struct sock *sk = BPF_CORE_READ(skb, sk);
if (sk != NULL)
netns = BPF_CORE_READ(sk, __sk_common.skc_net.net, ns.inum);
}

return netns;
}

static __always_inline bool
filter_l3_and_l4(struct sk_buff *skb)
{
void *skb_head = BPF_CORE_READ(skb, head);
u16 l3_off = BPF_CORE_READ(skb, network_header);
u16 l4_off = BPF_CORE_READ(skb, transport_header);

struct iphdr *l3_hdr = (struct iphdr *) (skb_head + l3_off);
u8 ip_vsn = BPF_CORE_READ_BITFIELD_PROBED(l3_hdr, version);
if (ip_vsn != tracing_cfg.ip_vsn)
return false;

bpf_printk("ip_vsn %d, tracing_cfg.ip_vsn %d\n", ip_vsn, tracing_cfg.ip_vsn);
u16 l4_proto;
if (ip_vsn == 4) {
struct iphdr *ip4 = (struct iphdr *) l3_hdr;
l4_proto = BPF_CORE_READ(ip4, protocol);
} else if (ip_vsn == 6) {
struct ipv6hdr *ip6 = (struct ipv6hdr *) l3_hdr;
l4_proto = BPF_CORE_READ(ip6, nexthdr);
} else {
return false;
}

bpf_printk("l4_proto %d, tracing_cfg.l4_proto %d\n", l4_proto, tracing_cfg.l4_proto);
if (l4_proto != tracing_cfg.l4_proto)
return false;

u16 sport, dport;
if (l4_proto == IPPROTO_TCP) {
struct tcphdr *tcp = (struct tcphdr *) (skb_head + l4_off);
sport = BPF_CORE_READ(tcp, source);
dport = BPF_CORE_READ(tcp, dest);
} else if (l4_proto == IPPROTO_UDP) {
struct udphdr *udp = (struct udphdr *) (skb_head + l4_off);
sport = BPF_CORE_READ(udp, source);
dport = BPF_CORE_READ(udp, dest);
} else {
return false;
}

bpf_printk("sport %d, dport %d, tracing_cfg.port %d\n", sport, dport, tracing_cfg.port);
if (dport != tracing_cfg.port && sport != tracing_cfg.port)
return false;

return true;
}

static __always_inline void
set_tuple(struct tuple *tpl, struct sk_buff *skb)
{
void *skb_head = BPF_CORE_READ(skb, head);
u16 l3_off = BPF_CORE_READ(skb, network_header);
u16 l4_off = BPF_CORE_READ(skb, transport_header);

struct iphdr *l3_hdr = (struct iphdr *) (skb_head + l3_off);
u8 ip_vsn = BPF_CORE_READ_BITFIELD_PROBED(l3_hdr, version);

if (ip_vsn == 4) {
struct iphdr *ip4 = (struct iphdr *) l3_hdr;
BPF_CORE_READ_INTO(&tpl->saddr, ip4, saddr);
BPF_CORE_READ_INTO(&tpl->daddr, ip4, daddr);
tpl->l4_proto = BPF_CORE_READ(ip4, protocol);
tpl->l3_proto = ETH_P_IP;
} else if (ip_vsn == 6) {
struct ipv6hdr *ip6 = (struct ipv6hdr *) l3_hdr;
BPF_CORE_READ_INTO(&tpl->saddr, ip6, saddr);
BPF_CORE_READ_INTO(&tpl->daddr, ip6, daddr);
tpl->l4_proto = BPF_CORE_READ(ip6, nexthdr);
tpl->l3_proto = ETH_P_IPV6;
}

if (tpl->l4_proto == IPPROTO_TCP) {
struct tcphdr *tcp = (struct tcphdr *) (skb_head + l4_off);
tpl->sport= BPF_CORE_READ(tcp, source);
tpl->dport= BPF_CORE_READ(tcp, dest);
} else if (tpl->l4_proto == IPPROTO_UDP) {
struct udphdr *udp = (struct udphdr *) (skb_head + l4_off);
tpl->sport= BPF_CORE_READ(udp, source);
tpl->dport= BPF_CORE_READ(udp, dest);
}
}

static __always_inline void
set_meta(struct meta *meta, struct sk_buff *skb, struct pt_regs *ctx)
{
meta->pc = BPF_CORE_READ(ctx, ip);
meta->skb = (__u64)skb;
meta->mark = BPF_CORE_READ(skb, mark);
meta->netns = get_netns(skb);
meta->ifindex = BPF_CORE_READ(skb, dev, ifindex);
BPF_CORE_READ_STR_INTO(&meta->ifname, skb, dev, name);

struct task_struct *current = (void *)bpf_get_current_task();
meta->pid = BPF_CORE_READ(current, pid);
u64 arg_start = BPF_CORE_READ(current, mm, arg_start);
bpf_probe_read_user_str(&meta->pname, PNAME_LEN, (void *)arg_start);
}

static __always_inline int
handle_skb(struct sk_buff *skb, struct pt_regs *ctx)
{
bool tracked = false;
u64 skb_addr = (u64) skb;
struct event ev = {};
if (bpf_map_lookup_elem(&skb_addresses, &skb_addr)) {
tracked = true;
goto cont;
}

if (!filter_l3_and_l4(skb))
return 0;

if (!tracked)
bpf_map_update_elem(&skb_addresses, &skb_addr, &TRUE, BPF_ANY);

cont:
ev.second_param = PT_REGS_PARM2(ctx);
set_meta(&ev.meta, skb, ctx);
set_tuple(&ev.tuple, skb);

bpf_ringbuf_output(&events, &ev, sizeof(ev), 0);
return 0;
}

#define KPROBE_SKB_AT(X) \
SEC("kprobe/skb-" #X) \
int kprobe_skb_##X(struct pt_regs *ctx) \
{ \
struct sk_buff *skb = (struct sk_buff *) PT_REGS_PARM##X(ctx); \
return handle_skb(skb, ctx); \
}

KPROBE_SKB_AT(1)
KPROBE_SKB_AT(2)
KPROBE_SKB_AT(3)
KPROBE_SKB_AT(4)
KPROBE_SKB_AT(5)

SEC("kprobe/skb_lifetime_termination")
int kprobe_skb_lifetime_termination(struct pt_regs *ctx)
{
u64 skb = (u64) PT_REGS_PARM1(ctx);
bpf_map_delete_elem(&skb_addresses, &skb);
return 0;
}

SEC("license") const char __license[] = "Dual BSD/GPL";
Loading

0 comments on commit 22d565f

Please sign in to comment.