Skip to content

Commit

Permalink
Refactor SQS Resource Type (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 authored Jul 26, 2023
1 parent 74ab0b4 commit 0f69d82
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 124 deletions.
2 changes: 1 addition & 1 deletion aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
}
if IsNukeable(sqsQueue.ResourceName(), resourceTypes) {
start := time.Now()
queueUrls, err := getAllSqsQueue(cloudNukeSession, region, excludeAfter)
queueUrls, err := sqsQueue.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down
32 changes: 15 additions & 17 deletions aws/sqs.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
"github.com/gruntwork-io/go-commons/errors"
)

// Returns a formatted string of SQS Queue URLs
func getAllSqsQueue(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := sqs.New(session)

func (sq SqsQueue) getAll(configObj config.Config) ([]*string, error) {
result := []*string{}
paginator := func(output *sqs.ListQueuesOutput, lastPage bool) bool {
result = append(result, output.QueueUrls...)
Expand All @@ -28,7 +25,7 @@ func getAllSqsQueue(session *session.Session, region string, excludeAfter time.T
param := &sqs.ListQueuesInput{
MaxResults: awsgo.Int64(10),
}
err := svc.ListQueuesPages(param, paginator)
err := sq.Client.ListQueuesPages(param, paginator)
if err != nil {
return nil, errors.WithStackTrace(err)
}
Expand All @@ -40,20 +37,23 @@ func getAllSqsQueue(session *session.Session, region string, excludeAfter time.T
QueueUrl: queue,
AttributeNames: awsgo.StringSlice([]string{"CreatedTimestamp"}),
}
queueAttributes, err := svc.GetQueueAttributes(param)
queueAttributes, err := sq.Client.GetQueueAttributes(param)
if err != nil {
return nil, errors.WithStackTrace(err)
}

// Convert string timestamp to int64
createdAt := *queueAttributes.Attributes["CreatedTimestamp"]
createdAtInt, err := strconv.ParseInt(createdAt, 10, 64)
createdAtTime, err := time.Parse(time.RFC3339, createdAt)
if err != nil {
return nil, errors.WithStackTrace(err)
}

// Compare time as int64
if excludeAfter.Unix() > createdAtInt {
if configObj.SQS.ShouldInclude(config.ResourceValue{
Name: queue,
Time: &createdAtTime,
}) {
urls = append(urls, queue)
}
}
Expand All @@ -62,23 +62,21 @@ func getAllSqsQueue(session *session.Session, region string, excludeAfter time.T
}

// Deletes all Elastic Load Balancers
func nukeAllSqsQueues(session *session.Session, urls []*string) error {
svc := sqs.New(session)

func (sq SqsQueue) nukeAll(urls []*string) error {
if len(urls) == 0 {
logging.Logger.Debugf("No SQS Queues to nuke in region %s", *session.Config.Region)
logging.Logger.Debugf("No SQS Queues to nuke in region %s", sq.Region)
return nil
}

logging.Logger.Debugf("Deleting all SQS Queues in region %s", *session.Config.Region)
logging.Logger.Debugf("Deleting all SQS Queues in region %s", sq.Region)
var deletedUrls []*string

for _, url := range urls {
params := &sqs.DeleteQueueInput{
QueueUrl: url,
}

_, err := svc.DeleteQueue(params)
_, err := sq.Client.DeleteQueue(params)

// Record status of this resource
e := report.Entry{
Expand All @@ -93,15 +91,15 @@ func nukeAllSqsQueues(session *session.Session, urls []*string) error {
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking SQS Queue",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": sq.Region,
})
} else {
deletedUrls = append(deletedUrls, url)
logging.Logger.Debugf("Deleted SQS Queue: %s", *url)
}
}

logging.Logger.Debugf("[OK] %d SQS Queue(s) deleted in %s", len(deletedUrls), *session.Config.Region)
logging.Logger.Debugf("[OK] %d SQS Queue(s) deleted in %s", len(deletedUrls), sq.Region)

return nil
}
189 changes: 89 additions & 100 deletions aws/sqs_test.go
Original file line number Diff line number Diff line change
@@ -1,128 +1,117 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
"testing"
"time"

"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/go-commons/retry"

"github.com/aws/aws-sdk-go-v2/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/gruntwork-io/cloud-nuke/util"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/stretchr/testify/require"
"regexp"
"testing"
"time"
)

func createTestQueue(t *testing.T, session *session.Session, name string) *string {
svc := sqs.New(session)
type mockedSqsQueue struct {
sqsiface.SQSAPI
DeleteQueueOutput sqs.DeleteQueueOutput
GetQueueAttributesOutput map[string]sqs.GetQueueAttributesOutput
ListQueuesOutput sqs.ListQueuesOutput
}

param := &sqs.CreateQueueInput{
QueueName: awsgo.String(name),
Attributes: map[string]*string{
"DelaySeconds": awsgo.String("60"),
"MessageRetentionPeriod": awsgo.String("86400"),
},
}
func (m mockedSqsQueue) ListQueuesPages(input *sqs.ListQueuesInput, fn func(*sqs.ListQueuesOutput, bool) bool) error {
fn(&m.ListQueuesOutput, true)
return nil
}

result, err := svc.CreateQueue(param)
require.NoError(t, err)
require.True(t, len(awsgo.StringValue(result.QueueUrl)) > 0, "Can't create test Sqs Queue")
func (m mockedSqsQueue) GetQueueAttributes(input *sqs.GetQueueAttributesInput) (*sqs.GetQueueAttributesOutput, error) {
url := input.QueueUrl
resp := m.GetQueueAttributesOutput[*url]

err = retry.DoWithRetry(
logging.Logger,
"Check if queue is created",
3,
5*time.Second,
func() error {
_, err = svc.GetQueueUrl(&sqs.GetQueueUrlInput{QueueName: awsgo.String(name)})
return err
},
)
require.NoError(t, err)
return result.QueueUrl
return &resp, nil
}

func TestListSqsQueue(t *testing.T) {
func (m mockedSqsQueue) DeleteQueue(*sqs.DeleteQueueInput) (*sqs.DeleteQueueOutput, error) {
return &m.DeleteQueueOutput, nil
}

func TestSqsQueue_GetAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)

session, err := session.NewSession(&awsgo.Config{
Region: awsgo.String(region)},
)
require.NoError(t, err)

// create 20 test queues, to validate pagination
queueList := []*string{}
for n := 0; n < 20; n++ {
queueName := "cloud-nuke-test-" + util.UniqueID()
queueUrl := createTestQueue(t, session, queueName)
require.NoError(t, err)

queueList = append(queueList, queueUrl)
queue1 := "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue1"
queue2 := "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue2"
now := time.Now()
sq := SqsQueue{
Client: mockedSqsQueue{
ListQueuesOutput: sqs.ListQueuesOutput{
QueueUrls: []*string{
awsgo.String(queue1),
awsgo.String(queue2),
},
},
GetQueueAttributesOutput: map[string]sqs.GetQueueAttributesOutput{
queue1: {
Attributes: map[string]*string{
"CreatedTimestamp": awsgo.String(now.Format(time.RFC3339)),
},
},
queue2: {
Attributes: map[string]*string{
"CreatedTimestamp": awsgo.String(now.Add(1).Format(time.RFC3339)),
},
},
},
},
}

// clean up after this test
defer nukeAllSqsQueues(session, queueList)

// timestamps to test
oneHourAgo := time.Now().Add(1 * time.Hour * -1)
oneHourFromNow := time.Now().Add(1 * time.Hour)

urls, err := getAllSqsQueue(session, region, oneHourAgo)
require.NoError(t, err)

for _, queue := range queueList {
assert.NotContains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queue))
tests := map[string]struct {
configObj config.ResourceType
expected []string
}{
"emptyFilter": {
configObj: config.ResourceType{},
expected: []string{queue1, queue2},
},
"nameExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
NamesRegExp: []config.Expression{{
RE: *regexp.MustCompile("MyQueue1"),
}}},
},
expected: []string{queue2},
},
"timeAfterExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now.Add(-1 * time.Hour)),
}},
expected: []string{},
},
}

