diff --git a/pkg/sources/filesystem/filesystem.go b/pkg/sources/filesystem/filesystem.go index ca4edd947bd3..1d3d0cc8b7d7 100644 --- a/pkg/sources/filesystem/filesystem.go +++ b/pkg/sources/filesystem/filesystem.go @@ -10,6 +10,7 @@ import ( "github.com/go-errors/errors" "github.com/go-logr/logr" diskbufferreader "github.com/trufflesecurity/disk-buffer-reader" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -26,13 +27,14 @@ import ( const SourceType = sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM type Source struct { - name string - sourceId sources.SourceID - jobId sources.JobID - verify bool - paths []string - log logr.Logger - filter *common.Filter + name string + sourceId sources.SourceID + jobId sources.JobID + concurrency int + verify bool + paths []string + log logr.Logger + filter *common.Filter sources.Progress sources.CommonSourceUnitUnmarshaller } @@ -57,9 +59,10 @@ func (s *Source) JobID() sources.JobID { } // Init returns an initialized Filesystem source. -func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, _ int) error { +func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { s.log = aCtx.Logger() + s.concurrency = concurrency s.name = name s.sourceId = sourceId s.jobId = jobId @@ -102,16 +105,22 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ . err = s.scanFile(ctx, cleanPath, chunksChan) } - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { logger.Info("error scanning filesystem", "error", err) } } + return nil } func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sources.Chunk) error { + workerPool := new(errgroup.Group) + workerPool.SetLimit(s.concurrency) + defer func() { _ = workerPool.Wait() }() + return fs.WalkDir(os.DirFS(path), ".", func(relativePath string, d fs.DirEntry, err error) error { if err != nil { + ctx.Logger().Error(err, "error walking directory") return nil } fullPath := filepath.Join(path, relativePath) @@ -126,9 +135,13 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour return nil } - if err = s.scanFile(ctx, fullPath, chunksChan); err != nil { - ctx.Logger().Info("error scanning file", "path", fullPath, "error", err) - } + workerPool.Go(func() error { + if err = s.scanFile(ctx, fullPath, chunksChan); err != nil { + ctx.Logger().Error(err, "error scanning file", "path", fullPath, "error", err) + } + return nil + }) + return nil }) } diff --git a/pkg/sources/filesystem/filesystem_test.go b/pkg/sources/filesystem/filesystem_test.go index 0dccb3ee5393..4a846a6737cc 100644 --- a/pkg/sources/filesystem/filesystem_test.go +++ b/pkg/sources/filesystem/filesystem_test.go @@ -53,6 +53,7 @@ func TestSource_Scan(t *testing.T) { wantErr: false, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := Source{} @@ -71,16 +72,23 @@ func TestSource_Scan(t *testing.T) { // TODO: this is kind of bad, if it errors right away we don't see it as a test failure. // Debugging this usually requires setting a breakpoint on L78 and running test w/ debug. go func() { + defer close(chunksCh) err = s.Chunks(ctx, chunksCh) if (err != nil) != tt.wantErr { t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr) return } }() - gotChunk := <-chunksCh - if diff := pretty.Compare(gotChunk.SourceMetadata, tt.wantSourceMetadata); diff != "" { - t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff) + var counter int + for chunk := range chunksCh { + if chunk.SourceMetadata.GetFilesystem().GetFile() == "filesystem.go" { + counter++ + if diff := pretty.Compare(chunk.SourceMetadata, tt.wantSourceMetadata); diff != "" { + t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff) + } + } } + assert.Equal(t, 1, counter) }) } }