diff --git a/cmd/mal/mal.go b/cmd/mal/mal.go index 52815932a..9de0b82d2 100644 --- a/cmd/mal/mal.go +++ b/cmd/mal/mal.go @@ -8,6 +8,7 @@ package main import ( "context" + "errors" "fmt" "io/fs" "log/slog" @@ -80,6 +81,16 @@ var riskMap = map[string]int{ "critical": 4, } +func showError(err error) { + emoji := "💣" + if errors.Is(err, action.ErrMatchedCondition) { + emoji = "👋" + err = errors.Unwrap(err) + } + + fmt.Fprintf(os.Stderr, "%s %s\n", emoji, err.Error()) +} + //nolint:cyclop // ignore complexity of 40 func main() { returnCode := ExitOK @@ -398,7 +409,7 @@ func main() { ps, err := action.ActiveProcesses(ctx) if err != nil { returnCode = ExitActionFailed - return fmt.Errorf("process paths: %w", err) + return err } for _, p := range ps { // in the future, we'll also want to attach process info directly @@ -409,7 +420,7 @@ func main() { res, err = action.Scan(ctx, mc) if err != nil { returnCode = ExitActionFailed - return fmt.Errorf("scan: %w", err) + return err } err = renderer.Full(ctx, res) @@ -530,7 +541,13 @@ func main() { } if err := app.Run(os.Args); err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - returnCode = ExitActionFailed + if returnCode != 0 { + returnCode = ExitActionFailed + } + if errors.Is(err, action.ErrMatchedCondition) { + returnCode = ExitOK + } + + showError(err) } } diff --git a/pkg/action/scan.go b/pkg/action/scan.go index dcade5e6f..e043f8c2a 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -5,6 +5,7 @@ package action import ( "context" + "errors" "fmt" "io/fs" "log/slog" @@ -31,7 +32,8 @@ var ( // compiledRuleCache are a cache of previously compiled rules. compiledRuleCache *yara.Rules // compileOnce ensures that we compile rules only once even across threads. - compileOnce sync.Once + compileOnce sync.Once + ErrMatchedCondition = errors.New("matched requested condition") ) // findFilesRecursively returns a list of files found recursively within a path. @@ -233,7 +235,6 @@ func cachedRules(ctx context.Context, fss []fs.FS) (*yara.Rules, error) { //nolint:gocognit,cyclop // ignoring complexity of 101,38 func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) { logger := clog.FromContext(ctx) - logger.Debug("recursive scan", slog.Any("config", c)) r := &malcontent.Report{ Files: orderedmap.New[string, *malcontent.FileReport](), } @@ -243,11 +244,12 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report var scanPathFindings sync.Map + var waitErr error + for _, scanPath := range c.ScanPaths { if c.Renderer != nil { c.Renderer.Scanning(ctx, scanPath) } - logger.Debug("recursive scan", slog.Any("scanPath", scanPath)) imageURI := "" ociExtractPath := "" var err error @@ -323,18 +325,19 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report fr, err := processFile(ctx, c, c.RuleFS, path, scanPath, trimPath, logger) if err != nil { scanPathFindings.Store(path, &malcontent.FileReport{}) - return err + return fmt.Errorf("process: %w", err) } - if fr != nil { - scanPathFindings.Store(path, fr) - if !c.OCI { - var frMap sync.Map - frMap.Store(path, fr) - if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil { - logger.Debugf("match short circuit: %s", err) - scanPathFindings.Store(path, fr) - return err - } + if fr == nil { + return nil + } + + scanPathFindings.Store(path, fr) + if !c.OCI { + var frMap sync.Map + frMap.Store(path, fr) + if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil { + scanPathFindings.Store(path, fr) + return fmt.Errorf("%q: %w", path, ErrMatchedCondition) } } return nil @@ -351,8 +354,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report } if err := g.Wait(); err != nil { - logger.Errorf("error with processing %v\n", err) - return nil, err + waitErr = err } var pathKeys []string @@ -396,6 +398,11 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report } } } + + // short-circuit out + if waitErr != nil { + return r, waitErr + } } // loop: next scan path return r, nil } @@ -460,9 +467,6 @@ func processFile(ctx context.Context, c malcontent.Config, ruleFS []fs.FS, path func Scan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) { r, err := recursiveScan(ctx, c) if err != nil { - if strings.Contains(err.Error(), "no matching capabilities") { - return r, nil - } return r, err } for files := r.Files.Oldest(); files != nil; files = files.Next() { @@ -473,7 +477,7 @@ func Scan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) if c.Stats { err = render.Statistics(r) if err != nil { - return r, err + return r, fmt.Errorf("stats: %w", err) } } return r, nil