Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify package addition and vulnerability scanning #6579

Merged
merged 4 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions pkg/scanner/langpkg/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ var (
)

type Scanner interface {
Packages(target types.ScanTarget, options types.ScanOptions) types.Results
Scan(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Results, error)
}

Expand All @@ -34,24 +33,7 @@ func NewScanner() Scanner {
return &scanner{}
}

func (s *scanner) Packages(target types.ScanTarget, _ types.ScanOptions) types.Results {
var results types.Results
for _, app := range target.Applications {
if len(app.Packages) == 0 {
continue
}

results = append(results, types.Result{
Target: targetName(app.Type, app.FilePath),
Class: types.ClassLangPkg,
Type: app.Type,
Packages: app.Packages,
})
}
return results
}

func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.ScanOptions) (types.Results, error) {
func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.ScanOptions) (types.Results, error) {
apps := target.Applications
log.Info("Number of language-specific files", log.Int("num", len(apps)))
if len(apps) == 0 {
Expand All @@ -66,34 +48,53 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.Sca
}

ctx = log.WithContextPrefix(ctx, string(app.Type))
result := types.Result{
Target: targetName(app.Type, app.FilePath),
Class: types.ClassLangPkg,
Type: app.Type,
}

// Prevent the same log messages from being displayed many times for the same type.
if _, ok := printedTypes[app.Type]; !ok {
log.InfoContext(ctx, "Detecting vulnerabilities...")
printedTypes[app.Type] = struct{}{}
if opts.ListAllPackages {
sort.Sort(app.Packages)
result.Packages = app.Packages
}

log.DebugContext(ctx, "Scanning packages from the file", log.String("file_path", app.FilePath))
vulns, err := library.Detect(ctx, app.Type, app.Packages)
if err != nil {
return nil, xerrors.Errorf("failed vulnerability detection of packages: %w", err)
} else if len(vulns) == 0 {
continue
if opts.Scanners.Enabled(types.VulnerabilityScanner) {
var err error
result.Vulnerabilities, err = s.scanVulnerabilities(ctx, app, printedTypes)
if err != nil {
return nil, err
}
}

results = append(results, types.Result{
Target: targetName(app.Type, app.FilePath),
Vulnerabilities: vulns,
Class: types.ClassLangPkg,
Type: app.Type,
})
if len(result.Packages) == 0 && len(result.Vulnerabilities) == 0 {
continue
}
results = append(results, result)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Target < results[j].Target
})
return results, nil
}

func (s *scanner) scanVulnerabilities(ctx context.Context, app ftypes.Application, printedTypes map[ftypes.LangType]struct{}) (
[]types.DetectedVulnerability, error) {

// Prevent the same log messages from being displayed many times for the same type.
if _, ok := printedTypes[app.Type]; !ok {
log.InfoContext(ctx, "Detecting vulnerabilities...")
printedTypes[app.Type] = struct{}{}
}

log.DebugContext(ctx, "Scanning packages for vulnerabilities", log.String("file_path", app.FilePath))
vulns, err := library.Detect(ctx, app.Type, app.Packages)
if err != nil {
return nil, xerrors.Errorf("failed vulnerability detection of libraries: %w", err)
}
return vulns, err
}

