diff --git a/cmd/mal/mal.go b/cmd/mal/mal.go index 7aa4acd2..23754535 100644 --- a/cmd/mal/mal.go +++ b/cmd/mal/mal.go @@ -381,13 +381,14 @@ func main() { // When scanning processes, load all of the valid commands (paths) // and store them as the ScanPaths if mc.Processes { - processPaths, err := action.GetAllProcessPaths(ctx) + ps, err := action.ActiveProcesses(ctx) if err != nil { returnCode = ExitActionFailed return fmt.Errorf("process paths: %w", err) } - for _, p := range processPaths { - mc.ScanPaths = append(mc.ScanPaths, p.Path) + for _, p := range ps { + // in the future, we'll also want to attach process info directly + mc.ScanPaths = append(mc.ScanPaths, p.ScanPath) } } @@ -458,13 +459,13 @@ func main() { // When scanning processes, load all of the valid commands (paths) // and store them as the ScanPaths if mc.Processes { - processPaths, err := action.GetAllProcessPaths(ctx) + ps, err := action.ActiveProcesses(ctx) if err != nil { returnCode = ExitActionFailed return fmt.Errorf("process paths: %w", err) } - for _, p := range processPaths { - mc.ScanPaths = append(mc.ScanPaths, p.Path) + for _, p := range ps { + mc.ScanPaths = append(mc.ScanPaths, p.ScanPath) } } diff --git a/pkg/action/process.go b/pkg/action/process.go index f6b9b1cc..e427d610 100644 --- a/pkg/action/process.go +++ b/pkg/action/process.go @@ -4,61 +4,120 @@ import ( "context" "fmt" "os" + "runtime" "sort" + "strings" "github.com/chainguard-dev/clog" "github.com/shirou/gopsutil/v4/process" ) -type Process struct { - PID int32 - Path string +type ProcessInfo struct { + PID int32 + PPID int32 + Name string + ScanPath string + AdvertisedPath string + CmdLine []string } -// GetAllProcessPaths is an exported function that returns a slice of Process PIDs and commands (path). -func GetAllProcessPaths(ctx context.Context) ([]Process, error) { +// ActiveProcesses is an exported function that a list of active processes. +func ActiveProcesses(ctx context.Context) ([]*ProcessInfo, error) { // Retrieve all of the active PIDs procs, err := process.ProcessesWithContext(ctx) if err != nil { return nil, fmt.Errorf("processes: %w", err) } - // Store PIDs and their respective commands (paths) in a map of paths and their Process structs - processMap := make(map[string]Process, len(procs)) + found := map[string]*ProcessInfo{} for _, p := range procs { - path, err := p.Exe() - // Executable resolution is non-fatal + pi, err := processInfo(ctx, p) if err != nil { - name, _ := p.Name() - clog.Errorf("%s[%d]: %v", name, p.Pid, err) + clog.Warnf("skipping pid %d: %v", p.Pid, err) continue } - if _, exists := processMap[path]; !exists && path != "" && isValidPath(path) { - processMap[path] = Process{ - PID: p.Pid, - Path: path, - } + if pi == nil { + continue } - } - return procMapSlice(processMap), nil -} + found[pi.ScanPath] = pi + } -// procMapSlice converts a map of paths and their Process structs to a slice of Processes. -func procMapSlice(m map[string]Process) []Process { - ps := make([]Process, 0, len(m)) - for _, v := range m { + ps := make([]*ProcessInfo, 0, len(found)) + for _, v := range found { ps = append(ps, v) } sort.Slice(ps, func(i, j int) bool { - return ps[i].Path < ps[j].Path + return ps[i].ScanPath < ps[j].ScanPath }) - return ps + + return ps, nil } -// isValidPath checks if the given path is valid. -func isValidPath(path string) bool { +// canStat checks if stat() works on a given path. +func canStat(path string) bool { _, err := os.Stat(path) return err == nil } + +// processInfo returns information about a process tuned for scanning. +func processInfo(ctx context.Context, p *process.Process) (*ProcessInfo, error) { + pi := &ProcessInfo{ + PID: p.Pid, + } + name, err := p.Name() + if err != nil { + name = "" + } + pi.Name = name + + parent, err := p.PpidWithContext(ctx) + if err != nil { + parent = -1 + } + + // Skip Linux kernel threads that have no backing executable + if runtime.GOOS == "linux" && (p.Pid == 2 || parent == 2) { + return nil, nil + } + pi.PPID = parent + + cmd, err := p.CmdlineSliceWithContext(ctx) + if err == nil { + pi.CmdLine = cmd + if len(cmd) > 0 && strings.HasPrefix(cmd[0], "/") { + pi.AdvertisedPath = cmd[0] + } + } + + // on Linux, this is effectively readlink(/proc/X/exe), but it isn't fully resolved either + path, err := p.Exe() + if err == nil { + pi.ScanPath = path + if canStat(pi.ScanPath) { + return pi, nil + } + } + + // fallback if p.Exe fails to be stattable + if runtime.GOOS == "linux" { + pi.ScanPath = fmt.Sprintf("/proc/%d/exe", p.Pid) + + if canStat(pi.ScanPath) { + return pi, nil + } + } + + // Settle for whatever binary we may have found in the process table + if canStat(pi.AdvertisedPath) { + pi.ScanPath = pi.AdvertisedPath + return pi, nil + } + + if pi.AdvertisedPath != "" { + return nil, fmt.Errorf("%s[%d]: unable to stat %q or %q", pi.Name, pi.PID, pi.ScanPath, pi.AdvertisedPath) + } + + return nil, fmt.Errorf("%s: unable to stat %q", pi.Name, pi.ScanPath) +} diff --git a/pkg/action/scan.go b/pkg/action/scan.go index 85209685..461cf4b0 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -34,28 +34,31 @@ var ( ) // findFilesRecursively returns a list of files found recursively within a path. -func findFilesRecursively(ctx context.Context, root string) ([]string, error) { - clog.FromContext(ctx).Infof("finding files in %s ...", root) +func findFilesRecursively(ctx context.Context, rootPath string) ([]string, error) { + clog.FromContext(ctx).Infof("finding files in %s ...", rootPath) var files []string // Follow symlink if provided at the root - root, err := filepath.EvalSymlinks(root) + root, err := filepath.EvalSymlinks(rootPath) if err != nil { - return nil, err + // Allow /proc/XXX/exe to be scanned even if symlink is not resolveable + if strings.HasPrefix(rootPath, "/proc/") { + root = rootPath + } else { + return nil, fmt.Errorf("eval %q: %w", rootPath, err) + } } err = filepath.WalkDir(root, func(path string, info os.DirEntry, err error) error { if err != nil { - clog.FromContext(ctx).Errorf("walk %s: %v", path, err) - return err + return fmt.Errorf("walk: %w", err) } if info.IsDir() || strings.Contains(path, "/.git/") { return nil } files = append(files, path) - return nil }) return files, err @@ -241,8 +244,14 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report paths, err := findFilesRecursively(ctx, scanPath) if err != nil { - return nil, fmt.Errorf("find: %w", err) + if len(c.ScanPaths) == 1 { + return nil, fmt.Errorf("find: %w", err) + } + // try to scan remaining scan paths + logger.Errorf("find failed: %v", err) + continue } + logger.Debug("files found", slog.Any("path count", len(paths)), slog.Any("scanPath", scanPath)) maxConcurrency := c.Concurrency