Skip to content

Commit

Permalink
Make all map operations concurrency-safe; fix nested archive extracti…
Browse files Browse the repository at this point in the history
…on (#424)

* Use sync.Map types for all maps involved in concurrent operations

Signed-off-by: egibs <[email protected]>

* Scan and nested archive fixes

Signed-off-by: egibs <[email protected]>

* Appease the linter

Signed-off-by: egibs <[email protected]>

---------

Signed-off-by: egibs <[email protected]>
  • Loading branch information
egibs authored Aug 22, 2024
1 parent 31f02a8 commit 94aec80
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 108 deletions.
119 changes: 78 additions & 41 deletions pkg/action/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"

"github.com/chainguard-dev/clog"
"github.com/ulikunitz/xz"
Expand All @@ -37,21 +38,32 @@ func extractTar(ctx context.Context, d string, f string) error {
}
defer tf.Close()

tr := tar.NewReader(tf)
if strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz") {
var tr *tar.Reader

switch {
case strings.Contains(f, ".apk") || strings.Contains(f, ".tar.gz") || strings.Contains(f, ".tgz"):
gzStream, err := gzip.NewReader(tf)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzStream.Close()
tr = tar.NewReader(gzStream)
}
if strings.Contains(filename, ".tar.xz") {
case strings.Contains(filename, ".tar.xz"):
_, err := tf.Seek(0, io.SeekStart) // Seek to start for xz reading
if err != nil {
return fmt.Errorf("failed to seek to start: %w", err)
}
xzStream, err := xz.NewReader(tf)
if err != nil {
return fmt.Errorf("failed to create xz reader: %w", err)
}
tr = tar.NewReader(xzStream)
default:
_, err := tf.Seek(0, io.SeekStart) // Seek to start for tar reading
if err != nil {
return fmt.Errorf("failed to seek to start: %w", err)
}
tr = tar.NewReader(tf)
}

for {
Expand Down Expand Up @@ -84,6 +96,7 @@ func extractTar(ctx context.Context, d string, f string) error {
}

if _, err := io.Copy(f, io.LimitReader(tr, maxBytes)); err != nil {
f.Close()
return fmt.Errorf("failed to copy file: %w", err)
}

Expand Down Expand Up @@ -169,23 +182,25 @@ func extractZip(ctx context.Context, d string, f string) error {

open, err := file.Open()
if err != nil {
open.Close()
return fmt.Errorf("failed to open file in zip: %w", err)
}

err = os.MkdirAll(filepath.Dir(name), 0o755)
if err != nil {
open.Close()
return fmt.Errorf("failed to create directory: %w", err)
}

mode := file.Mode() | 0o200
create, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode)
if err != nil {
create.Close()
open.Close()
return fmt.Errorf("failed to create file: %w", err)
}

if _, err = io.Copy(create, io.LimitReader(open, maxBytes)); err != nil {
open.Close()
create.Close()
return fmt.Errorf("failed to copy file: %w", err)
}

Expand All @@ -199,13 +214,14 @@ func extractNestedArchive(
ctx context.Context,
d string,
f string,
extracted map[string]bool,
extracted *sync.Map,
) error {
isArchive := false
ext := getExt(f)
if _, ok := archiveMap[ext]; ok {
isArchive = true
}
//nolint:nestif // ignore complexity of 8
if isArchive {
// Ensure the file was extracted and exists
fullPath := filepath.Join(d, f)
Expand All @@ -222,13 +238,27 @@ func extractNestedArchive(
return fmt.Errorf("extract nested archive: %w", err)
}
// Mark the file as extracted
extracted[f] = true
extracted.Store(f, true)

// Remove the nested archive file
// This is done to prevent the file from being scanned
if err := os.Remove(fullPath); err != nil {
return fmt.Errorf("failed to remove file: %w", err)
}

// Check if there are any newly extracted files that are also archives
files, err := os.ReadDir(d)
if err != nil {
return fmt.Errorf("failed to read directory after extraction: %w", err)
}
for _, file := range files {
relPath := filepath.Join(d, file.Name())
if _, isExtracted := extracted.Load(relPath); !isExtracted {
if err := extractNestedArchive(ctx, d, file.Name(), extracted); err != nil {
return fmt.Errorf("failed to extract nested archive %s: %w", file.Name(), err)
}
}
}
}
return nil
}
Expand All @@ -253,53 +283,60 @@ func extractArchiveToTempDir(ctx context.Context, path string) (string, error) {
return "", fmt.Errorf("failed to extract %s: %w", path, err)
}

extractedFiles := make(map[string]bool)
var extractedFiles sync.Map
files, err := os.ReadDir(tmpDir)
if err != nil {
return "", fmt.Errorf("failed to read files in directory %s: %w", tmpDir, err)
}
for _, file := range files {
extractedFiles[filepath.Join(tmpDir, file.Name())] = false
extractedFiles.Store(filepath.Join(tmpDir, file.Name()), false)
}

for file := range extractedFiles {
ext := getExt(file)
info, err := os.Stat(file)
if err != nil {
return "", fmt.Errorf("failed to stat file %s: %w", file, err)
}
switch mode := info.Mode(); {
case mode.IsDir():
err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
rel, err := filepath.Rel(tmpDir, path)
if err != nil {
return fmt.Errorf("filepath.Rel: %w", err)
}
if !d.IsDir() {
if err := extractNestedArchive(ctx, tmpDir, rel, extractedFiles); err != nil {
return fmt.Errorf("failed to extract nested archive %s: %w", rel, err)
extractedFiles.Range(func(key, _ any) bool {
//nolint: nestif // ignoring complexity of 11
if file, ok := key.(string); ok {
ext := getExt(file)
info, err := os.Stat(file)
if err != nil {
return false
}
switch mode := info.Mode(); {
case mode.IsDir():
err = filepath.WalkDir(file, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
rel, err := filepath.Rel(tmpDir, path)
if err != nil {
return fmt.Errorf("filepath.Rel: %w", err)
}
if !d.IsDir() {
if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil {
return fmt.Errorf("failed to extract nested archive %s: %w", rel, err)
}
}
}

return nil
})
return tmpDir, err
case mode.IsRegular():
if _, ok := archiveMap[ext]; ok {
rel, err := filepath.Rel(tmpDir, file)
return nil
})
if err != nil {
return "", fmt.Errorf("filepath.Rel: %w", err)
return false
}
if err := extractNestedArchive(ctx, tmpDir, rel, extractedFiles); err != nil {
return "", fmt.Errorf("extract nested archive %s: %w", rel, err)
return true
case mode.IsRegular():
if _, ok := archiveMap[ext]; ok {
rel, err := filepath.Rel(tmpDir, file)
if err != nil {
return false
}
if err := extractNestedArchive(ctx, tmpDir, rel, &extractedFiles); err != nil {
return false
}
}
return true
}
return tmpDir, nil
}
}
return true
})

return tmpDir, nil
}
Expand Down
Loading

0 comments on commit 94aec80

Please sign in to comment.