From e4956615ad9a60a19b3bdb8fd40d91f65222f5ff Mon Sep 17 00:00:00 2001 From: ahrav Date: Fri, 22 Nov 2024 13:33:34 -0800 Subject: [PATCH] [feat] - Support S3 Source Resumption (#3570) * add config option for s3 resumption * updates * initial progress tracking logic * more testing * revert s3 source file * UpdateScanProgress tests * adjust * updates * invert * updates * updates * fix * update * adjust test * fix * remove progress tracking * cleanup * cleanup * remove dupe * remove context cancellation logic * fix comment format * make resumption logic more clear * rename * fixes * update * add edge case test * remove dupe mu * add comment * fix comment --- pkg/sources/s3/s3.go | 215 +++++++++++++++++++++----- pkg/sources/s3/s3_integration_test.go | 162 ++++++++++++++++++- pkg/sources/s3/s3_test.go | 3 +- 3 files changed, 343 insertions(+), 37 deletions(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 27c9e9b4e5e7..91970e9fd703 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -2,6 +2,7 @@ package s3 import ( "fmt" + "slices" "strings" "sync" "sync/atomic" @@ -43,8 +44,10 @@ type Source struct { jobID sources.JobID verify bool concurrency int + conn *sourcespb.S3 + + checkpointer *Checkpointer sources.Progress - conn *sourcespb.S3 errorCount *sync.Map jobPool *errgroup.Group @@ -67,7 +70,7 @@ func (s *Source) JobID() sources.JobID { return s.jobID } // Init returns an initialized AWS source func (s *Source) Init( - _ context.Context, + ctx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, @@ -90,6 +93,8 @@ func (s *Source) Init( } s.conn = &conn + s.checkpointer = NewCheckpointer(ctx, conn.GetEnableResumption(), &s.Progress) + s.setMaxObjectSize(conn.GetMaxObjectSize()) if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 { @@ -173,9 +178,16 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) { return s3.New(sess), nil } -// IAM identity needs s3:ListBuckets permission +// getBucketsToScan returns a list of S3 buckets to scan. +// If the connection has a list of buckets specified, those are returned. +// Otherwise, it lists all buckets the client has access to and filters out the ignored ones. +// The list of buckets is sorted lexicographically to ensure consistent ordering, +// which allows resuming scanning from the same place if the scan is interrupted. +// +// Note: The IAM identity needs the s3:ListBuckets permission. func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) { if buckets := s.conn.GetBuckets(); len(buckets) > 0 { + slices.Sort(buckets) return buckets, nil } @@ -196,9 +208,73 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) { bucketsToScan = append(bucketsToScan, name) } } + slices.Sort(bucketsToScan) + return bucketsToScan, nil } +// pageMetadata contains metadata about a single page of S3 objects being scanned. +type pageMetadata struct { + bucket string // The name of the S3 bucket being scanned + pageNumber int // Current page number in the pagination sequence + client *s3.S3 // AWS S3 client configured for the appropriate region + page *s3.ListObjectsV2Output // Contains the list of S3 objects in this page +} + +// processingState tracks the state of concurrent S3 object processing. +type processingState struct { + errorCount *sync.Map // Thread-safe map tracking errors per prefix + objectCount *uint64 // Total number of objects processed +} + +// resumePosition tracks where to restart scanning S3 buckets and objects after an interruption. +// It encapsulates all the information needed to resume a scan from its last known position. +type resumePosition struct { + bucket string // The bucket name we were processing + index int // Index in the buckets slice where we should resume + startAfter string // The last processed object key within the bucket + isNewScan bool // True if we're starting a fresh scan + exactMatch bool // True if we found the exact bucket we were previously processing +} + +// determineResumePosition calculates where to resume scanning from based on the last saved checkpoint +// and the current list of available buckets to scan. It handles several scenarios: +// +// 1. If getting the resume point fails or there is no previous bucket saved (CurrentBucket is empty), +// we start a new scan from the beginning, this is the safest option. +// +// 2. If the previous bucket exists in our current scan list (exactMatch=true), +// we resume from that exact position and use the StartAfter value +// to continue from the last processed object within that bucket. +// +// 3. If the previous bucket is not found in our current scan list (exactMatch=false), this typically means: +// - The bucket was deleted since our last scan +// - The bucket was explicitly excluded from this scan's configuration +// - The IAM role no longer has access to the bucket +// - The bucket name changed due to a configuration update +// In this case, we use binary search to find the closest position where the bucket would have been, +// allowing us to resume from the nearest available point in our sorted bucket list rather than +// restarting the entire scan. +func determineResumePosition(ctx context.Context, tracker *Checkpointer, buckets []string) resumePosition { + resumePoint, err := tracker.ResumePoint(ctx) + if err != nil { + ctx.Logger().Error(err, "failed to get resume point; starting from the beginning") + return resumePosition{isNewScan: true} + } + + if resumePoint.CurrentBucket == "" { + return resumePosition{isNewScan: true} + } + + startIdx, found := slices.BinarySearch(buckets, resumePoint.CurrentBucket) + return resumePosition{ + bucket: resumePoint.CurrentBucket, + startAfter: resumePoint.StartAfter, + index: startIdx, + exactMatch: found, + } +} + func (s *Source) scanBuckets( ctx context.Context, client *s3.S3, @@ -206,22 +282,48 @@ func (s *Source) scanBuckets( bucketsToScan []string, chunksChan chan *sources.Chunk, ) { - var objectCount uint64 - if role != "" { ctx = context.WithValue(ctx, "role", role) } + var objectCount uint64 - for i, bucket := range bucketsToScan { + pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan) + switch { + case pos.isNewScan: + ctx.Logger().Info("Starting new scan from beginning") + case !pos.exactMatch: + ctx.Logger().Info( + "Resume bucket no longer available, starting from closest position", + "original_bucket", pos.bucket, + "position", pos.index, + ) + default: + ctx.Logger().Info( + "Resuming scan from previous scan's bucket", + "bucket", pos.bucket, + "position", pos.index, + ) + } + + bucketsToScanCount := len(bucketsToScan) + for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ { + bucket := bucketsToScan[bucketIdx] ctx := context.WithValue(ctx, "bucket", bucket) if common.IsDone(ctx) { + ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket") return } - s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "") ctx.Logger().V(3).Info("Scanning bucket") + s.SetProgressComplete( + bucketIdx, + len(bucketsToScan), + fmt.Sprintf("Bucket: %s", bucket), + s.Progress.EncodedResumeInfo, + ) + regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket) if err != nil { ctx.Logger().Error(err, "could not get regional client for bucket") @@ -230,10 +332,33 @@ func (s *Source) scanBuckets( errorCount := sync.Map{} + input := &s3.ListObjectsV2Input{Bucket: &bucket} + if bucket == pos.bucket && pos.startAfter != "" { + input.StartAfter = &pos.startAfter + ctx.Logger().V(3).Info( + "Resuming bucket scan", + "start_after", pos.startAfter, + ) + } + + pageNumber := 1 err = regionalClient.ListObjectsV2PagesWithContext( - ctx, &s3.ListObjectsV2Input{Bucket: &bucket}, + ctx, + input, func(page *s3.ListObjectsV2Output, _ bool) bool { - s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount) + pageMetadata := pageMetadata{ + bucket: bucket, + pageNumber: pageNumber, + client: regionalClient, + page: page, + } + processingState := processingState{ + errorCount: &errorCount, + objectCount: &objectCount, + } + s.pageChunker(ctx, pageMetadata, processingState, chunksChan) + + pageNumber++ return true }) @@ -249,6 +374,7 @@ func (s *Source) scanBuckets( } } } + s.SetProgressComplete( len(bucketsToScan), len(bucketsToScan), @@ -289,29 +415,25 @@ func (s *Source) getRegionalClientForBucket( return regionalClient, nil } -// pageChunker emits chunks onto the given channel from a page +// pageChunker emits chunks onto the given channel from a page. func (s *Source) pageChunker( ctx context.Context, - client *s3.S3, + metadata pageMetadata, + state processingState, chunksChan chan *sources.Chunk, - bucket string, - page *s3.ListObjectsV2Output, - errorCount *sync.Map, - pageNumber int, - objectCount *uint64, ) { - for _, obj := range page.Contents { + s.checkpointer.Reset() // Reset the checkpointer for each PAGE + ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber) + + for objIdx, obj := range metadata.page.Contents { if obj == nil { + if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { + ctx.Logger().Error(err, "could not update progress for nil object") + } continue } - ctx = context.WithValues( - ctx, - "key", *obj.Key, - "bucket", bucket, - "page", pageNumber, - "size", *obj.Size, - ) + ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size) if common.IsDone(ctx) { return @@ -320,29 +442,44 @@ func (s *Source) pageChunker( // Skip GLACIER and GLACIER_IR objects. if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") { ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass) + if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { + ctx.Logger().Error(err, "could not update progress for glacier object") + } continue } // Ignore large files. if *obj.Size > s.maxObjectSize { ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)") + if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { + ctx.Logger().Error(err, "could not update progress for large file") + } continue } // File empty file. if *obj.Size == 0 { ctx.Logger().V(5).Info("Skipping empty file") + if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { + ctx.Logger().Error(err, "could not update progress for empty file") + } continue } // Skip incompatible extensions. if common.SkipFile(*obj.Key) { ctx.Logger().V(5).Info("Skipping file with incompatible extension") + if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { + ctx.Logger().Error(err, "could not update progress for incompatible file") + } continue } s.jobPool.Go(func() error { defer common.RecoverWithExit(ctx) + if common.IsDone(ctx) { + return ctx.Err() + } if strings.HasSuffix(*obj.Key, "/") { ctx.Logger().V(5).Info("Skipping directory") @@ -352,7 +489,7 @@ func (s *Source) pageChunker( path := strings.Split(*obj.Key, "/") prefix := strings.Join(path[:len(path)-1], "/") - nErr, ok := errorCount.Load(prefix) + nErr, ok := state.errorCount.Load(prefix) if !ok { nErr = 0 } @@ -366,8 +503,8 @@ func (s *Source) pageChunker( objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout) defer cancel() - res, err := client.GetObjectWithContext(objCtx, &s3.GetObjectInput{ - Bucket: &bucket, + res, err := metadata.client.GetObjectWithContext(objCtx, &s3.GetObjectInput{ + Bucket: &metadata.bucket, Key: obj.Key, }) if err != nil { @@ -382,7 +519,7 @@ func (s *Source) pageChunker( res.Body.Close() } - nErr, ok := errorCount.Load(prefix) + nErr, ok := state.errorCount.Load(prefix) if !ok { nErr = 0 } @@ -391,7 +528,7 @@ func (s *Source) pageChunker( return nil } nErr = nErr.(int) + 1 - errorCount.Store(prefix, nErr) + state.errorCount.Store(prefix, nErr) // too many consecutive errors on this page if nErr.(int) > 3 { ctx.Logger().V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix) @@ -413,9 +550,9 @@ func (s *Source) pageChunker( SourceMetadata: &source_metadatapb.MetaData{ Data: &source_metadatapb.MetaData_S3{ S3: &source_metadatapb.S3{ - Bucket: bucket, + Bucket: metadata.bucket, File: sanitizer.UTF8(*obj.Key), - Link: sanitizer.UTF8(makeS3Link(bucket, *client.Config.Region, *obj.Key)), + Link: sanitizer.UTF8(makeS3Link(metadata.bucket, *metadata.client.Config.Region, *obj.Key)), Email: sanitizer.UTF8(email), Timestamp: sanitizer.UTF8(modified), }, @@ -429,14 +566,19 @@ func (s *Source) pageChunker( return nil } - atomic.AddUint64(objectCount, 1) - ctx.Logger().V(5).Info("S3 object scanned.", "object_count", objectCount) - nErr, ok = errorCount.Load(prefix) + atomic.AddUint64(state.objectCount, 1) + ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount) + nErr, ok = state.errorCount.Load(prefix) if !ok { nErr = 0 } if nErr.(int) > 0 { - errorCount.Store(prefix, 0) + state.errorCount.Store(prefix, 0) + } + + // Update progress after successful processing. + if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil { + ctx.Logger().Error(err, "could not update progress for scanned object") } return nil @@ -485,6 +627,9 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr // for each role, passing in the default S3 client, the role ARN, and the list of // buckets to scan. // +// The provided function parameter typically implements the core scanning logic +// and must handle context cancellation appropriately. +// // If no roles are configured, it will call the function with an empty role ARN. func (s *Source) visitRoles( ctx context.Context, diff --git a/pkg/sources/s3/s3_integration_test.go b/pkg/sources/s3/s3_integration_test.go index 7ca89e4cca94..1801e29e20e2 100644 --- a/pkg/sources/s3/s3_integration_test.go +++ b/pkg/sources/s3/s3_integration_test.go @@ -10,11 +10,12 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/common" - "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/context" + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) @@ -215,3 +216,162 @@ func TestSource_Validate(t *testing.T) { }) } } + +func TestSourceChunksNoResumption(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + s := Source{} + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"integration-resumption-tests"}, + } + conn, err := anypb.New(connection) + if err != nil { + t.Fatal(err) + } + + err = s.Init(ctx, "test name", 0, 0, false, conn, 1) + chunksCh := make(chan *sources.Chunk) + go func() { + defer close(chunksCh) + err = s.Chunks(ctx, chunksCh) + assert.Nil(t, err) + }() + + wantChunkCount := 19787 + got := 0 + + for range chunksCh { + got++ + } + assert.Equal(t, wantChunkCount, got) +} + +func TestSourceChunksResumption(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + src := new(Source) + src.Progress = sources.Progress{ + Message: "Bucket: integration-resumption-tests", + EncodedResumeInfo: "{\"current_bucket\":\"integration-resumption-tests\",\"start_after\":\"test-dir/\"}", + SectionsCompleted: 0, + SectionsRemaining: 1, + } + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"integration-resumption-tests"}, + EnableResumption: true, + } + conn, err := anypb.New(connection) + require.NoError(t, err) + + err = src.Init(ctx, "test name", 0, 0, false, conn, 2) + require.NoError(t, err) + + chunksCh := make(chan *sources.Chunk) + var count int + + cancelCtx, ctxCancel := context.WithCancel(ctx) + defer ctxCancel() + + go func() { + defer close(chunksCh) + err = src.Chunks(cancelCtx, chunksCh) + assert.NoError(t, err, "Should not error during scan") + }() + + for range chunksCh { + count++ + } + + // Verify that we processed all remaining data on resume. + // Also verify that we processed less than the total number of chunks for the source. + sourceTotalChunkCount := 19787 + assert.Equal(t, 9638, count, "Should have processed all remaining data on resume") + assert.Less(t, count, sourceTotalChunkCount, "Should have processed less than total chunks on resume") +} + +func TestSourceChunksNoResumptionMultipleBuckets(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + s := Source{} + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"integration-resumption-tests", "truffletestbucket"}, + } + conn, err := anypb.New(connection) + if err != nil { + t.Fatal(err) + } + + err = s.Init(ctx, "test name", 0, 0, false, conn, 1) + chunksCh := make(chan *sources.Chunk) + go func() { + defer close(chunksCh) + err = s.Chunks(ctx, chunksCh) + assert.Nil(t, err) + }() + + wantChunkCount := 19890 + got := 0 + + for range chunksCh { + got++ + } + assert.Equal(t, wantChunkCount, got) +} + +func TestSourceChunksResumptionMultipleBucketsIgnoredBucket(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + src := new(Source) + + // The bucket stored in EncodedResumeInfo is NOT in the list of buckets to scan. + // Therefore, resume from the other provided bucket (truffletestbucket). + src.Progress = sources.Progress{ + Message: "Bucket: integration-resumption-tests", + EncodedResumeInfo: "{\"current_bucket\":\"integration-resumption-tests\",\"start_after\":\"test-dir/\"}", + SectionsCompleted: 0, + SectionsRemaining: 1, + } + connection := &sourcespb.S3{ + Credential: &sourcespb.S3_Unauthenticated{}, + Buckets: []string{"truffletestbucket"}, + EnableResumption: true, + } + conn, err := anypb.New(connection) + require.NoError(t, err) + + err = src.Init(ctx, "test name", 0, 0, false, conn, 2) + require.NoError(t, err) + + chunksCh := make(chan *sources.Chunk) + var count int + + cancelCtx, ctxCancel := context.WithCancel(ctx) + defer ctxCancel() + + go func() { + defer close(chunksCh) + err = src.Chunks(cancelCtx, chunksCh) + assert.NoError(t, err, "Should not error during scan") + }() + + for range chunksCh { + count++ + } + + assert.Equal(t, 103, count, "Should have processed all remaining data on resume") +} diff --git a/pkg/sources/s3/s3_test.go b/pkg/sources/s3/s3_test.go index 1368bdac87d0..5f2f4aed75a4 100644 --- a/pkg/sources/s3/s3_test.go +++ b/pkg/sources/s3/s3_test.go @@ -10,12 +10,13 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" - "google.golang.org/protobuf/types/known/anypb" ) func TestSource_Init_IncludeAndIgnoreBucketsError(t *testing.T) {