diff --git a/cmd/ebpf.go b/cmd/ebpf.go index 9b9ed0a9..b5438bb2 100644 --- a/cmd/ebpf.go +++ b/cmd/ebpf.go @@ -6,33 +6,33 @@ import ( "github.com/mozillazg/ptcpdump/internal/dev" ) -func attachHooks(opts Options) (map[int]dev.Device, *bpf.BPF, error) { +func attachHooks(opts Options) (*bpf.BPF, error) { devices, err := dev.GetDevices(opts.ifaces) if err != nil { - return nil, nil, err + return nil, err } if err := rlimit.RemoveMemlock(); err != nil { - return devices, nil, err + return nil, err } bf, err := bpf.NewBPF() if err != nil { - return devices, nil, err + return nil, err } if err := bf.Load(bpf.NewOptions(opts.pid, opts.comm, opts.followForks, opts.pcapFilter)); err != nil { - return devices, nil, err + return nil, err } if err := bf.AttachKprobes(); err != nil { - return devices, bf, err + return bf, err } if err := bf.AttachTracepoints(); err != nil { - return devices, bf, err + return bf, err } for _, iface := range devices { if err := bf.AttachTcHooks(iface.Ifindex, opts.DirectionOut(), opts.DirectionIn()); err != nil { - return devices, bf, err + return bf, err } } - return devices, bf, nil + return bf, nil } diff --git a/cmd/root.go b/cmd/root.go index c384c829..f8c2731c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -83,7 +83,7 @@ func run(cmd *cobra.Command, args []string) error { }() go pcache.Start() - devices, bf, err := attachHooks(opts) + bf, err := attachHooks(opts) if err != nil { if bf != nil { bf.Close() @@ -108,7 +108,7 @@ func run(cmd *cobra.Command, args []string) error { execConsumer := consumer.NewExecEventConsumer(pcache) go execConsumer.Start(ctx, execEvensCh) - packetConsumer := consumer.NewPacketEventConsumer(writers, devices) + packetConsumer := consumer.NewPacketEventConsumer(writers) go func() { packetConsumer.Start(ctx, packetEvensCh, opts.maxPacketCount) stop() diff --git a/internal/consumer/net.go b/internal/consumer/net.go index f688e851..2a39ba56 100644 --- a/internal/consumer/net.go +++ b/internal/consumer/net.go @@ -14,7 +14,8 @@ type PacketEventConsumer struct { devices map[int]dev.Device } -func NewPacketEventConsumer(writers []writer.PacketWriter, devices map[int]dev.Device) *PacketEventConsumer { +func NewPacketEventConsumer(writers []writer.PacketWriter) *PacketEventConsumer { + devices, _ := dev.GetDevices([]string{}) return &PacketEventConsumer{ writers: writers, devices: devices, @@ -47,7 +48,7 @@ func (c *PacketEventConsumer) parsePacketEvent(pt bpf.BpfPacketEventT) { for _, w := range c.writers { if err := w.Write(pevent); err != nil { - log.Printf("[PacketEventConsumer] write packet failed: %s", err) + log.Printf("[PacketEventConsumer] write packet failed: %s, device: %#v", err, pevent.Device) } w.Flush() } diff --git a/internal/dev/dev.go b/internal/dev/dev.go index a264dce6..28b75e8a 100644 --- a/internal/dev/dev.go +++ b/internal/dev/dev.go @@ -3,19 +3,31 @@ package dev import ( "github.com/vishvananda/netlink" "golang.org/x/xerrors" + "sync" ) +var allLinks []netlink.Link +var once sync.Once + type Device struct { Name string Ifindex int } +func getAllLinks() ([]netlink.Link, error) { + var err error + once.Do(func() { + allLinks, err = netlink.LinkList() + }) + return allLinks, err +} + func GetDevices(names []string) (map[int]Device, error) { var links []netlink.Link var err error ifindexMap := make(map[int]Device) - allLinks, err := netlink.LinkList() + allLinks, err := getAllLinks() if err != nil { return nil, xerrors.Errorf(": %w", err) }