From 14374426521c3bec977396622fef13535d008778 Mon Sep 17 00:00:00 2001 From: mozillazg Date: Sat, 27 Apr 2024 19:46:42 +0800 Subject: [PATCH] add new flag: -c --- cmd/options.go | 1 + cmd/root.go | 6 +++++- internal/consumer/net.go | 9 ++++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cmd/options.go b/cmd/options.go index 089a0c65..bb8a3271 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -10,6 +10,7 @@ type Options struct { listInterfaces bool version bool print bool + maxPacketCount uint } func (o Options) WritePath() string { diff --git a/cmd/root.go b/cmd/root.go index f2e3854d..9b8bb528 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -53,6 +53,7 @@ func init() { "Print the ptcpdump and libpcap version strings and exit") rootCmd.Flags().BoolVar(&opts.print, "print", false, "Print parsed packet output, even if the raw packets are being saved to a file with the -w flag") + rootCmd.Flags().UintVarP(&opts.maxPacketCount, "receive-count", "c", 0, "Exit after receiving count packets") } func Execute() error { @@ -104,7 +105,10 @@ func run(cmd *cobra.Command, args []string) error { execConsumer := consumer.NewExecEventConsumer(pcache) go execConsumer.Start(ctx, execEventReader) packetConsumer := consumer.NewPacketEventConsumer(writers, devices) - go packetConsumer.Start(ctx, packetEventReader) + go func() { + packetConsumer.Start(ctx, packetEventReader, opts.maxPacketCount) + stop() + }() runtime.Gosched() diff --git a/internal/consumer/net.go b/internal/consumer/net.go index 1fb853bb..08d6f974 100644 --- a/internal/consumer/net.go +++ b/internal/consumer/net.go @@ -22,7 +22,8 @@ func NewPacketEventConsumer(writers []writer.PacketWriter, devices map[int]dev.D } } -func (c *PacketEventConsumer) Start(ctx context.Context, reader *perf.Reader) { +func (c *PacketEventConsumer) Start(ctx context.Context, reader *perf.Reader, maxPacketCount uint) { + var n uint for { select { case <-ctx.Done(): @@ -43,6 +44,12 @@ func (c *PacketEventConsumer) Start(ctx context.Context, reader *perf.Reader) { log.Printf("[PacketEventConsumer] lost samples: %d", record.LostSamples) } c.parsePacketEvent(record.RawSample) + + n++ + if maxPacketCount > 0 && n == maxPacketCount { + log.Printf("%d packets captured", n) + break + } } }