Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scan-9] Update enumeration logic #3626

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/engine/circleci.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ func (e *Engine) ScanCircleCI(ctx context.Context, token string) (sources.JobPro
if err := circleSource.Init(ctx, "trufflehog - Circle CI", jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, circleSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, circleSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ func (e *Engine) ScanDocker(ctx context.Context, c sources.DockerConfig) (source
if err := dockerSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, dockerSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, dockerSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ func (e *Engine) ScanElasticsearch(ctx context.Context, c sources.ElasticsearchC
if err := elasticsearchSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, elasticsearchSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, elasticsearchSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ func (e *Engine) ScanFileSystem(ctx context.Context, c sources.FilesystemConfig)
if err := fileSystemSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, fileSystemSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, fileSystemSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (e *Engine) ScanGCS(ctx context.Context, c sources.GCSConfig) (sources.JobP
if err := gcsSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, int(c.Concurrency)); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, gcsSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, gcsSource)
}

func isAuthValid(ctx context.Context, c sources.GCSConfig, connection *sourcespb.GCS) bool {
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) (sources.JobP
return sources.JobProgressRef{}, err
}

return e.sourceManager.Run(ctx, sourceName, gitSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, gitSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) (source
return sources.JobProgressRef{}, err
}
githubSource.WithScanOptions(scanOptions)
return e.sourceManager.Run(ctx, sourceName, githubSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, githubSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/github_experimental.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ func (e *Engine) ScanGitHubExperimental(ctx context.Context, c sources.GitHubExp
return sources.JobProgressRef{}, err
}
githubExperimentalSource.WithScanOptions(scanOptions)
return e.sourceManager.Run(ctx, sourceName, githubExperimentalSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, githubExperimentalSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ func (e *Engine) ScanGitLab(ctx context.Context, c sources.GitlabConfig) (source
return sources.JobProgressRef{}, err
}
gitlabSource.WithScanOptions(scanOptions)
return e.sourceManager.Run(ctx, sourceName, gitlabSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, gitlabSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ func (e *Engine) ScanHuggingface(ctx context.Context, c HuggingfaceConfig) (sour
if err := huggingfaceSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, huggingfaceSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, huggingfaceSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/jenkins.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ func (e *Engine) ScanJenkins(ctx context.Context, jenkinsConfig JenkinsConfig) (
if err := jenkinsSource.Init(ctx, "trufflehog - Jenkins", jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, jenkinsSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, jenkinsSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/postman.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ func (e *Engine) ScanPostman(ctx context.Context, c sources.PostmanConfig) (sour
if err := postmanSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, c.Concurrency); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, postmanSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, postmanSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ func (e *Engine) ScanS3(ctx context.Context, c sources.S3Config) (sources.JobPro
if err := s3Source.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, s3Source)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, s3Source)
}
2 changes: 1 addition & 1 deletion pkg/engine/syslog.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ func (e *Engine) ScanSyslog(ctx context.Context, c sources.SyslogConfig) (source
}
syslogSource.InjectConnection(connection)

return e.sourceManager.Run(ctx, sourceName, syslogSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, syslogSource)
}
2 changes: 1 addition & 1 deletion pkg/engine/travisci.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ func (e *Engine) ScanTravisCI(ctx context.Context, token string) (sources.JobPro
if err := travisSource.Init(ctx, sourceName, jobID, sourceID, true, &conn, runtime.NumCPU()); err != nil {
return sources.JobProgressRef{}, err
}
return e.sourceManager.Run(ctx, sourceName, travisSource)
return e.sourceManager.EnumerateAndScan(ctx, sourceName, travisSource)
}
122 changes: 120 additions & 2 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ func (s *SourceManager) GetIDs(ctx context.Context, sourceName string, kind sour
return s.api.GetIDs(ctx, sourceName, kind)
}

// Run blocks until a resource is available to run the source, then
// EnumerateAndScan blocks until a resource is available to run the source, then
// asynchronously runs it. Error information is stored and accessible via the
// JobProgressRef as it becomes available.
func (s *SourceManager) Run(ctx context.Context, sourceName string, source Source, targets ...ChunkingTarget) (JobProgressRef, error) {
func (s *SourceManager) EnumerateAndScan(ctx context.Context, sourceName string, source Source, targets ...ChunkingTarget) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx); err != nil {
Expand Down Expand Up @@ -169,6 +169,54 @@ func (s *SourceManager) Run(ctx context.Context, sourceName string, source Sourc
return progress.Ref(), nil
}

func (s *SourceManager) Enumerate(ctx context.Context, sourceName string, source Source, reporter UnitReporter) (JobProgressRef, error) {
sourceID, jobID := source.SourceID(), source.JobID()
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx); err != nil {
return JobProgressRef{
SourceName: sourceName,
SourceID: sourceID,
JobID: jobID,
}, err
}

// Create a JobProgress object for tracking progress.
sem := s.sem
ctx, cancel := context.WithCancelCause(ctx)
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
if err := sem.Acquire(ctx, 1); err != nil {
// Context cancelled.
progress.ReportError(Fatal{err})
return progress.Ref(), Fatal{err}
}

// Wrap the passed in reporter so we update the progress information.
reporter = baseUnitReporter{
child: reporter,
progress: progress,
}

s.wg.Add(1)
go func() {
// Call Finish after the semaphore has been released.
defer progress.Finish()
defer sem.Release(1)
defer s.wg.Done()
ctx := context.WithValues(ctx,
"source_manager_worker_id", common.RandomID(5),
)
defer common.Recover(ctx)
defer cancel(nil)
if err := s.enumerate(ctx, source, progress, reporter); err != nil {
select {
case s.firstErr <- err:
default:
}
}
}()
return progress.Ref(), nil
}

// Chunks returns the read only channel of all the chunks produced by all of
// the sources managed by this manager.
func (s *SourceManager) Chunks() <-chan *Chunk {
Expand Down Expand Up @@ -286,6 +334,75 @@ func (s *SourceManager) run(ctx context.Context, source Source, report *JobProgr
return s.runWithoutUnits(ctx, source, report, targets...)
}

// enumerate is a helper method to enumerate a Source.
func (s *SourceManager) enumerate(ctx context.Context, source Source, report *JobProgress, reporter UnitReporter) error {
report.Start(time.Now())
defer func() { report.End(time.Now()) }()

defer func() {
if err := context.Cause(ctx); err != nil {
report.ReportError(Fatal{err})
}
}()

report.TrackProgress(source.GetProgress())
if ctx.Value("job_id") == "" {
ctx = context.WithValue(ctx, "job_id", report.JobID)
}
if ctx.Value("source_id") == "" {
ctx = context.WithValue(ctx, "source_id", report.SourceID)
}
if ctx.Value("source_name") == "" {
ctx = context.WithValue(ctx, "source_name", report.SourceName)
}
if ctx.Value("source_type") == "" {
ctx = context.WithValue(ctx, "source_type", source.Type().String())
}

// Check for the preferred method of tracking source units.
canUseSourceUnits := s.useSourceUnitsFunc != nil
if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
ctx.Logger().Info("running source",
"with_units", true)
return s.enumerateWithUnits(ctx, enumChunker, report, reporter)
}
return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String())
}

// enumerateWithUnits is a helper method to enumerate a Source that is also a
// SourceUnitEnumerator. This allows better introspection of what is getting
// enumerated and any errors encountered.
func (s *SourceManager) enumerateWithUnits(ctx context.Context, source SourceUnitEnumerator, report *JobProgress, reporter UnitReporter) error {
// Create a function that will save the first error encountered (if
// any) and discard the rest.
fatalErr := make(chan error, 1)
catchFirstFatal := func(err error) {
select {
case fatalErr <- err:
default:
}
}

// Produce units.
func() {
// TODO: Catch panics and add to report.
report.StartEnumerating(time.Now())
defer func() { report.EndEnumerating(time.Now()) }()
ctx.Logger().V(2).Info("enumerating source with units")
if err := source.Enumerate(ctx, reporter); err != nil {
report.ReportError(Fatal{err})
catchFirstFatal(Fatal{err})
}
}()

select {
case err := <-fatalErr:
return err
default:
return nil
}
}

// runWithoutUnits is a helper method to run a Source. It has coarse-grained
// job reporting.
func (s *SourceManager) runWithoutUnits(ctx context.Context, source Source, report *JobProgress, targets ...ChunkingTarget) error {
Expand All @@ -302,6 +419,7 @@ func (s *SourceManager) runWithoutUnits(ctx context.Context, source Source, repo
s.outputChunks <- chunk
}
}()

// Don't return from this function until the goroutine has finished
// outputting chunks to the downstream channel. Closing the channel
// will stop the goroutine, so that needs to happen first in the defer
Expand Down
24 changes: 12 additions & 12 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestSourceManagerRun(t *testing.T) {
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
for i := 0; i < 3; i++ {
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
<-ref.Done()
assert.NoError(t, err)
assert.NoError(t, ref.Snapshot().FatalError())
Expand All @@ -132,7 +132,7 @@ func TestSourceManagerWait(t *testing.T) {
source, err := buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
// Asynchronously run the source.
_, err = mgr.Run(context.Background(), "dummy", source)
_, err = mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
// Read the 1 chunk we're expecting so Waiting completes.
<-mgr.Chunks()
Expand All @@ -141,15 +141,15 @@ func TestSourceManagerWait(t *testing.T) {
// Run should return an error now.
_, err = buildDummy(&counterChunker{count: 1})
assert.NoError(t, err)
_, err = mgr.Run(context.Background(), "dummy", source)
_, err = mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.Error(t, err)
}

func TestSourceManagerError(t *testing.T) {
mgr := NewManager()
source, err := buildDummy(errorChunker{fmt.Errorf("oops")})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Error(t, ref.Snapshot().FatalError())
Expand All @@ -165,7 +165,7 @@ func TestSourceManagerReport(t *testing.T) {
mgr := NewManager(opts...)
source, err := buildDummy(&counterChunker{count: 4})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.Equal(t, 0, len(ref.Snapshot().Errors))
Expand Down Expand Up @@ -230,7 +230,7 @@ func TestSourceManagerNonFatalError(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
report := ref.Snapshot()
Expand All @@ -247,7 +247,7 @@ func TestSourceManagerContextCancelled(t *testing.T) {
assert.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
ref, err := mgr.Run(ctx, "dummy", source)
ref, err := mgr.EnumerateAndScan(ctx, "dummy", source)
assert.NoError(t, err)

cancel()
Expand Down Expand Up @@ -291,7 +291,7 @@ func TestSourceManagerCancelRun(t *testing.T) {
}})
assert.NoError(t, err)

ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)

cancelErr := fmt.Errorf("abort! abort!")
Expand All @@ -313,7 +313,7 @@ func TestSourceManagerAvailableCapacity(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, 1337, mgr.AvailableCapacity())
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)

<-start // Wait for start signal.
Expand All @@ -338,7 +338,7 @@ func TestSourceManagerUnitHook(t *testing.T) {
)
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())
Expand Down Expand Up @@ -399,7 +399,7 @@ func TestSourceManagerUnitHookBackPressure(t *testing.T) {
)
source, err := buildDummy(&unitChunker{input})
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)

var metrics []UnitMetrics
Expand Down Expand Up @@ -428,7 +428,7 @@ func TestSourceManagerUnitHookNoUnits(t *testing.T) {
source, err := buildDummy(&counterChunker{count: 5})
assert.NoError(t, err)

ref, err := mgr.Run(context.Background(), "dummy", source)
ref, err := mgr.EnumerateAndScan(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())
Expand Down
24 changes: 24 additions & 0 deletions pkg/sources/sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ type SourceUnitEnumerator interface {
Enumerate(ctx context.Context, reporter UnitReporter) error
}

// BaseUnitReporter is a helper struct that implements the UnitReporter interface
// and includes a JobProgress reference.
type baseUnitReporter struct {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes this a "base" unit reporter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the passed-in reporter gets wrapped into this, and so this reporter becomes the parent. 'base' made sense in that respect.

child UnitReporter
progress *JobProgress
}

func (b baseUnitReporter) UnitOk(ctx context.Context, unit SourceUnit) error {
b.progress.ReportUnit(unit)
if b.child != nil {
return b.child.UnitOk(ctx, unit)
}
return nil
}

func (b baseUnitReporter) UnitErr(ctx context.Context, err error) error {
b.progress.ReportError(err)
if b.child != nil {
return b.child.UnitErr(ctx, err)
}
return nil
}


// UnitReporter defines the interface a source will use to report whether a
// unit was found during enumeration. Either method may be called any number of
// times. Implementors of this interface should allow for concurrent calls.
Expand Down
Loading