From 94aec80a71e7a3ce8ed3a05d26df2dc10a44d49d Mon Sep 17 00:00:00 2001 From: Evan Gibler <20933572+egibs@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:50:36 -0500 Subject: [PATCH] Make all map operations concurrency-safe; fix nested archive extraction (#424) * Use sync.Map types for all maps involved in concurrent operations Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Scan and nested archive fixes Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Appease the linter Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --------- Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --- pkg/action/archive.go | 119 +++++++++++++++++++++------------ pkg/action/scan.go | 148 +++++++++++++++++++++++------------------- 2 files changed, 159 insertions(+), 108 deletions(-) diff --git a/pkg/action/archive.go b/pkg/action/archive.go index 3cd295d5b..09ab467b1 100644 --- a/pkg/action/archive.go +++ b/pkg/action/archive.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "strings" + "sync" "github.com/chainguard-dev/clog" "github.com/ulikunitz/xz" @@ -37,21 +38,32 @@ func extractTar(ctx context.Context, d string, f string) error { } defer tf.Close() - tr := tar.NewReader(tf) - if strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz") { + var tr *tar.Reader + + switch { + case strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz"): gzStream, err := gzip.NewReader(tf) if err != nil { return fmt.Errorf("failed to create gzip reader: %w", err) } defer gzStream.Close() tr = tar.NewReader(gzStream) - } - if strings.Contains(filename, ".tar.xz") { + case strings.Contains(filename, ".tar.xz"): + _, err := tf.Seek(0, io.SeekStart) // Seek to start for xz reading + if err != nil { + return fmt.Errorf("failed to seek to start: %w", err) + } xzStream, err := xz.NewReader(tf) if err != nil { return fmt.Errorf("failed to create xz reader: %w", err) } tr = tar.NewReader(xzStream) + default: + _, err := tf.Seek(0, io.SeekStart) // Seek to start for tar reading + if err != nil { + return fmt.Errorf("failed to seek to start: %w", err) + } + tr = tar.NewReader(tf) } for { @@ -84,6 +96,7 @@ func extractTar(ctx context.Context, d string, f string) error { } if _, err := io.Copy(f, io.LimitReader(tr, maxBytes)); err != nil { + f.Close() return fmt.Errorf("failed to copy file: %w", err) } @@ -169,23 +182,25 @@ func extractZip(ctx context.Context, d string, f string) error { open, err := file.Open() if err != nil { - open.Close() return fmt.Errorf("failed to open file in zip: %w", err) } err = os.MkdirAll(filepath.Dir(name), 0o755) if err != nil { + open.Close() return fmt.Errorf("failed to create directory: %w", err) } mode := file.Mode() | 0o200 create, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) if err != nil { - create.Close() + open.Close() return fmt.Errorf("failed to create file: %w", err) } if _, err = io.Copy(create, io.LimitReader(open, maxBytes)); err != nil { + open.Close() + create.Close() return fmt.Errorf("failed to copy file: %w", err) } @@ -199,13 +214,14 @@ func extractNestedArchive( ctx context.Context, d string, f string, - extracted map[string]bool, + extracted *sync.Map, ) error { isArchive := false ext := getExt(f) if _, ok := archiveMap[ext]; ok { isArchive = true } + //nolint:nestif // ignore complexity of 8 if isArchive { // Ensure the file was extracted and exists fullPath := filepath.Join(d, f) @@ -222,13 +238,27 @@ func extractNestedArchive( return fmt.Errorf("extract nested archive: %w", err) } // Mark the file as extracted - extracted[f] = true + extracted.Store(f, true) // Remove the nested archive file // This is done to prevent the file from being scanned if err := os.Remove(fullPath); err != nil { return fmt.Errorf("failed to remove file: %w", err) } + + // Check if there are any newly extracted files that are also archives + files, err := os.ReadDir(d) + if err != nil { + return fmt.Errorf("failed to read directory after extraction: %w", err) + } + for _, file := range files { + relPath := filepath.Join(d, file.Name()) + if _, isExtracted := extracted.Load(relPath); !isExtracted { + if err := extractNestedArchive(ctx, d, file.Name(), extracted); err != nil { + return fmt.Errorf("failed to extract nested archive %s: %w", file.Name(), err) + } + } + } } return nil } @@ -253,53 +283,60 @@ func extractArchiveToTempDir(ctx context.Context, path string) (string, error) { return "", fmt.Errorf("failed to extract %s: %w", path, err) } - extractedFiles := make(map[string]bool) + var extractedFiles sync.Map files, err := os.ReadDir(tmpDir) if err != nil { return "", fmt.Errorf("failed to read files in directory %s: %w", tmpDir, err) } for _, file := range files { - extractedFiles[filepath.Join(tmpDir, file.Name())] = false + extractedFiles.Store(filepath.Join(tmpDir, file.Name()), false) } - for file := range extractedFiles { - ext := getExt(file) - info, err := os.Stat(file) - if err != nil { - return "", fmt.Errorf("failed to stat file %s: %w", file, err) - } - switch mode := info.Mode(); { - case mode.IsDir(): - err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - rel, err := filepath.Rel(tmpDir, path) - if err != nil { - return fmt.Errorf("filepath.Rel: %w", err) - } - if !d.IsDir() { - if err := extractNestedArchive(ctx, tmpDir, rel, extractedFiles); err != nil { - return fmt.Errorf("failed to extract nested archive %s: %w", rel, err) + extractedFiles.Range(func(key, _ any) bool { + //nolint: nestif // ignoring complexity of 11 + if file, ok := key.(string); ok { + ext := getExt(file) + info, err := os.Stat(file) + if err != nil { + return false + } + switch mode := info.Mode(); { + case mode.IsDir(): + err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + rel, err := filepath.Rel(tmpDir, path) + if err != nil { + return fmt.Errorf("filepath.Rel: %w", err) + } + if !d.IsDir() { + if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil { + return fmt.Errorf("failed to extract nested archive %s: %w", rel, err) + } } - } - return nil - }) - return tmpDir, err - case mode.IsRegular(): - if _, ok := archiveMap[ext]; ok { - rel, err := filepath.Rel(tmpDir, file) + return nil + }) if err != nil { - return "", fmt.Errorf("filepath.Rel: %w", err) + return false } - if err := extractNestedArchive(ctx, tmpDir, rel, extractedFiles); err != nil { - return "", fmt.Errorf("extract nested archive %s: %w", rel, err) + return true + case mode.IsRegular(): + if _, ok := archiveMap[ext]; ok { + rel, err := filepath.Rel(tmpDir, file) + if err != nil { + return false + } + if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil { + return false + } } + return true } - return tmpDir, nil } - } + return true + }) return tmpDir, nil } diff --git a/pkg/action/scan.go b/pkg/action/scan.go index 76d57f014..818bc8a10 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -125,20 +125,26 @@ func isSupportedArchive(path string) bool { } // errIfMatch generates the right error if a match is encountered. -func errIfHitOrMiss(frs map[string]*bincapz.FileReport, kind string, scanPath string, errIfHit bool, errIfMiss bool) error { - bMap := map[string]bool{} +func errIfHitOrMiss(frs *sync.Map, kind string, scanPath string, errIfHit bool, errIfMiss bool) error { + var bMap sync.Map count := 0 - for _, fr := range frs { - for _, b := range fr.Behaviors { - count++ - bMap[b.ID] = true + frs.Range(func(_, value any) bool { + if fr, ok := value.(*bincapz.FileReport); ok { + for _, b := range fr.Behaviors { + count++ + bMap.Store(b.ID, true) + } } - } + return true + }) bList := []string{} - for b := range bMap { - bList = append(bList, b) - } + bMap.Range(func(key, _ any) bool { + if k, ok := key.(string); ok { + bList = append(bList, k) + } + return true + }) sort.Strings(bList) suffix := "" @@ -158,6 +164,8 @@ func errIfHitOrMiss(frs map[string]*bincapz.FileReport, kind string, scanPath st } // recursiveScan recursively YARA scans the configured paths - handling archives and OCI images. +// + func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, error) { logger := clog.FromContext(ctx) logger.Debug("recursive scan", slog.Any("config", c)) @@ -175,9 +183,8 @@ func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, erro yrs := c.Rules logger.Infof("%d rules loaded", len(yrs.GetRules())) - scanPathFindings := map[string]*bincapz.FileReport{} + var scanPathFindings sync.Map - var results sync.Map for _, scanPath := range c.ScanPaths { logger.Debug("recursive scan", slog.Any("scanPath", scanPath)) imageURI := "" @@ -213,58 +220,65 @@ func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, erro } close(pc) - process := func(path string) error { - //nolint:nestif // ignore complexity of 13 - if isSupportedArchive(path) { - logger.Debug("found archive path", slog.Any("path", path)) - frs, err := processArchive(ctx, c, yrs, path, logger) - if err != nil { - logger.Errorf("unable to process %s: %v", path, err) + handleArchive := func(path string) error { + logger.Debug("found archive path", slog.Any("path", path)) + frs, err := processArchive(ctx, c, yrs, path, logger) + if err != nil { + logger.Errorf("unable to process %s: %v", path, err) + } + // If we're handling an archive within an OCI archive, wait to for other files to declare a miss + if !c.OCI { + if err := errIfHitOrMiss(frs, "archive", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil { + return err } + } - // If we're handling an archive within an OCI archive, wait to for other files to declare a miss - if !c.OCI { - if err := errIfHitOrMiss(frs, "archive", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil { - results.Store(path, &bincapz.FileReport{}) - return err + frs.Range(func(key, value any) bool { + if k, ok := key.(string); ok { + if fr, ok := value.(*bincapz.FileReport); ok { + scanPathFindings.Store(k, fr) } } + return true + }) + return nil + } - for extractedPath, fr := range frs { - results.Store(extractedPath, fr) - } - } else { - trimPath := "" - if c.OCI { - scanPath = imageURI - trimPath = ociExtractPath - } + handleFile := func(path string) error { + trimPath := "" + if c.OCI { + scanPath = imageURI + trimPath = ociExtractPath + } - logger.Debug("processing path", slog.Any("path", path)) - fr, err := processFile(ctx, c, yrs, path, scanPath, trimPath, logger) - if err != nil { - results.Store(path, &bincapz.FileReport{}) - return err - } - if fr != nil { - results.Store(path, fr) - if !c.OCI { - if err := errIfHitOrMiss(map[string]*bincapz.FileReport{path: fr}, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil { - logger.Debugf("match short circuit: %s", err) - results.Store(path, &bincapz.FileReport{}) - } + logger.Debug("processing path", slog.Any("path", path)) + fr, err := processFile(ctx, c, yrs, path, scanPath, trimPath, logger) + if err != nil { + scanPathFindings.Store(path, &bincapz.FileReport{}) + return err + } + if fr != nil { + scanPathFindings.Store(path, fr) + if !c.OCI { + var frMap sync.Map + frMap.Store(path, fr) + if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil { + logger.Debugf("match short circuit: %s", err) + scanPathFindings.Store(path, &bincapz.FileReport{}) } } } return nil } - var g errgroup.Group g.SetLimit(maxConcurrency) for path := range pc { path := path g.Go(func() error { - return process(path) + if isSupportedArchive(path) { + return handleArchive(path) + } + return handleFile(path) }) } @@ -273,14 +287,9 @@ func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, erro } var pathKeys []string - results.Range(func(key, value interface{}) bool { + scanPathFindings.Range(func(key, _ interface{}) bool { if k, ok := key.(string); ok { pathKeys = append(pathKeys, k) - value, ok := value.(*bincapz.FileReport) - if !ok { - return false - } - scanPathFindings[k] = value } return true }) @@ -288,7 +297,7 @@ func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, erro // OCI images hadle their match his/miss logic per scanPath if c.OCI { - if err := errIfHitOrMiss(scanPathFindings, "image", imageURI, c.ErrFirstHit, c.ErrFirstMiss); err != nil { + if err := errIfHitOrMiss(&scanPathFindings, "image", imageURI, c.ErrFirstHit, c.ErrFirstMiss); err != nil { return r, err } @@ -299,14 +308,19 @@ func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, erro // Add the sorted paths and file reports to the parent report and render the results for _, k := range pathKeys { - r.Files.Set(k, scanPathFindings[k]) - if c.Renderer != nil && r.Diff == nil { - if scanPathFindings[k].RiskScore < c.MinFileRisk { - return nil, nil - } - - if err := c.Renderer.File(ctx, scanPathFindings[k]); err != nil { - return nil, fmt.Errorf("render: %w", err) + finding, ok := scanPathFindings.Load(k) + if !ok { + return nil, fmt.Errorf("could not load finding from sync map") + } + if fr, ok := finding.(*bincapz.FileReport); ok { + r.Files.Set(k, fr) + if c.Renderer != nil && r.Diff == nil { + if fr.RiskScore < c.MinFileRisk { + return nil, nil + } + if err := c.Renderer.File(ctx, fr); err != nil { + return nil, fmt.Errorf("render: %w", err) + } } } } @@ -316,11 +330,11 @@ func recursiveScan(ctx context.Context, c bincapz.Config) (*bincapz.Report, erro } // processArchive extracts and scans a single archive file. -func processArchive(ctx context.Context, c bincapz.Config, yrs *yara.Rules, archivePath string, logger *clog.Logger) (map[string]*bincapz.FileReport, error) { +func processArchive(ctx context.Context, c bincapz.Config, yrs *yara.Rules, archivePath string, logger *clog.Logger) (*sync.Map, error) { logger = logger.With("archivePath", archivePath) var err error - frs := map[string]*bincapz.FileReport{} + var frs sync.Map tmpRoot, err := extractArchiveToTempDir(ctx, archivePath) if err != nil { @@ -338,14 +352,14 @@ func processArchive(ctx context.Context, c bincapz.Config, yrs *yara.Rules, arch return nil, err } if fr != nil { - frs[extractedFilePath] = fr + frs.Store(extractedFilePath, fr) } } if err := os.RemoveAll(tmpRoot); err != nil { logger.Errorf("remove %s: %v", tmpRoot, err) } - return frs, nil + return &frs, nil } // processFile scans a single output file, rendering live output if available.