Skip to content

Commit

Permalink
processes: improve results on Linux (#499)
Browse files Browse the repository at this point in the history
* pip: add known good list

* process: make non-existent paths non-fatal, sort scan paths

* processes: improve results on Linux

* processes: improve results on Linux

* processes: improve results on Linux

* improve comment
  • Loading branch information
tstromberg authored Oct 7, 2024
1 parent 6020473 commit 1fe7d0f
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 41 deletions.
13 changes: 7 additions & 6 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,14 @@ func main() {
// When scanning processes, load all of the valid commands (paths)
// and store them as the ScanPaths
if mc.Processes {
processPaths, err := action.GetAllProcessPaths(ctx)
ps, err := action.ActiveProcesses(ctx)
if err != nil {
returnCode = ExitActionFailed
return fmt.Errorf("process paths: %w", err)
}
for _, p := range processPaths {
mc.ScanPaths = append(mc.ScanPaths, p.Path)
for _, p := range ps {
// in the future, we'll also want to attach process info directly
mc.ScanPaths = append(mc.ScanPaths, p.ScanPath)
}
}

Expand Down Expand Up @@ -458,13 +459,13 @@ func main() {
// When scanning processes, load all of the valid commands (paths)
// and store them as the ScanPaths
if mc.Processes {
processPaths, err := action.GetAllProcessPaths(ctx)
ps, err := action.ActiveProcesses(ctx)
if err != nil {
returnCode = ExitActionFailed
return fmt.Errorf("process paths: %w", err)
}
for _, p := range processPaths {
mc.ScanPaths = append(mc.ScanPaths, p.Path)
for _, p := range ps {
mc.ScanPaths = append(mc.ScanPaths, p.ScanPath)
}
}

Expand Down
113 changes: 86 additions & 27 deletions pkg/action/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,120 @@ import (
"context"
"fmt"
"os"
"runtime"
"sort"
"strings"

"github.com/chainguard-dev/clog"
"github.com/shirou/gopsutil/v4/process"
)

type Process struct {
PID int32
Path string
type ProcessInfo struct {
PID int32
PPID int32
Name string
ScanPath string
AdvertisedPath string
CmdLine []string
}

// GetAllProcessPaths is an exported function that returns a slice of Process PIDs and commands (path).
func GetAllProcessPaths(ctx context.Context) ([]Process, error) {
// ActiveProcesses is an exported function that a list of active processes.
func ActiveProcesses(ctx context.Context) ([]*ProcessInfo, error) {
// Retrieve all of the active PIDs
procs, err := process.ProcessesWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("processes: %w", err)
}

// Store PIDs and their respective commands (paths) in a map of paths and their Process structs
processMap := make(map[string]Process, len(procs))
found := map[string]*ProcessInfo{}
for _, p := range procs {
path, err := p.Exe()
// Executable resolution is non-fatal
pi, err := processInfo(ctx, p)
if err != nil {
name, _ := p.Name()
clog.Errorf("%s[%d]: %v", name, p.Pid, err)
clog.Warnf("skipping pid %d: %v", p.Pid, err)
continue
}
if _, exists := processMap[path]; !exists && path != "" && isValidPath(path) {
processMap[path] = Process{
PID: p.Pid,
Path: path,
}
if pi == nil {
continue
}
}

return procMapSlice(processMap), nil
}
found[pi.ScanPath] = pi
}

// procMapSlice converts a map of paths and their Process structs to a slice of Processes.
func procMapSlice(m map[string]Process) []Process {
ps := make([]Process, 0, len(m))
for _, v := range m {
ps := make([]*ProcessInfo, 0, len(found))
for _, v := range found {
ps = append(ps, v)
}

sort.Slice(ps, func(i, j int) bool {
return ps[i].Path < ps[j].Path
return ps[i].ScanPath < ps[j].ScanPath
})
return ps

return ps, nil
}

// isValidPath checks if the given path is valid.
func isValidPath(path string) bool {
// canStat checks if stat() works on a given path.
func canStat(path string) bool {
_, err := os.Stat(path)
return err == nil
}

// processInfo returns information about a process tuned for scanning.
func processInfo(ctx context.Context, p *process.Process) (*ProcessInfo, error) {
pi := &ProcessInfo{
PID: p.Pid,
}
name, err := p.Name()
if err != nil {
name = "<unknown>"
}
pi.Name = name

parent, err := p.PpidWithContext(ctx)
if err != nil {
parent = -1
}

// Skip Linux kernel threads that have no backing executable
if runtime.GOOS == "linux" && (p.Pid == 2 || parent == 2) {
return nil, nil
}
pi.PPID = parent

cmd, err := p.CmdlineSliceWithContext(ctx)
if err == nil {
pi.CmdLine = cmd
if len(cmd) > 0 && strings.HasPrefix(cmd[0], "/") {
pi.AdvertisedPath = cmd[0]
}
}

// on Linux, this is effectively readlink(/proc/X/exe), but it isn't fully resolved either
path, err := p.Exe()
if err == nil {
pi.ScanPath = path
if canStat(pi.ScanPath) {
return pi, nil
}
}

// fallback if p.Exe fails to be stattable
if runtime.GOOS == "linux" {
pi.ScanPath = fmt.Sprintf("/proc/%d/exe", p.Pid)

if canStat(pi.ScanPath) {
return pi, nil
}
}

// Settle for whatever binary we may have found in the process table
if canStat(pi.AdvertisedPath) {
pi.ScanPath = pi.AdvertisedPath
return pi, nil
}

if pi.AdvertisedPath != "" {
return nil, fmt.Errorf("%s[%d]: unable to stat %q or %q", pi.Name, pi.PID, pi.ScanPath, pi.AdvertisedPath)
}

return nil, fmt.Errorf("%s: unable to stat %q", pi.Name, pi.ScanPath)
}
25 changes: 17 additions & 8 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,31 @@ var (
)

// findFilesRecursively returns a list of files found recursively within a path.
func findFilesRecursively(ctx context.Context, root string) ([]string, error) {
clog.FromContext(ctx).Infof("finding files in %s ...", root)
func findFilesRecursively(ctx context.Context, rootPath string) ([]string, error) {
clog.FromContext(ctx).Infof("finding files in %s ...", rootPath)
var files []string

// Follow symlink if provided at the root
root, err := filepath.EvalSymlinks(root)
root, err := filepath.EvalSymlinks(rootPath)
if err != nil {
return nil, err
// Allow /proc/XXX/exe to be scanned even if symlink is not resolveable
if strings.HasPrefix(rootPath, "/proc/") {
root = rootPath
} else {
return nil, fmt.Errorf("eval %q: %w", rootPath, err)
}
}

err = filepath.WalkDir(root,
func(path string, info os.DirEntry, err error) error {
if err != nil {
clog.FromContext(ctx).Errorf("walk %s: %v", path, err)
return err
return fmt.Errorf("walk: %w", err)
}
if info.IsDir() || strings.Contains(path, "/.git/") {
return nil
}

files = append(files, path)

return nil
})
return files, err
Expand Down Expand Up @@ -241,8 +244,14 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report

paths, err := findFilesRecursively(ctx, scanPath)
if err != nil {
return nil, fmt.Errorf("find: %w", err)
if len(c.ScanPaths) == 1 {
return nil, fmt.Errorf("find: %w", err)
}
// try to scan remaining scan paths
logger.Errorf("find failed: %v", err)
continue
}

logger.Debug("files found", slog.Any("path count", len(paths)), slog.Any("scanPath", scanPath))

maxConcurrency := c.Concurrency
Expand Down

0 comments on commit 1fe7d0f

Please sign in to comment.