diff --git a/aws/aws.go b/aws/aws.go index 9b705ffc..3f993732 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -454,7 +454,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(sqsQueue.ResourceName(), resourceTypes) { start := time.Now() - queueUrls, err := getAllSqsQueue(cloudNukeSession, region, excludeAfter) + queueUrls, err := sqsQueue.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/sqs.go b/aws/sqs.go index 8a8b0182..bdcac892 100644 --- a/aws/sqs.go +++ b/aws/sqs.go @@ -1,14 +1,13 @@ package aws import ( + "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/telemetry" commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" - "strconv" "time" "github.com/aws/aws-sdk-go/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" @@ -16,9 +15,7 @@ import ( ) // Returns a formatted string of SQS Queue URLs -func getAllSqsQueue(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) { - svc := sqs.New(session) - +func (sq SqsQueue) getAll(configObj config.Config) ([]*string, error) { result := []*string{} paginator := func(output *sqs.ListQueuesOutput, lastPage bool) bool { result = append(result, output.QueueUrls...) @@ -28,7 +25,7 @@ func getAllSqsQueue(session *session.Session, region string, excludeAfter time.T param := &sqs.ListQueuesInput{ MaxResults: awsgo.Int64(10), } - err := svc.ListQueuesPages(param, paginator) + err := sq.Client.ListQueuesPages(param, paginator) if err != nil { return nil, errors.WithStackTrace(err) } @@ -40,20 +37,23 @@ func getAllSqsQueue(session *session.Session, region string, excludeAfter time.T QueueUrl: queue, AttributeNames: awsgo.StringSlice([]string{"CreatedTimestamp"}), } - queueAttributes, err := svc.GetQueueAttributes(param) + queueAttributes, err := sq.Client.GetQueueAttributes(param) if err != nil { return nil, errors.WithStackTrace(err) } // Convert string timestamp to int64 createdAt := *queueAttributes.Attributes["CreatedTimestamp"] - createdAtInt, err := strconv.ParseInt(createdAt, 10, 64) + createdAtTime, err := time.Parse(time.RFC3339, createdAt) if err != nil { return nil, errors.WithStackTrace(err) } // Compare time as int64 - if excludeAfter.Unix() > createdAtInt { + if configObj.SQS.ShouldInclude(config.ResourceValue{ + Name: queue, + Time: &createdAtTime, + }) { urls = append(urls, queue) } } @@ -62,15 +62,13 @@ func getAllSqsQueue(session *session.Session, region string, excludeAfter time.T } // Deletes all Elastic Load Balancers -func nukeAllSqsQueues(session *session.Session, urls []*string) error { - svc := sqs.New(session) - +func (sq SqsQueue) nukeAll(urls []*string) error { if len(urls) == 0 { - logging.Logger.Debugf("No SQS Queues to nuke in region %s", *session.Config.Region) + logging.Logger.Debugf("No SQS Queues to nuke in region %s", sq.Region) return nil } - logging.Logger.Debugf("Deleting all SQS Queues in region %s", *session.Config.Region) + logging.Logger.Debugf("Deleting all SQS Queues in region %s", sq.Region) var deletedUrls []*string for _, url := range urls { @@ -78,7 +76,7 @@ func nukeAllSqsQueues(session *session.Session, urls []*string) error { QueueUrl: url, } - _, err := svc.DeleteQueue(params) + _, err := sq.Client.DeleteQueue(params) // Record status of this resource e := report.Entry{ @@ -93,7 +91,7 @@ func nukeAllSqsQueues(session *session.Session, urls []*string) error { telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking SQS Queue", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": sq.Region, }) } else { deletedUrls = append(deletedUrls, url) @@ -101,7 +99,7 @@ func nukeAllSqsQueues(session *session.Session, urls []*string) error { } } - logging.Logger.Debugf("[OK] %d SQS Queue(s) deleted in %s", len(deletedUrls), *session.Config.Region) + logging.Logger.Debugf("[OK] %d SQS Queue(s) deleted in %s", len(deletedUrls), sq.Region) return nil } diff --git a/aws/sqs_test.go b/aws/sqs_test.go index f131406c..7df36269 100644 --- a/aws/sqs_test.go +++ b/aws/sqs_test.go @@ -1,128 +1,117 @@ package aws import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" - "testing" - "time" - - "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/go-commons/retry" - + "github.com/aws/aws-sdk-go-v2/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/stretchr/testify/assert" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/stretchr/testify/require" + "regexp" + "testing" + "time" ) -func createTestQueue(t *testing.T, session *session.Session, name string) *string { - svc := sqs.New(session) +type mockedSqsQueue struct { + sqsiface.SQSAPI + DeleteQueueOutput sqs.DeleteQueueOutput + GetQueueAttributesOutput map[string]sqs.GetQueueAttributesOutput + ListQueuesOutput sqs.ListQueuesOutput +} - param := &sqs.CreateQueueInput{ - QueueName: awsgo.String(name), - Attributes: map[string]*string{ - "DelaySeconds": awsgo.String("60"), - "MessageRetentionPeriod": awsgo.String("86400"), - }, - } +func (m mockedSqsQueue) ListQueuesPages(input *sqs.ListQueuesInput, fn func(*sqs.ListQueuesOutput, bool) bool) error { + fn(&m.ListQueuesOutput, true) + return nil +} - result, err := svc.CreateQueue(param) - require.NoError(t, err) - require.True(t, len(awsgo.StringValue(result.QueueUrl)) > 0, "Can't create test Sqs Queue") +func (m mockedSqsQueue) GetQueueAttributes(input *sqs.GetQueueAttributesInput) (*sqs.GetQueueAttributesOutput, error) { + url := input.QueueUrl + resp := m.GetQueueAttributesOutput[*url] - err = retry.DoWithRetry( - logging.Logger, - "Check if queue is created", - 3, - 5*time.Second, - func() error { - _, err = svc.GetQueueUrl(&sqs.GetQueueUrlInput{QueueName: awsgo.String(name)}) - return err - }, - ) - require.NoError(t, err) - return result.QueueUrl + return &resp, nil } -func TestListSqsQueue(t *testing.T) { +func (m mockedSqsQueue) DeleteQueue(*sqs.DeleteQueueInput) (*sqs.DeleteQueueOutput, error) { + return &m.DeleteQueueOutput, nil +} + +func TestSqsQueue_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - require.NoError(t, err) - - // create 20 test queues, to validate pagination - queueList := []*string{} - for n := 0; n < 20; n++ { - queueName := "cloud-nuke-test-" + util.UniqueID() - queueUrl := createTestQueue(t, session, queueName) - require.NoError(t, err) - - queueList = append(queueList, queueUrl) + queue1 := "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue1" + queue2 := "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue2" + now := time.Now() + sq := SqsQueue{ + Client: mockedSqsQueue{ + ListQueuesOutput: sqs.ListQueuesOutput{ + QueueUrls: []*string{ + awsgo.String(queue1), + awsgo.String(queue2), + }, + }, + GetQueueAttributesOutput: map[string]sqs.GetQueueAttributesOutput{ + queue1: { + Attributes: map[string]*string{ + "CreatedTimestamp": awsgo.String(now.Format(time.RFC3339)), + }, + }, + queue2: { + Attributes: map[string]*string{ + "CreatedTimestamp": awsgo.String(now.Add(1).Format(time.RFC3339)), + }, + }, + }, + }, } - // clean up after this test - defer nukeAllSqsQueues(session, queueList) - - // timestamps to test - oneHourAgo := time.Now().Add(1 * time.Hour * -1) - oneHourFromNow := time.Now().Add(1 * time.Hour) - - urls, err := getAllSqsQueue(session, region, oneHourAgo) - require.NoError(t, err) - - for _, queue := range queueList { - assert.NotContains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queue)) + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{queue1, queue2}, + }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile("MyQueue1"), + }}}, + }, + expected: []string{queue2}, + }, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1 * time.Hour)), + }}, + expected: []string{}, + }, } - - urls, err = getAllSqsQueue(session, region, oneHourFromNow) - require.NoError(t, err) - - for _, queue := range queueList { - assert.Contains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queue)) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := sq.getAll(config.Config{ + SQS: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, awsgo.StringValueSlice(names)) + }) } } -func TestNukeSqsQueue(t *testing.T) { +func TestSqsQueue_NukeAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - require.NoError(t, err) - - queueName := "cloud-nuke-test-" + util.UniqueID() - queueUrl := createTestQueue(t, session, queueName) - oneHourFromNow := time.Now().Add(1 * time.Hour) - - urls, err := getAllSqsQueue(session, region, oneHourFromNow) - require.NoError(t, err) - assert.Contains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queueUrl)) - - err = nukeAllSqsQueues(session, []*string{queueUrl}) - require.NoError(t, err) - - // SQS Queue deletion takes up to 60 seconds to be finished. See https://docs.aws.amazon.com/sdk-for-go/api/service/sqs/#SQS.DeleteQueue - for retry := 0; retry <= 6; retry++ { - urls, err = getAllSqsQueue(session, region, oneHourFromNow) - if err == nil { - break - } - - sleepMessage := "SQS Queue still available. Waiting 10 seconds to check again." - sleepFor := 10 * time.Second - sleepWithMessage(sleepFor, sleepMessage) + sq := SqsQueue{ + Client: mockedSqsQueue{ + DeleteQueueOutput: sqs.DeleteQueueOutput{}, + }, } + + err := sq.nukeAll([]*string{aws.String("test")}) require.NoError(t, err) - assert.NotContains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queueUrl)) } diff --git a/aws/sqs_types.go b/aws/sqs_types.go index 801a8932..eb0fe976 100644 --- a/aws/sqs_types.go +++ b/aws/sqs_types.go @@ -15,23 +15,23 @@ type SqsQueue struct { } // ResourceName - the simple name of the aws resource -func (queue SqsQueue) ResourceName() string { +func (sq SqsQueue) ResourceName() string { return "sqs" } -func (queue SqsQueue) MaxBatchSize() int { +func (sq SqsQueue) MaxBatchSize() int { // Tentative batch size to ensure AWS doesn't throttle return 49 } // ResourceIdentifiers - The arns of the sqs queues -func (queue SqsQueue) ResourceIdentifiers() []string { - return queue.QueueUrls +func (sq SqsQueue) ResourceIdentifiers() []string { + return sq.QueueUrls } // Nuke - nuke 'em all!!! -func (queue SqsQueue) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllSqsQueues(session, awsgo.StringSlice(identifiers)); err != nil { +func (sq SqsQueue) Nuke(session *session.Session, identifiers []string) error { + if err := sq.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) }