From 9b035d1ead7f2b017733e858a8a73e7901b68a79 Mon Sep 17 00:00:00 2001 From: James Kwon Date: Wed, 9 Aug 2023 21:00:29 -0400 Subject: [PATCH] Refactor s3 resource type --- aws/aws.go | 49 +--- aws/s3.go | 169 ++++------- aws/s3_test.go | 759 +++++++----------------------------------------- aws/s3_types.go | 2 +- 4 files changed, 179 insertions(+), 800 deletions(-) diff --git a/aws/aws.go b/aws/aws.go index 8347bea1..e9c1477d 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -1193,48 +1193,17 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(s3Buckets.ResourceName(), resourceTypes) { start := time.Now() - var bucketNamesPerRegion map[string][]*string - - // AWS S3 buckets list operation lists all buckets irrespective of regions. - // For each bucket we have to make a separate call to find the bucket region. - // Hence for x buckets and a total of y target regions - we need to make: - // (x + 1) * y calls i.e. 1 call to list all x buckets, x calls to find out - // each bucket's region and repeat the process for each of the y regions. - - // getAllS3Buckets returns a map of regions to buckets and we call it only once - - // thereby reducing total calls from (x + 1) * y to only (x + 1) for the first region - - // followed by a cache lookup for rest of the regions. - - // Cache lookup to check if we already obtained bucket names per region - bucketNamesPerRegion, ok := resourcesCache["S3"] - - if !ok { - bucketNamesPerRegion, err = getAllS3Buckets( - cloudNukeSession, - excludeAfter, - targetRegions, - "", - s3Buckets.MaxConcurrentGetSize(), - configObj, - ) - if err != nil { - ge := report.GeneralError{ - Error: err, - Description: "Unable to retrieve S3 buckets", - ResourceType: s3Buckets.ResourceName(), - } - report.RecordError(ge) - } - - resourcesCache["S3"] = make(map[string][]*string) - - for bucketRegion := range bucketNamesPerRegion { - resourcesCache["S3"][bucketRegion] = bucketNamesPerRegion[bucketRegion] + bucketNames, err := s3Buckets.getAll(configObj) + if err != nil { + ge := report.GeneralError{ + Error: err, + Description: "Unable to retrieve S3 buckets", + ResourceType: s3Buckets.ResourceName(), } + report.RecordError(ge) } - bucketNames, ok := resourcesCache["S3"][region] - + resourcesCache["S3"] = make(map[string][]*string) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Done Listing S3 Buckets", }, map[string]interface{}{ @@ -1242,7 +1211,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp "recordCount": len(bucketNames), "actionTime": time.Since(start).Seconds(), }) - if ok && len(bucketNames) > 0 { + if len(bucketNames) > 0 { s3Buckets.Names = aws.StringValueSlice(bucketNames) resourcesInRegion.Resources = append(resourcesInRegion.Resources, s3Buckets) } diff --git a/aws/s3.go b/aws/s3.go index d17d28b0..ea2efa28 100644 --- a/aws/s3.go +++ b/aws/s3.go @@ -13,7 +13,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/gruntwork-io/go-commons/errors" @@ -23,12 +22,12 @@ import ( ) // getS3BucketRegion returns S3 Bucket region. -func getS3BucketRegion(svc *s3.S3, bucketName string) (string, error) { +func (sb S3Buckets) getS3BucketRegion(bucketName string) (string, error) { input := &s3.GetBucketLocationInput{ Bucket: aws.String(bucketName), } - result, err := svc.GetBucketLocation(input) + result, err := sb.Client.GetBucketLocation(input) if err != nil { return "", err } @@ -42,7 +41,7 @@ func getS3BucketRegion(svc *s3.S3, bucketName string) (string, error) { } // getS3BucketTags returns S3 Bucket tags. -func getS3BucketTags(svc *s3.S3, bucketName string) ([]map[string]string, error) { +func (bucket S3Buckets) getS3BucketTags(bucketName string) ([]map[string]string, error) { input := &s3.GetBucketTaggingInput{ Bucket: aws.String(bucketName), } @@ -51,7 +50,7 @@ func getS3BucketTags(svc *s3.S3, bucketName string) ([]map[string]string, error) // Please note that svc argument should be created from a session object which is // in the same region as the bucket or GetBucketTagging will fail. - result, err := svc.GetBucketTagging(input) + result, err := bucket.Client.GetBucketTagging(input) if err != nil { if aerr, ok := err.(awserr.Error); ok { switch aerr.Code() { @@ -88,7 +87,6 @@ func hasValidTags(bucketTags []map[string]string) bool { type S3Bucket struct { Name string CreationDate time.Time - Region string Tags []map[string]string Error error IsValid bool @@ -96,31 +94,20 @@ type S3Bucket struct { } // getAllS3Buckets returns a map of per region AWS S3 buckets which were created before excludeAfter -func getAllS3Buckets(awsSession *session.Session, excludeAfter time.Time, - targetRegions []string, bucketNameSubStr string, batchSize int, configObj config.Config, -) (map[string][]*string, error) { - if batchSize <= 0 { - return nil, fmt.Errorf("Invalid batchsize - %d - should be > 0", batchSize) - } - - svc := s3.New(awsSession) +func (sb S3Buckets) getAll(configObj config.Config) ([]*string, error) { input := &s3.ListBucketsInput{} - output, err := svc.ListBuckets(input) - if err != nil { - return nil, errors.WithStackTrace(err) - } - - regionClients, err := getRegionClients(targetRegions) + output, err := sb.Client.ListBuckets(input) if err != nil { return nil, errors.WithStackTrace(err) } - bucketNamesPerRegion := make(map[string][]*string) + var names []*string totalBuckets := len(output.Buckets) if totalBuckets == 0 { - return bucketNamesPerRegion, nil + return nil, nil } + batchSize := sb.MaxBatchSize() totalBatches := int(math.Ceil(float64(totalBuckets) / float64(batchSize))) batchCount := 1 @@ -129,53 +116,29 @@ func getAllS3Buckets(awsSession *session.Session, excludeAfter time.Time, batchEnd := int(math.Min(float64(batchStart)+float64(batchSize), float64(totalBuckets))) logging.Logger.Debugf("Getting - %d-%d buckets of batch %d/%d", batchStart+1, batchEnd, batchCount, totalBatches) targetBuckets := output.Buckets[batchStart:batchEnd] - currBucketNamesPerRegion, err := getBucketNamesPerRegion(svc, targetBuckets, excludeAfter, regionClients, bucketNameSubStr, configObj) + bucketNames, err := sb.getBucketNames(targetBuckets, configObj) if err != nil { - return bucketNamesPerRegion, err + return nil, err } - for region, buckets := range currBucketNamesPerRegion { - if _, ok := bucketNamesPerRegion[region]; !ok { - bucketNamesPerRegion[region] = []*string{} - } - bucketNamesPerRegion[region] = append(bucketNamesPerRegion[region], buckets...) - } + names = append(names, bucketNames...) batchCount++ } - return bucketNamesPerRegion, nil -} - -// getRegions creates s3 clients for target regions -func getRegionClients(regions []string) (map[string]*s3.S3, error) { - regionClients := make(map[string]*s3.S3) - for _, region := range regions { - logging.Logger.Debugf("S3 - creating session - region %s", region) - awsSession := newSession(region) - - regionClients[region] = s3.New(awsSession) - } - return regionClients, nil + return names, nil } // getBucketNamesPerRegions gets valid bucket names concurrently from list of target buckets -func getBucketNamesPerRegion(svc *s3.S3, targetBuckets []*s3.Bucket, excludeAfter time.Time, regionClients map[string]*s3.S3, - bucketNameSubStr string, configObj config.Config, -) (map[string][]*string, error) { - bucketNamesPerRegion := make(map[string][]*string) +func (sb S3Buckets) getBucketNames(targetBuckets []*s3.Bucket, configObj config.Config) ([]*string, error) { + var bucketNames []*string bucketCh := make(chan *S3Bucket, len(targetBuckets)) var wg sync.WaitGroup for _, bucket := range targetBuckets { - if len(bucketNameSubStr) > 0 && !strings.Contains(*bucket.Name, bucketNameSubStr) { - logging.Logger.Debugf("Skipping - Bucket %s - failed substring filter - %s", *bucket.Name, bucketNameSubStr) - continue - } - wg.Add(1) go func(bucket *s3.Bucket) { defer wg.Done() - getBucketInfo(svc, bucket, excludeAfter, regionClients, bucketCh, configObj) + sb.getBucketInfo(bucket, bucketCh, configObj) }(bucket) } @@ -188,51 +151,42 @@ func getBucketNamesPerRegion(svc *s3.S3, targetBuckets []*s3.Bucket, excludeAfte // messages are shown to the user as soon as possible for bucketData := range bucketCh { if bucketData.Error != nil { - logging.Logger.Debugf("Skipping - Bucket %s - region - %s - error: %s", bucketData.Name, bucketData.Region, bucketData.Error) + logging.Logger.Debugf("Skipping - Bucket %s - region - %s - error: %s", bucketData.Name, sb.Region, bucketData.Error) continue } if !bucketData.IsValid { - logging.Logger.Debugf("Skipping - Bucket %s - region - %s - %s", bucketData.Name, bucketData.Region, bucketData.InvalidReason) + logging.Logger.Debugf("Skipping - Bucket %s - region - %s - %s", bucketData.Name, sb.Region, bucketData.InvalidReason) continue } - if _, ok := bucketNamesPerRegion[bucketData.Region]; !ok { - bucketNamesPerRegion[bucketData.Region] = []*string{} - } - bucketNamesPerRegion[bucketData.Region] = append(bucketNamesPerRegion[bucketData.Region], aws.String(bucketData.Name)) + + bucketNames = append(bucketNames, aws.String(bucketData.Name)) } - return bucketNamesPerRegion, nil + + return bucketNames, nil } // getBucketInfo populates the local S3Bucket struct for the passed AWS bucket -func getBucketInfo(svc *s3.S3, bucket *s3.Bucket, excludeAfter time.Time, regionClients map[string]*s3.S3, bucketCh chan<- *S3Bucket, configObj config.Config) { +func (sb S3Buckets) getBucketInfo(bucket *s3.Bucket, bucketCh chan<- *S3Bucket, configObj config.Config) { var bucketData S3Bucket bucketData.Name = aws.StringValue(bucket.Name) bucketData.CreationDate = aws.TimeValue(bucket.CreationDate) - bucketRegion, err := getS3BucketRegion(svc, bucketData.Name) + bucketRegion, err := sb.getS3BucketRegion(bucketData.Name) if err != nil { bucketData.Error = err bucketCh <- &bucketData return } - bucketData.Region = bucketRegion // Check if the bucket is in target region - matchedRegion := false - for region := range regionClients { - if region == bucketData.Region { - matchedRegion = true - break - } - } - if !matchedRegion { + if bucketRegion != sb.Region { bucketData.InvalidReason = "Not in target region" bucketCh <- &bucketData return } // Check if the bucket has valid tags - bucketTags, err := getS3BucketTags(regionClients[bucketData.Region], bucketData.Name) + bucketTags, err := sb.getS3BucketTags(bucketData.Name) if err != nil { bucketData.Error = err bucketCh <- &bucketData @@ -246,14 +200,14 @@ func getBucketInfo(svc *s3.S3, bucket *s3.Bucket, excludeAfter time.Time, region } // Check if the bucket is older than the required time - if !excludeAfter.After(bucketData.CreationDate) { + if !configObj.S3.ShouldInclude(config.ResourceValue{Time: &bucketData.CreationDate}) { bucketData.InvalidReason = "Matched CreationDate filter" bucketCh <- &bucketData return } // Check if the bucket matches config file rules - if !config.ShouldInclude(bucketData.Name, configObj.S3.IncludeRule.NamesRegExp, configObj.S3.ExcludeRule.NamesRegExp) { + if !configObj.S3.ShouldInclude(config.ResourceValue{Name: &bucketData.Name}) { bucketData.InvalidReason = "Filtered by config file rules" bucketCh <- &bucketData return @@ -269,7 +223,7 @@ func getBucketInfo(svc *s3.S3, bucket *s3.Bucket, excludeAfter time.Time, region // does not provide any API for getting the object count, and the only way to do that is to iterate through all the // objects. For memory and time efficiency, we opted to delete the objects as we retrieve each page, which means we // don't know how many are left until we complete all the operations. -func emptyBucket(svc *s3.S3, bucketName *string, isVersioned bool, batchSize int) error { +func (sb S3Buckets) emptyBucket(bucketName *string, isVersioned bool) error { // Since the error may happen in the inner function handler for the pager, we need a function scoped variable that // the inner function can set when there is an error. var errOut error @@ -277,14 +231,14 @@ func emptyBucket(svc *s3.S3, bucketName *string, isVersioned bool, batchSize int // Handle versioned buckets. if isVersioned { - err := svc.ListObjectVersionsPages( + err := sb.Client.ListObjectVersionsPages( &s3.ListObjectVersionsInput{ Bucket: bucketName, - MaxKeys: aws.Int64(int64(batchSize)), + MaxKeys: aws.Int64(int64(sb.MaxBatchSize())), }, func(page *s3.ListObjectVersionsOutput, lastPage bool) (shouldContinue bool) { logging.Logger.Debugf("Deleting page %d of object versions (%d objects) from bucket %s", pageId, len(page.Versions), aws.StringValue(bucketName)) - if err := deleteObjectVersions(svc, bucketName, page.Versions); err != nil { + if err := sb.deleteObjectVersions(bucketName, page.Versions); err != nil { logging.Logger.Errorf("Error deleting objects versions for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err) errOut = err return false @@ -292,7 +246,7 @@ func emptyBucket(svc *s3.S3, bucketName *string, isVersioned bool, batchSize int logging.Logger.Debugf("[OK] - deleted page %d of object versions (%d objects) from bucket %s", pageId, len(page.Versions), aws.StringValue(bucketName)) logging.Logger.Debugf("Deleting page %d of deletion markers (%d deletion markers) from bucket %s", pageId, len(page.DeleteMarkers), aws.StringValue(bucketName)) - if err := deleteDeletionMarkers(svc, bucketName, page.DeleteMarkers); err != nil { + if err := sb.deleteDeletionMarkers(bucketName, page.DeleteMarkers); err != nil { logging.Logger.Debugf("Error deleting deletion markers for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err) errOut = err return false @@ -313,14 +267,14 @@ func emptyBucket(svc *s3.S3, bucketName *string, isVersioned bool, batchSize int } // Handle non versioned buckets. - err := svc.ListObjectsV2Pages( + err := sb.Client.ListObjectsV2Pages( &s3.ListObjectsV2Input{ Bucket: bucketName, - MaxKeys: aws.Int64(int64(batchSize)), + MaxKeys: aws.Int64(int64(sb.MaxBatchSize())), }, func(page *s3.ListObjectsV2Output, lastPage bool) (shouldContinue bool) { logging.Logger.Debugf("Deleting object page %d (%d objects) from bucket %s", pageId, len(page.Contents), aws.StringValue(bucketName)) - if err := deleteObjects(svc, bucketName, page.Contents); err != nil { + if err := sb.deleteObjects(bucketName, page.Contents); err != nil { logging.Logger.Errorf("Error deleting objects for page %d from bucket %s: %s", pageId, aws.StringValue(bucketName), err) errOut = err return false @@ -341,7 +295,7 @@ func emptyBucket(svc *s3.S3, bucketName *string, isVersioned bool, batchSize int } // deleteObjects will delete the provided objects (unversioned) from the specified bucket. -func deleteObjects(svc *s3.S3, bucketName *string, objects []*s3.Object) error { +func (sb S3Buckets) deleteObjects(bucketName *string, objects []*s3.Object) error { if len(objects) == 0 { logging.Logger.Debugf("No objects returned in page") return nil @@ -353,7 +307,7 @@ func deleteObjects(svc *s3.S3, bucketName *string, objects []*s3.Object) error { Key: obj.Key, }) } - _, err := svc.DeleteObjects( + _, err := sb.Client.DeleteObjects( &s3.DeleteObjectsInput{ Bucket: bucketName, Delete: &s3.Delete{ @@ -366,7 +320,7 @@ func deleteObjects(svc *s3.S3, bucketName *string, objects []*s3.Object) error { } // deleteObjectVersions will delete the provided object versions from the specified bucket. -func deleteObjectVersions(svc *s3.S3, bucketName *string, objectVersions []*s3.ObjectVersion) error { +func (sb S3Buckets) deleteObjectVersions(bucketName *string, objectVersions []*s3.ObjectVersion) error { if len(objectVersions) == 0 { logging.Logger.Debugf("No object versions returned in page") return nil @@ -379,7 +333,7 @@ func deleteObjectVersions(svc *s3.S3, bucketName *string, objectVersions []*s3.O VersionId: obj.VersionId, }) } - _, err := svc.DeleteObjects( + _, err := sb.Client.DeleteObjects( &s3.DeleteObjectsInput{ Bucket: bucketName, Delete: &s3.Delete{ @@ -392,7 +346,7 @@ func deleteObjectVersions(svc *s3.S3, bucketName *string, objectVersions []*s3.O } // deleteDeletionMarkers will delete the provided deletion markers from the specified bucket. -func deleteDeletionMarkers(svc *s3.S3, bucketName *string, objectDelMarkers []*s3.DeleteMarkerEntry) error { +func (sb S3Buckets) deleteDeletionMarkers(bucketName *string, objectDelMarkers []*s3.DeleteMarkerEntry) error { if len(objectDelMarkers) == 0 { logging.Logger.Debugf("No deletion markers returned in page") return nil @@ -405,7 +359,7 @@ func deleteDeletionMarkers(svc *s3.S3, bucketName *string, objectDelMarkers []*s VersionId: obj.VersionId, }) } - _, err := svc.DeleteObjects( + _, err := sb.Client.DeleteObjects( &s3.DeleteObjectsInput{ Bucket: bucketName, Delete: &s3.Delete{ @@ -418,8 +372,8 @@ func deleteDeletionMarkers(svc *s3.S3, bucketName *string, objectDelMarkers []*s } // nukeAllS3BucketObjects batch deletes all objects in an S3 bucket -func nukeAllS3BucketObjects(svc *s3.S3, bucketName *string, batchSize int) error { - versioningResult, err := svc.GetBucketVersioning(&s3.GetBucketVersioningInput{ +func (sb S3Buckets) nukeAllS3BucketObjects(bucketName *string) error { + versioningResult, err := sb.Client.GetBucketVersioning(&s3.GetBucketVersioningInput{ Bucket: bucketName, }) if err != nil { @@ -428,12 +382,12 @@ func nukeAllS3BucketObjects(svc *s3.S3, bucketName *string, batchSize int) error isVersioned := aws.StringValue(versioningResult.Status) == "Enabled" - if batchSize < 1 || batchSize > 1000 { - return fmt.Errorf("Invalid batchsize - %d - should be between %d and %d", batchSize, 1, 1000) + if sb.MaxBatchSize() < 1 || sb.MaxBatchSize() > 1000 { + return fmt.Errorf("Invalid batchsize - %d - should be between %d and %d", sb.MaxBatchSize(), 1, 1000) } logging.Logger.Debugf("Emptying bucket %s", aws.StringValue(bucketName)) - if err := emptyBucket(svc, bucketName, isVersioned, batchSize); err != nil { + if err := sb.emptyBucket(bucketName, isVersioned); err != nil { return err } logging.Logger.Debugf("[OK] - successfully emptied bucket %s", aws.StringValue(bucketName)) @@ -441,8 +395,8 @@ func nukeAllS3BucketObjects(svc *s3.S3, bucketName *string, batchSize int) error } // nukeEmptyS3Bucket deletes an empty S3 bucket -func nukeEmptyS3Bucket(svc *s3.S3, bucketName *string, verifyBucketDeletion bool) error { - _, err := svc.DeleteBucket(&s3.DeleteBucketInput{ +func (sb S3Buckets) nukeEmptyS3Bucket(bucketName *string, verifyBucketDeletion bool) error { + _, err := sb.Client.DeleteBucket(&s3.DeleteBucketInput{ Bucket: bucketName, }) if err != nil { @@ -458,7 +412,7 @@ func nukeEmptyS3Bucket(svc *s3.S3, bucketName *string, verifyBucketDeletion bool const maxRetries = 3 for i := 0; i < maxRetries; i++ { logging.Logger.Debugf("Waiting until bucket (%s) deletion is propagated (attempt %d / %d)", aws.StringValue(bucketName), i+1, maxRetries) - err = svc.WaitUntilBucketNotExists(&s3.HeadBucketInput{ + err = sb.Client.WaitUntilBucketNotExists(&s3.HeadBucketInput{ Bucket: bucketName, }) // Exit early if no error @@ -472,26 +426,25 @@ func nukeEmptyS3Bucket(svc *s3.S3, bucketName *string, verifyBucketDeletion bool return err } -func nukeS3BucketPolicy(svc *s3.S3, bucketName *string) error { - _, err := svc.DeleteBucketPolicy(&s3.DeleteBucketPolicyInput{ +func (sb S3Buckets) nukeS3BucketPolicy(bucketName *string) error { + _, err := sb.Client.DeleteBucketPolicy(&s3.DeleteBucketPolicyInput{ Bucket: aws.String(*bucketName), }) return err } // nukeAllS3Buckets deletes all S3 buckets passed as input -func nukeAllS3Buckets(awsSession *session.Session, bucketNames []*string, objectBatchSize int) (delCount int, err error) { - svc := s3.New(awsSession) +func (sb S3Buckets) nukeAll(bucketNames []*string) (delCount int, err error) { verifyBucketDeletion := true if len(bucketNames) == 0 { - logging.Logger.Debugf("No S3 Buckets to nuke in region %s", *awsSession.Config.Region) + logging.Logger.Debugf("No S3 Buckets to nuke in region %s", sb.Region) return 0, nil } totalCount := len(bucketNames) - logging.Logger.Debugf("Deleting - %d S3 Buckets in region %s", totalCount, *awsSession.Config.Region) + logging.Logger.Debugf("Deleting - %d S3 Buckets in region %s", totalCount, sb.Region) multiErr := new(multierror.Error) for bucketIndex := 0; bucketIndex < totalCount; bucketIndex++ { @@ -499,37 +452,37 @@ func nukeAllS3Buckets(awsSession *session.Session, bucketNames []*string, object bucketName := bucketNames[bucketIndex] logging.Logger.Debugf("Deleting - %d/%d - Bucket: %s", bucketIndex+1, totalCount, *bucketName) - err = nukeAllS3BucketObjects(svc, bucketName, objectBatchSize) + err = sb.nukeAllS3BucketObjects(bucketName) if err != nil { logging.Logger.Debugf("[Failed] - %d/%d - Bucket: %s - object deletion error - %s", bucketIndex+1, totalCount, *bucketName, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking S3 Bucket Objects", }, map[string]interface{}{ - "region": *awsSession.Config.Region, + "region": sb.Region, }) multierror.Append(multiErr, err) continue } - err = nukeS3BucketPolicy(svc, bucketName) + err = sb.nukeS3BucketPolicy(bucketName) if err != nil { logging.Logger.Debugf("[Failed] - %d/%d - Bucket: %s - bucket policy cleanup error - %s", bucketIndex+1, totalCount, *bucketName, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking S3 Bucket Polikcy", }, map[string]interface{}{ - "region": *awsSession.Config.Region, + "region": sb.Region, }) multierror.Append(multiErr, err) continue } - err = nukeEmptyS3Bucket(svc, bucketName, verifyBucketDeletion) + err = sb.nukeEmptyS3Bucket(bucketName, verifyBucketDeletion) if err != nil { logging.Logger.Debugf("[Failed] - %d/%d - Bucket: %s - bucket deletion error - %s", bucketIndex+1, totalCount, *bucketName, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking S3 Bucket", }, map[string]interface{}{ - "region": *awsSession.Config.Region, + "region": sb.Region, }) multierror.Append(multiErr, err) continue diff --git a/aws/s3_test.go b/aws/s3_test.go index 475e9c38..943cfcd3 100644 --- a/aws/s3_test.go +++ b/aws/s3_test.go @@ -1,706 +1,163 @@ package aws import ( - "encoding/json" - "fmt" - "github.com/gruntwork-io/cloud-nuke/telemetry" - "io/ioutil" - "os" - "path" - "strings" - "testing" - "time" - "github.com/aws/aws-sdk-go/aws" - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/stretchr/testify/require" + "regexp" + "testing" + "time" ) -func TestMain(m *testing.M) { - telemetry.InitTelemetry("cloud-nuke", "") - logLevel := os.Getenv("LOG_LEVEL") - if len(logLevel) > 0 { - parsedLogLevel, err := logrus.ParseLevel(logLevel) - if err != nil { - logging.Logger.Errorf("Invalid log level - %s - %s", logLevel, err) - os.Exit(1) - } - logging.Logger.Level = parsedLogLevel - } - exitVal := m.Run() - os.Exit(exitVal) +type mockedS3Buckets struct { + s3iface.S3API + ListBucketsOutput s3.ListBucketsOutput + GetBucketLocationOutput s3.GetBucketLocationOutput + GetBucketTaggingOutput s3.GetBucketTaggingOutput + GetBucketVersioningOutput s3.GetBucketVersioningOutput + ListObjectVersionsPagesOutput s3.ListObjectVersionsOutput + DeleteObjectsOutput s3.DeleteObjectsOutput + DeleteBucketPolicyOutput s3.DeleteBucketPolicyOutput + DeleteBucketOutput s3.DeleteBucketOutput } -// S3TestGenBucketName generates a test bucket name. -func S3TestGenBucketName() string { - return strings.ToLower("cloud-nuke-test-" + util.UniqueID() + util.UniqueID()) +func (m mockedS3Buckets) ListBuckets(*s3.ListBucketsInput) (*s3.ListBucketsOutput, error) { + return &m.ListBucketsOutput, nil } -// S3TestCreateNewAWSSession creates a new session for testing and returns it. -func S3TestCreateNewAWSSession(region string) (*session.Session, error) { - if region == "" { - var err error - region, err = getRandomRegion() - if err != nil { - return nil, err - } - logging.Logger.Debugf("Creating session in region - %s", region) - } - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - return session, err +func (m mockedS3Buckets) GetBucketLocation(*s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error) { + return &m.GetBucketLocationOutput, nil } -// S3TestAWSParams has AWS params info, -type S3TestAWSParams struct { - region string - awsSession *session.Session - svc *s3.S3 +func (m mockedS3Buckets) GetBucketTagging(*s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error) { + return &m.GetBucketTaggingOutput, nil } -// newS3TestAWSParams sets up common operations for nuke S3 tests. -func newS3TestAWSParams(region string) (S3TestAWSParams, error) { - var params S3TestAWSParams - - if region == "" { - var err error - region, err = getRandomRegion() - if err != nil { - return params, err - } - } - - params.region = region - - awsSession, err := S3TestCreateNewAWSSession(params.region) - if err != nil { - return params, err - } - params.awsSession = awsSession - - params.svc = s3.New(params.awsSession) - if err != nil { - return params, err - } - - return params, nil -} - -// S3TestCreateBucket creates a test bucket and optionally tags and versions it. -func S3TestCreateBucket(svc *s3.S3, bucketName string, tags []map[string]string, isVersioned bool) error { - logging.Logger.Debugf("Bucket: %s - creating", bucketName) - - _, err := svc.CreateBucket(&s3.CreateBucketInput{ - Bucket: aws.String(bucketName), - }) - if err != nil { - return err - } - - // Add default tag for testing - var awsTagSet []*s3.Tag - - for _, tagSet := range tags { - awsTagSet = append(awsTagSet, &s3.Tag{Key: aws.String(tagSet["Key"]), Value: aws.String(tagSet["Value"])}) - } - - if len(awsTagSet) > 0 { - input := &s3.PutBucketTaggingInput{ - Bucket: aws.String(bucketName), - Tagging: &s3.Tagging{ - TagSet: awsTagSet, - }, - } - _, err = svc.PutBucketTagging(input) - if err != nil { - return err - } - } - - if isVersioned { - input := &s3.PutBucketVersioningInput{ - Bucket: aws.String(bucketName), - VersioningConfiguration: &s3.VersioningConfiguration{ - Status: aws.String("Enabled"), - }, - } - _, err = svc.PutBucketVersioning(input) - if err != nil { - return err - } - } - - err = svc.WaitUntilBucketExists( - &s3.HeadBucketInput{ - Bucket: aws.String(bucketName), - }, - ) - if err != nil { - return err - } +func (m mockedS3Buckets) WaitUntilBucketNotExists(*s3.HeadBucketInput) error { return nil } -// S3TestBucketAddObject adds an object ot an S3 bucket. -func S3TestBucketAddObject(awsParams S3TestAWSParams, bucketName string, fileName string, fileBody string) error { - logging.Logger.Debugf("Bucket: %s - adding object: %s - content: %s", bucketName, fileName, fileBody) - - reader := strings.NewReader(fileBody) - uploader := s3manager.NewUploader(awsParams.awsSession) - - _, err := uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String(bucketName), - Key: aws.String(fileName), - Body: reader, - }) - if err != nil { - return err - } - return nil +func (m mockedS3Buckets) GetBucketVersioning(*s3.GetBucketVersioningInput) (*s3.GetBucketVersioningOutput, error) { + return &m.GetBucketVersioningOutput, nil } -// TestListS3Bucket represents arguments for TestListS3Bucket -type TestListS3BucketArgs struct { - bucketTags []map[string]string - batchSize int - shouldError bool - shouldMatch bool +func (m mockedS3Buckets) ListObjectVersionsPages(input *s3.ListObjectVersionsInput, fn func(*s3.ListObjectVersionsOutput, bool) bool) error { + fn(&m.ListObjectVersionsPagesOutput, true) + return nil } -// testListS3Bucket - helper function for TestListS3Bucket -func testListS3Bucket(t *testing.T, args TestListS3BucketArgs) { - awsParams, err := newS3TestAWSParams("") - require.NoError(t, err, "Failed to setup AWS params") - - bucketName := S3TestGenBucketName() - - awsSession, err := S3TestCreateNewAWSSession("") - require.NoError(t, err, "Failed to create random session") - - targetRegions := []string{awsParams.region} - - // Please note that we are passing the same session that was used to create the bucket - // This is required so that the defer cleanup call always gets the right bucket region - // to delete - defer func() { - _, err := nukeAllS3Buckets(awsParams.awsSession, []*string{aws.String(bucketName)}, 1000) - if args.shouldError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }() - - // Verify that - before creating bucket - it should not exist - // - // Please note that we are not reusing S3TestAWSParams.awsSession and creating a random session in a region other - // than the one in which the bucket is created - this is useful to test the scenario where the user has - // AWS_DEFAULT_REGION set to region x but the bucket is in region y. - bucketNamesPerRegion, err := getAllS3Buckets(awsSession, time.Now().Add(1*time.Hour*-1), targetRegions, bucketName, args.batchSize, config.Config{}) - if args.shouldError { - require.Error(t, err) - logging.Logger.Debugf("SUCCESS: Did not list buckets due to invalid batch size - %s - %s", bucketName, err.Error()) - return - } - - require.NoError(t, err, "Failed to list S3 Buckets") - - // Validate test bucket does not exist before creation - require.NotContains(t, bucketNamesPerRegion[awsParams.region], aws.String(bucketName)) - - // Create test bucket - var bucketTags []map[string]string - if args.bucketTags != nil && len(args.bucketTags) > 0 { - bucketTags = args.bucketTags - } - - err = S3TestCreateBucket(awsParams.svc, bucketName, bucketTags, false) - - require.NoError(t, err, "Failed to create test buckets") - - bucketNamesPerRegion, err = getAllS3Buckets(awsSession, time.Now().Add(1*time.Hour), targetRegions, bucketName, args.batchSize, config.Config{}) - require.NoError(t, err, "Failed to list S3 Buckets") - - if args.shouldMatch { - require.Contains(t, bucketNamesPerRegion[awsParams.region], aws.String(bucketName)) - logging.Logger.Debugf("SUCCESS: Matched bucket - %s", bucketName) - } else { - require.NotContains(t, bucketNamesPerRegion[awsParams.region], aws.String(bucketName)) - logging.Logger.Debugf("SUCCESS: Did not match bucket - %s", bucketName) - } +func (m mockedS3Buckets) DeleteObjects(*s3.DeleteObjectsInput) (*s3.DeleteObjectsOutput, error) { + return &m.DeleteObjectsOutput, nil } -// TestListS3Bucket tests listing S3 bucket operation -func TestListS3Bucket(t *testing.T) { - t.Parallel() - telemetry.InitTelemetry("cloud-nuke", "") - - var testCases = []struct { - name string - args TestListS3BucketArgs - }{ - { - "NoTags", - TestListS3BucketArgs{ - bucketTags: []map[string]string{}, - batchSize: 10, - shouldMatch: true, - shouldError: false, - }, - }, - { - "WithoutFilterTag", - TestListS3BucketArgs{ - bucketTags: []map[string]string{ - {"Key": "testKey", "Value": "testValue"}, - }, - batchSize: 10, - shouldMatch: true, - shouldError: false, - }, - }, - { - "WithFilterTag", - TestListS3BucketArgs{ - bucketTags: []map[string]string{ - {"Key": AwsResourceExclusionTagKey, "Value": "true"}, - }, - batchSize: 10, - shouldMatch: false, - shouldError: false, - }, - }, - { - "MultiCaseFilterTag", - TestListS3BucketArgs{ - bucketTags: []map[string]string{ - {"Key": "test-key-1", "Value": "test-value-1"}, - {"Key": "test-key-2", "Value": "test-value-2"}, - {"Key": strings.ToTitle(AwsResourceExclusionTagKey), "Value": "TrUe"}, - }, - batchSize: 10, - shouldMatch: false, - shouldError: false, - }, - }, - { - "InvalidBatchSize", - TestListS3BucketArgs{ - bucketTags: nil, - batchSize: -1, - shouldMatch: false, - shouldError: true, - }, - }, - } - for _, tc := range testCases { - // Capture the range variable as per https://blog.golang.org/subtests - // Not doing this will lead to tc being set to the last entry in the testCases - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - testListS3Bucket(t, tc.args) - }) - } +func (m mockedS3Buckets) DeleteBucketPolicy(*s3.DeleteBucketPolicyInput) (*s3.DeleteBucketPolicyOutput, error) { + return &m.DeleteBucketPolicyOutput, nil } -// TestNukeS3BucketArgs represents arguments for TestNukeS3Bucket -type TestNukeS3BucketArgs struct { - isVersioned bool - checkDeleteMarker bool - objectCount int - objectBatchsize int - shouldNuke bool - shouldError bool +func (m mockedS3Buckets) DeleteBucket(*s3.DeleteBucketInput) (*s3.DeleteBucketOutput, error) { + return &m.DeleteBucketOutput, nil } -// testNukeS3Bucket - generates the test function for TestNukeS3Bucket -func testNukeS3Bucket(t *testing.T, args TestNukeS3BucketArgs) { - awsParams, err := newS3TestAWSParams("eu-central-1") - require.NoError(t, err, "Failed to setup AWS params") - - // Create test bucket - bucketName := S3TestGenBucketName() - var bucketTags []map[string]string - - err = S3TestCreateBucket(awsParams.svc, bucketName, bucketTags, args.isVersioned) - require.NoError(t, err, "Failed to create test bucket") - - awsSession, err := S3TestCreateNewAWSSession("") - require.NoError(t, err, "Failed to create random session") - - if args.objectCount > 0 { - objectVersions := 1 - if args.isVersioned { - objectVersions = 3 - } - - // Add two more versions of the same file - for i := 0; i < objectVersions; i++ { - for j := 0; j < args.objectCount; j++ { - fileName := fmt.Sprintf("l1/l2/l3/f%d.txt", j) - fileBody := fmt.Sprintf("%d-%d", i, j) - err := S3TestBucketAddObject(awsParams, bucketName, fileName, fileBody) - require.NoError(t, err, "Failed to add object to test bucket") - } - } - - // Do a simple delete to create DeleteMarker object - if args.checkDeleteMarker { - targetObject := "l1/l2/l3/f0.txt" - logging.Logger.Debugf("Bucket: %s - doing simple delete on object: %s", bucketName, targetObject) - - _, err = awsParams.svc.DeleteObject(&s3.DeleteObjectInput{ - Bucket: aws.String(bucketName), - Key: aws.String("l1/l2/l3/f0.txt"), - }) - require.NoError(t, err, "Failed to create delete marker") - } - } - - // Don't remove this. - // It ensures that all S3 buckets created as part of this test will be nuked after the test has run. - // This is necessary, as some test cases are expected to fail & test that the buckets with invalid args are not nuked. - // For more details, look at Github issue-140: https://github.com/gruntwork-io/cloud-nuke/issues/140 - defer nukeAllS3Buckets(awsParams.awsSession, []*string{aws.String(bucketName)}, 1000) - - // Nuke the test bucket - delCount, err := nukeAllS3Buckets(awsParams.awsSession, []*string{aws.String(bucketName)}, args.objectBatchsize) - if args.shouldError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - // If we should not nuke the bucket then deleted bucket count should be 0 - if !args.shouldNuke { - if delCount > 0 { - require.Failf(t, "Should not nuke but got delCount > 0", "delCount: %d", delCount) - } - logging.Logger.Debugf("SUCCESS: Did not nuke bucket - %s", bucketName) - return - } - - var configObj *config.Config - configObj, err = config.GetConfig(readTemplate(t, "../config/mocks/s3_include_names.yaml", map[string]string{"__TESTID__": ""})) - require.NoError(t, err) - - // Verify that - after nuking test bucket - it should not exist - bucketNamesPerRegion, err := getAllS3Buckets(awsSession, time.Now().Add(1*time.Hour), []string{awsParams.region}, bucketName, 100, *configObj) - require.NoError(t, err, "Failed to list S3 Buckets") - require.NotContains(t, bucketNamesPerRegion[awsParams.region], aws.String(bucketName)) - logging.Logger.Debugf("SUCCESS: Nuked bucket - %s", bucketName) -} - -// TestNukeS3Bucket tests S3 bucket deletion -func TestNukeS3Bucket(t *testing.T) { - t.Parallel() - +func TestS3Bucket_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") - type testCaseStruct struct { - name string - args TestNukeS3BucketArgs - } - - var allTestCases []testCaseStruct + t.Parallel() - for _, bucketType := range []string{"NoVersioning", "Versioning"} { - isVersioned := bucketType == "Versioning" - testCases := []testCaseStruct{ - { - bucketType + "_EmptyBucket", - TestNukeS3BucketArgs{ - isVersioned: isVersioned, - checkDeleteMarker: false, - objectCount: 0, - objectBatchsize: 1, - shouldNuke: true, - shouldError: false, + testName1 := "test-bucket-1" + testName2 := "test-bucket-2" + now := time.Now() + sb := S3Buckets{ + Client: mockedS3Buckets{ + ListBucketsOutput: s3.ListBucketsOutput{ + Buckets: []*s3.Bucket{ + { + Name: aws.String(testName1), + CreationDate: aws.Time(now), + }, + { + Name: aws.String(testName2), + CreationDate: aws.Time(now.Add(1)), + }, }, }, - { - bucketType + "_AllObjects", - TestNukeS3BucketArgs{ - isVersioned: isVersioned, - checkDeleteMarker: false, - objectCount: 10, - objectBatchsize: 1000, - shouldNuke: true, - shouldError: false, - }, + GetBucketLocationOutput: s3.GetBucketLocationOutput{ + LocationConstraint: aws.String("us-east-1"), }, - { - bucketType + "_BatchObjects_ValidBatchSize", - TestNukeS3BucketArgs{ - isVersioned: isVersioned, - checkDeleteMarker: false, - objectCount: 30, - objectBatchsize: 5, - shouldNuke: true, - shouldError: false, - }, - }, - { - bucketType + "_BatchObjects_InvalidBatchSize_Over", - TestNukeS3BucketArgs{ - isVersioned: isVersioned, - checkDeleteMarker: false, - objectCount: 2, - objectBatchsize: 1001, - shouldNuke: false, - shouldError: true, - }, + GetBucketTaggingOutput: s3.GetBucketTaggingOutput{ + TagSet: []*s3.Tag{}, }, - { - bucketType + "_BatchObjects_InvalidBatchSize_Under", - TestNukeS3BucketArgs{ - isVersioned: isVersioned, - checkDeleteMarker: false, - objectCount: 2, - objectBatchsize: 0, - shouldNuke: false, - shouldError: true, - }, - }, - } - for _, tc := range testCases { - allTestCases = append(allTestCases, tc) - } - } - - allTestCases = append(allTestCases, testCaseStruct{ - "Versioning_DeleteMarker", - TestNukeS3BucketArgs{ - isVersioned: true, - checkDeleteMarker: true, - objectCount: 10, - objectBatchsize: 1000, - shouldNuke: true, }, - }) - for _, tc := range allTestCases { - // Capture the range variable as per https://blog.golang.org/subtests - // Not doing this will lead to tc being set to the last entry in the testCases - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - testNukeS3Bucket(t, tc.args) - }) - } -} - -// readTemplate - read and replace variables in template, return path to temporary file with processed template. -func readTemplate(t *testing.T, templatePath string, variables map[string]string) string { - _, name := path.Split(templatePath) - file, err := ioutil.TempFile(os.TempDir(), "*"+name) - require.NoError(t, err) - - defer file.Close() - content, err := ioutil.ReadFile(templatePath) - require.NoError(t, err) - - template := string(content) - - for key, value := range variables { - template = strings.Replace(template, key, value, -1) + Region: "us-east-1", } - _, err = file.WriteString(template) - require.NoError(t, err) - - return file.Name() -} - -// TestFilterS3BucketArgs represents arguments for TestFilterS3Bucket_Config -type TestFilterS3BucketArgs struct { - configFilePath string - exactMatch bool - matches []string -} - -func bucketNamesForConfigTests(id string) []string { - return []string{ - "alb-alb-123456-access-logs-" + id, - "alb-alb-234567-access-logs-" + id, - "tonico-prod-alb-access-logs-" + id, - "prod-alb-public-access-logs-" + id, - "stage-alb-internal-access-logs-" + id, - "stage-alb-public-access-logs-" + id, - "cloud-watch-logs-staging-" + id, - "something-else-logs-staging-" + id, - } -} - -// TestFilterS3Bucket_Config tests listing only S3 buckets that match config file -func TestFilterS3Bucket_Config(t *testing.T) { - t.Parallel() - - telemetry.InitTelemetry("cloud-nuke", "") - testId := S3TestGenBucketName() - logging.Logger.Debugf("Generated test id %v", testId) - - // Create AWS session in ca-central-1 - awsParams, err := newS3TestAWSParams("ca-central-1") - require.NoError(t, err, "Failed to setup AWS params") - - // Nuke all buckets in ca-central-1 first - // passing in a config that matches all buckets - var configObj *config.Config - configObj, err = config.GetConfig(readTemplate(t, "../config/mocks/s3_all.yaml", map[string]string{"__TESTID__": testId})) - - // Verify that only filtered buckets are listed - cleanupBuckets, err := getAllS3Buckets(awsParams.awsSession, time.Now().Add(1*time.Hour), []string{awsParams.region}, "", 100, *configObj) - require.NoError(t, err, "Failed to list S3 Buckets in ca-central-1") - - _, err = nukeAllS3Buckets(awsParams.awsSession, cleanupBuckets[awsParams.region], 1000) - require.NoError(t, err) - - // Create test buckets in ca-central-1 - var bucketTags []map[string]string - bucketNames := bucketNamesForConfigTests(testId) - for _, bucketName := range bucketNames { - err = S3TestCreateBucket(awsParams.svc, bucketName, bucketTags, false) - require.NoErrorf(t, err, "Failed to create test bucket - %s", bucketName) - } - - // Please note that we are not reusing awsParams.awsSession and creating a random session in a region other - // than the one in which the bucket is created - this is useful to test the scenario where the user has - // AWS_DEFAULT_REGION set to region x but the bucket is in region y. - awsSession, err := S3TestCreateNewAWSSession("") - require.NoError(t, err, "Failed to create session in random region") - - // Define test cases - type testCaseStruct struct { - name string - args TestFilterS3BucketArgs - } - - includeBuckets := []string{} - includeBuckets = append(includeBuckets, bucketNames[:4]...) - - excludeBuckets := []string{} - excludeBuckets = append(excludeBuckets, bucketNames[:3]...) - excludeBuckets = append(excludeBuckets, bucketNames[4]) - excludeBuckets = append(excludeBuckets, bucketNames[6:]...) - - filterBuckets := []string{} - filterBuckets = append(filterBuckets, bucketNames[:3]...) - - testCases := []testCaseStruct{ - { - "Include", - TestFilterS3BucketArgs{ - configFilePath: readTemplate(t, "../config/mocks/s3_include_names.yaml", map[string]string{"__TESTID__": testId}), - matches: includeBuckets, - }, + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testName1, testName2}, }, - { - "Exclude", - TestFilterS3BucketArgs{ - configFilePath: readTemplate(t, "../config/mocks/s3_exclude_names.yaml", map[string]string{"__TESTID__": testId}), - matches: excludeBuckets, - // exclude match may include multiple buckets than created during test - // https://github.com/gruntwork-io/cloud-nuke/issues/142 - exactMatch: false, - }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testName2}, }, - { - "IncludeAndExclude", - TestFilterS3BucketArgs{ - configFilePath: readTemplate(t, "../config/mocks/s3_filter_names.yaml", map[string]string{"__TESTID__": testId}), - matches: filterBuckets, - }, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }}, + expected: []string{testName1}, }, } - - // Clean up test buckets - defer func() { - _, err := nukeAllS3Buckets(awsParams.awsSession, aws.StringSlice(bucketNames), 1000) - assert.NoError(t, err) - }() - t.Run("config tests", func(t *testing.T) { - for _, tc := range testCases { - // Capture the range variable as per https://blog.golang.org/subtests - // Not doing this will lead to tc being set to the last entry in the testCases - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var configObj *config.Config - configObj, err = config.GetConfig(tc.args.configFilePath) - - // Verify that only filtered buckets are listed (use random region) - bucketNamesPerRegion, err := getAllS3Buckets(awsSession, time.Now().Add(1*time.Hour), []string{awsParams.region}, "", 100, *configObj) - - require.NoError(t, err, "Failed to list S3 Buckets") - if tc.args.exactMatch { - require.Equal(t, len(tc.args.matches), len(bucketNamesPerRegion[awsParams.region])) - } else { - // in case of not exact match, at least check if number of matched buckets are more or equal to arg count - require.True(t, len(bucketNamesPerRegion[awsParams.region]) >= len(tc.args.matches)) - } - require.Subset(t, aws.StringValueSlice(bucketNamesPerRegion[awsParams.region]), tc.args.matches) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := sb.getAll(config.Config{ + S3: tc.configObj, }) - } - }) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.StringValueSlice(names)) + }) + } } -// TestNukeS3BucketWithBucketPolicy tests deletion of S3 buckets with a policy that denies deletion -func TestNukeS3BucketWithBucketPolicy(t *testing.T) { +func TestS3Bucket_NukeAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") - awsParams, err := newS3TestAWSParams("") - require.NoError(t, err, "Failed to setup AWS params") - - // Create test bucket - bucketName := S3TestGenBucketName() - var bucketTags []map[string]string - - err = S3TestCreateBucket(awsParams.svc, bucketName, bucketTags, false) - require.NoError(t, err, "Failed to create test bucket") + t.Parallel() - policy, _ := json.Marshal(map[string]interface{}{ - "Version": "2012-10-17", - "Statement": []map[string]interface{}{ - { - "Effect": "Deny", - "Principal": "*", - "Action": []string{ - "s3:DeleteBucket", + sb := S3Buckets{ + Client: mockedS3Buckets{ + GetBucketVersioningOutput: s3.GetBucketVersioningOutput{ + Status: aws.String("Enabled"), + }, + ListObjectVersionsPagesOutput: s3.ListObjectVersionsOutput{ + Versions: []*s3.ObjectVersion{ + { + Key: aws.String("test-key"), + VersionId: aws.String("test-version-id"), + }, }, - "Resource": []string{ - fmt.Sprintf("arn:aws:s3:::%s", bucketName), + DeleteMarkers: []*s3.DeleteMarkerEntry{ + { + Key: aws.String("test-key"), + VersionId: aws.String("test-version-id"), + }, }, }, + DeleteObjectsOutput: s3.DeleteObjectsOutput{}, + DeleteBucketPolicyOutput: s3.DeleteBucketPolicyOutput{}, + DeleteBucketOutput: s3.DeleteBucketOutput{}, }, - }) - - _, err = awsParams.svc.PutBucketPolicy(&s3.PutBucketPolicyInput{ - Bucket: aws.String(bucketName), - Policy: aws.String(string(policy)), - }) - require.NoError(t, err) - - defer func() { - /* - If the policy was not removed, delete it manually and delete - the bucket to not leave any test data in the account - */ - awsParams.svc.DeleteBucketPolicy(&s3.DeleteBucketPolicyInput{ - Bucket: aws.String(bucketName), - }) - nukeAllS3Buckets(awsParams.awsSession, []*string{aws.String(bucketName)}, 1000) - }() + } - _, err = nukeAllS3Buckets(awsParams.awsSession, []*string{aws.String(bucketName)}, 1000) + count, err := sb.nukeAll([]*string{aws.String("test-bucket")}) require.NoError(t, err) - + require.Equal(t, 1, count) } diff --git a/aws/s3_types.go b/aws/s3_types.go index 5d5ed7e9..4801e622 100644 --- a/aws/s3_types.go +++ b/aws/s3_types.go @@ -45,7 +45,7 @@ func (bucket S3Buckets) ResourceIdentifiers() []string { // Nuke - nuke 'em all!!! func (bucket S3Buckets) Nuke(session *session.Session, identifiers []string) error { - delCount, err := nukeAllS3Buckets(session, aws.StringSlice(identifiers), bucket.ObjectMaxBatchSize()) + delCount, err := bucket.nukeAll(aws.StringSlice(identifiers)) totalCount := len(identifiers) if delCount > 0 {