Skip to content

Commit

Permalink
[feat] - concurently scan the filesystem source (#2364)
Browse files Browse the repository at this point in the history
* concurently scan the filesystem source

Co-authored-by: Miccah Castorina <[email protected]>

* fix test

* update test

* remove return

* use error not info

* address comment

---------

Co-authored-by: Miccah Castorina <[email protected]>
  • Loading branch information
ahrav and mcastorina authored Feb 3, 2024
1 parent 27b30e6 commit a22874f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
37 changes: 25 additions & 12 deletions pkg/sources/filesystem/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/go-errors/errors"
"github.com/go-logr/logr"
diskbufferreader "github.com/trufflesecurity/disk-buffer-reader"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

Expand All @@ -26,13 +27,14 @@ import (
const SourceType = sourcespb.SourceType_SOURCE_TYPE_FILESYSTEM

type Source struct {
name string
sourceId sources.SourceID
jobId sources.JobID
verify bool
paths []string
log logr.Logger
filter *common.Filter
name string
sourceId sources.SourceID
jobId sources.JobID
concurrency int
verify bool
paths []string
log logr.Logger
filter *common.Filter
sources.Progress
sources.CommonSourceUnitUnmarshaller
}
Expand All @@ -57,9 +59,10 @@ func (s *Source) JobID() sources.JobID {
}

// Init returns an initialized Filesystem source.
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, _ int) error {
func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error {
s.log = aCtx.Logger()

s.concurrency = concurrency
s.name = name
s.sourceId = sourceId
s.jobId = jobId
Expand Down Expand Up @@ -102,16 +105,22 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
err = s.scanFile(ctx, cleanPath, chunksChan)
}

if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
logger.Info("error scanning filesystem", "error", err)
}
}

return nil
}

func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sources.Chunk) error {
workerPool := new(errgroup.Group)
workerPool.SetLimit(s.concurrency)
defer func() { _ = workerPool.Wait() }()

return fs.WalkDir(os.DirFS(path), ".", func(relativePath string, d fs.DirEntry, err error) error {
if err != nil {
ctx.Logger().Error(err, "error walking directory")
return nil
}
fullPath := filepath.Join(path, relativePath)
Expand All @@ -126,9 +135,13 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour
return nil
}

if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
ctx.Logger().Info("error scanning file", "path", fullPath, "error", err)
}
workerPool.Go(func() error {
if err = s.scanFile(ctx, fullPath, chunksChan); err != nil {
ctx.Logger().Error(err, "error scanning file", "path", fullPath, "error", err)
}
return nil
})

return nil
})
}
Expand Down
14 changes: 11 additions & 3 deletions pkg/sources/filesystem/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func TestSource_Scan(t *testing.T) {
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := Source{}
Expand All @@ -71,16 +72,23 @@ func TestSource_Scan(t *testing.T) {
// TODO: this is kind of bad, if it errors right away we don't see it as a test failure.
// Debugging this usually requires setting a breakpoint on L78 and running test w/ debug.
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
if (err != nil) != tt.wantErr {
t.Errorf("Source.Chunks() error = %v, wantErr %v", err, tt.wantErr)
return
}
}()
gotChunk := <-chunksCh
if diff := pretty.Compare(gotChunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
var counter int
for chunk := range chunksCh {
if chunk.SourceMetadata.GetFilesystem().GetFile() == "filesystem.go" {
counter++
if diff := pretty.Compare(chunk.SourceMetadata, tt.wantSourceMetadata); diff != "" {
t.Errorf("Source.Chunks() %s diff: (-got +want)\n%s", tt.name, diff)
}
}
}
assert.Equal(t, 1, counter)
})
}
}
Expand Down

0 comments on commit a22874f

Please sign in to comment.