func targetName(appType ftypes.LangType, filePath string) string {
if t, ok := PkgTargets[appType]; ok && filePath == "" {
// When the file path is empty, we will overwrite it with the pre-defined value.
Expand Down
66 changes: 18 additions & 48 deletions pkg/scanner/local/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"golang.org/x/xerrors"

dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
ospkgDetector "github.com/aquasecurity/trivy/pkg/detector/ospkg"
"github.com/aquasecurity/trivy/pkg/fanal/analyzer"
"github.com/aquasecurity/trivy/pkg/fanal/applier"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
Expand Down Expand Up @@ -105,39 +106,19 @@ func (s Scanner) Scan(ctx context.Context, targetName, artifactKey string, blobK
}

func (s Scanner) ScanTarget(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Results, ftypes.OS, error) {
var eosl bool
var results, pkgResults types.Results
var err error
var results types.Results

// By default, we need to remove dev dependencies from the result
// IncludeDevDeps option allows you not to remove them
excludeDevDeps(target.Applications, options.IncludeDevDeps)

// Fill OS packages and language-specific packages
if options.ListAllPackages {
if res := s.osPkgScanner.Packages(target, options); len(res.Packages) != 0 {
pkgResults = append(pkgResults, res)
}
pkgResults = append(pkgResults, s.langPkgScanner.Packages(target, options)...)
}

// Scan packages for vulnerabilities
if options.Scanners.Enabled(types.VulnerabilityScanner) {
var vulnResults types.Results
vulnResults, eosl, err = s.scanVulnerabilities(ctx, target, options)
if err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("failed to detect vulnerabilities: %w", err)
}
target.OS.Eosl = eosl

// Merge package results into vulnerability results
mergedResults := s.fillPkgsInVulns(pkgResults, vulnResults)

results = append(results, mergedResults...)
} else {
// If vulnerability scanning is not enabled, it just adds package results.
results = append(results, pkgResults...)
// Add packages if needed and scan packages for vulnerabilities
vulnResults, eosl, err := s.scanVulnerabilities(ctx, target, options)
if err != nil {
return nil, ftypes.OS{}, xerrors.Errorf("failed to detect vulnerabilities: %w", err)
}
target.OS.Eosl = eosl
results = append(results, vulnResults...)

// Store misconfigurations
results = append(results, s.misconfsToResults(target.Misconfigurations, options)...)
Expand Down Expand Up @@ -172,17 +153,24 @@ func (s Scanner) ScanTarget(ctx context.Context, target types.ScanTarget, option

func (s Scanner) scanVulnerabilities(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (
types.Results, bool, error) {
if !options.ListAllPackages && !options.Scanners.Enabled(types.VulnerabilityScanner) {
return nil, false, nil
}

var eosl bool
var results types.Results

if slices.Contains(options.VulnType, types.VulnTypeOS) {
vuln, detectedEOSL, err := s.osPkgScanner.Scan(ctx, target, options)
if err != nil {
switch {
case errors.Is(err, ospkgDetector.ErrUnsupportedOS):
// do nothing
case err != nil:
return nil, false, xerrors.Errorf("unable to scan OS packages: %w", err)
} else if vuln.Target != "" {
case vuln.Target != "":
results = append(results, vuln)
eosl = detectedEOSL
}
eosl = detectedEOSL
}

if slices.Contains(options.VulnType, types.VulnTypeLibrary) {
Expand All @@ -196,24 +184,6 @@ func (s Scanner) scanVulnerabilities(ctx context.Context, target types.ScanTarge
return results, eosl, nil
}

func (s Scanner) fillPkgsInVulns(pkgResults, vulnResults types.Results) types.Results {
var results types.Results
if len(pkgResults) == 0 { // '--list-all-pkgs' == false or packages not found
return vulnResults
}
for _, result := range pkgResults {
if r, found := lo.Find(vulnResults, func(r types.Result) bool {
return r.Class == result.Class && r.Target == result.Target && r.Type == result.Type
}); found {
r.Packages = result.Packages
results = append(results, r)
} else { // when package result has no vulnerabilities we still need to add it to result(for 'list-all-pkgs')
results = append(results, result)
}
}
return results
}

func (s Scanner) misconfsToResults(misconfs []ftypes.Misconfiguration, options types.ScanOptions) types.Results {
if !ShouldScanMisconfigOrRbac(options.Scanners) &&
!options.ImageConfigScanners.Enabled(types.MisconfigScanner) {
Expand Down
50 changes: 22 additions & 28 deletions pkg/scanner/ospkg/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ospkg

import (
"context"
"errors"
"fmt"
"sort"
"time"
Expand All @@ -15,7 +14,6 @@ import (
)

type Scanner interface {
Packages(target types.ScanTarget, options types.ScanOptions) types.Result
Scan(ctx context.Context, target types.ScanTarget, options types.ScanOptions) (types.Result, bool, error)
}

Expand All @@ -25,21 +23,7 @@ func NewScanner() Scanner {
return &scanner{}
}

func (s *scanner) Packages(target types.ScanTarget, _ types.ScanOptions) types.Result {
if len(target.Packages) == 0 || !target.OS.Detected() {
return types.Result{}
}

sort.Sort(target.Packages)
return types.Result{
Target: fmt.Sprintf("%s (%s %s)", target.Name, target.OS.Family, target.OS.Name),
Class: types.ClassOSPkg,
Type: target.OS.Family,
Packages: target.Packages,
}
}

func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.ScanOptions) (types.Result, bool, error) {
func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, opts types.ScanOptions) (types.Result, bool, error) {
if !target.OS.Detected() {
log.Debug("Detected OS: unknown")
return types.Result{}, false, nil
Expand All @@ -52,19 +36,29 @@ func (s *scanner) Scan(ctx context.Context, target types.ScanTarget, _ types.Sca
target.OS.Name += "-ESM"
}

result := types.Result{
Target: fmt.Sprintf("%s (%s %s)", target.Name, target.OS.Family, target.OS.Name),
Class: types.ClassOSPkg,
Type: target.OS.Family,
}

if opts.ListAllPackages {
sort.Sort(target.Packages)
result.Packages = target.Packages
}

if !opts.Scanners.Enabled(types.VulnerabilityScanner) {
// Return packages only
return result, false, nil
}

vulns, eosl, err := ospkgDetector.Detect(ctx, "", target.OS.Family, target.OS.Name, target.Repository, time.Time{},
target.Packages)
if errors.Is(err, ospkgDetector.ErrUnsupportedOS) {
return types.Result{}, false, nil
} else if err != nil {
return types.Result{}, false, xerrors.Errorf("failed vulnerability detection of OS packages: %w", err)
if err != nil {
// Return a result for those who want to override the error handling.
return result, false, xerrors.Errorf("failed vulnerability detection of OS packages: %w", err)
}
result.Vulnerabilities = vulns

artifactDetail := fmt.Sprintf("%s (%s %s)", target.Name, target.OS.Family, target.OS.Name)
return types.Result{
Target: artifactDetail,
Vulnerabilities: vulns,
Class: types.ClassOSPkg,
Type: target.OS.Family,
}, eosl, nil
return result, eosl, nil
}