urls, err = getAllSqsQueue(session, region, oneHourFromNow)
require.NoError(t, err)

for _, queue := range queueList {
assert.Contains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queue))
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
names, err := sq.getAll(config.Config{
SQS: tc.configObj,
})
require.NoError(t, err)
require.Equal(t, tc.expected, awsgo.StringValueSlice(names))
})
}
}

func TestNukeSqsQueue(t *testing.T) {
func TestSqsQueue_NukeAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)

session, err := session.NewSession(&awsgo.Config{
Region: awsgo.String(region)},
)
require.NoError(t, err)

queueName := "cloud-nuke-test-" + util.UniqueID()
queueUrl := createTestQueue(t, session, queueName)
oneHourFromNow := time.Now().Add(1 * time.Hour)

urls, err := getAllSqsQueue(session, region, oneHourFromNow)
require.NoError(t, err)
assert.Contains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queueUrl))

err = nukeAllSqsQueues(session, []*string{queueUrl})
require.NoError(t, err)

// SQS Queue deletion takes up to 60 seconds to be finished. See https://docs.aws.amazon.com/sdk-for-go/api/service/sqs/#SQS.DeleteQueue
for retry := 0; retry <= 6; retry++ {
urls, err = getAllSqsQueue(session, region, oneHourFromNow)
if err == nil {
break
}

sleepMessage := "SQS Queue still available. Waiting 10 seconds to check again."
sleepFor := 10 * time.Second
sleepWithMessage(sleepFor, sleepMessage)
sq := SqsQueue{
Client: mockedSqsQueue{
DeleteQueueOutput: sqs.DeleteQueueOutput{},
},
}

err := sq.nukeAll([]*string{aws.String("test")})
require.NoError(t, err)
assert.NotContains(t, awsgo.StringValueSlice(urls), awsgo.StringValue(queueUrl))
}
12 changes: 6 additions & 6 deletions aws/sqs_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ type SqsQueue struct {
}

// ResourceName - the simple name of the aws resource
func (queue SqsQueue) ResourceName() string {
func (sq SqsQueue) ResourceName() string {
return "sqs"
}

func (queue SqsQueue) MaxBatchSize() int {
func (sq SqsQueue) MaxBatchSize() int {
// Tentative batch size to ensure AWS doesn't throttle
return 49
}

// ResourceIdentifiers - The arns of the sqs queues
func (queue SqsQueue) ResourceIdentifiers() []string {
return queue.QueueUrls
func (sq SqsQueue) ResourceIdentifiers() []string {
return sq.QueueUrls
}

// Nuke - nuke 'em all!!!
func (queue SqsQueue) Nuke(session *session.Session, identifiers []string) error {
if err := nukeAllSqsQueues(session, awsgo.StringSlice(identifiers)); err != nil {
func (sq SqsQueue) Nuke(session *session.Session, identifiers []string) error {
if err := sq.nukeAll(awsgo.StringSlice(identifiers)); err != nil {
return errors.WithStackTrace(err)
}

Expand Down

0 comments on commit 0f69d82

Please sign in to comment.