Skip to content

Commit

Permalink
Refactor SNSTopic Resource Type (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 authored Jul 26, 2023
1 parent 65bd31b commit 4723121
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 364 deletions.
2 changes: 1 addition & 1 deletion aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
}
if IsNukeable(snsTopics.ResourceName(), resourceTypes) {
start := time.Now()
snsTopicArns, err := getAllSNSTopics(cloudNukeSession, excludeAfter, configObj)
snsTopicArns, err := snsTopics.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down
101 changes: 45 additions & 56 deletions aws/sns.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
package aws

import (
"context"
"github.com/aws/aws-sdk-go/service/sns"
"strings"
"sync"
"time"

"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sns"
"github.com/aws/aws-sdk-go-v2/service/sns/types"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
Expand All @@ -24,60 +20,62 @@ import (
// getAllSNSTopics returns a list of all SNS topics in the region, filtering the name by the config
// The SQS APIs do not return a creation date, therefore we tag the resources with a first seen time when the topic first appears. We then
// use that tag to measure the excludeAfter time duration, and determine whether to nuke the resource based on that.
func getAllSNSTopics(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) {
ctx := context.TODO()
func (s SNSTopic) getAll(configObj config.Config) ([]*string, error) {

cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(aws.StringValue(session.Config.Region)))
if err != nil {
return []*string{}, errors.WithStackTrace(err)
}
svc := sns.NewFromConfig(cfg)

snsTopics := []*string{}

paginator := sns.NewListTopicsPaginator(svc, nil)

for paginator.HasMorePages() {
resp, err := paginator.NextPage(ctx)
if err != nil {
return []*string{}, errors.WithStackTrace(err)
}
for _, topic := range resp.Topics {
firstSeenTime, err := getFirstSeenSNSTopicTag(ctx, svc, *topic.TopicArn, firstSeenTagKey)
var snsTopics []*string
err := s.Client.ListTopicsPages(&sns.ListTopicsInput{}, func(page *sns.ListTopicsOutput, lastPage bool) bool {
for _, topic := range page.Topics {
firstSeenTime, err := s.getFirstSeenSNSTopicTag(*topic.TopicArn)
if err != nil {
logging.Logger.Errorf("Unable to retrieve tags for SNS Topic: %s, with error: %s", *topic.TopicArn, err)
return nil, err
logging.Logger.Errorf(
"Unable to retrieve tags for SNS Topic: %s, with error: %s", *topic.TopicArn, err)
continue
}

if firstSeenTime == nil {
now := time.Now().UTC()
firstSeenTime = &now
if err := setFirstSeenSNSTopicTag(ctx, svc, *topic.TopicArn, firstSeenTagKey, now); err != nil {
logging.Logger.Errorf("Unable to apply first seen tag SNS Topic: %s, with error: %s", *topic.TopicArn, err)
return nil, err
if err := s.setFirstSeenSNSTopicTag(*topic.TopicArn, now); err != nil {
logging.Logger.Errorf(
"Unable to apply first seen tag SNS Topic: %s, with error: %s", *topic.TopicArn, err)
continue
}
}

if shouldIncludeSNS(*topic.TopicArn, excludeAfter, *firstSeenTime, configObj) {
// a topic arn is of the form arn:aws:sns:us-east-1:123456789012:MyTopic
// so we can search for the index of the last colon, then slice the string to get the topic name
nameIndex := strings.LastIndex(*topic.TopicArn, ":")
topicName := (*topic.TopicArn)[nameIndex+1:]
if configObj.SNS.ShouldInclude(config.ResourceValue{
Time: firstSeenTime,
Name: &topicName,
}) {
snsTopics = append(snsTopics, topic.TopicArn)
}
}

return !lastPage
})

if err != nil {
return nil, errors.WithStackTrace(err)
}

return snsTopics, nil
}

// getFirstSeenSNSTopicTag will retrive the time that the topic was first seen, otherwise returning nil if the topic has not been
// seen before.
func getFirstSeenSNSTopicTag(ctx context.Context, svc *sns.Client, topicArn, key string) (*time.Time, error) {
response, err := svc.ListTagsForResource(ctx, &sns.ListTagsForResourceInput{
func (s SNSTopic) getFirstSeenSNSTopicTag(topicArn string) (*time.Time, error) {
response, err := s.Client.ListTagsForResource(&sns.ListTagsForResourceInput{
ResourceArn: &topicArn,
})
if err != nil {
return nil, err
}

for i := range response.Tags {
if *response.Tags[i].Key == key {
if *response.Tags[i].Key == firstSeenTagKey {
firstSeenTime, err := time.Parse(firstSeenTimeFormat, *response.Tags[i].Value)
if err != nil {
return nil, err
Expand All @@ -91,16 +89,15 @@ func getFirstSeenSNSTopicTag(ctx context.Context, svc *sns.Client, topicArn, key
}

// setFirstSeenSNSTopic will append a tag to the SNS Topic that details the first seen time.
func setFirstSeenSNSTopicTag(ctx context.Context, svc *sns.Client, topicArn, key string, value time.Time) error {
func (s SNSTopic) setFirstSeenSNSTopicTag(topicArn string, value time.Time) error {
timeValue := value.Format(firstSeenTimeFormat)

_, err := svc.TagResource(
ctx,
_, err := s.Client.TagResource(
&sns.TagResourceInput{
ResourceArn: &topicArn,
Tags: []types.Tag{
Tags: []*sns.Tag{
{
Key: &key,
Key: aws.String(firstSeenTagKey),
Value: &timeValue,
},
},
Expand All @@ -127,17 +124,9 @@ func shouldIncludeSNS(topicArn string, excludeAfter, firstSeenTime time.Time, co
return config.ShouldInclude(topicName, configObj.SNS.IncludeRule.NamesRegExp, configObj.SNS.ExcludeRule.NamesRegExp)
}

func nukeAllSNSTopics(session *session.Session, identifiers []*string) error {
region := aws.StringValue(session.Config.Region)

cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), awsconfig.WithRegion(aws.StringValue(session.Config.Region)))
if err != nil {
return errors.WithStackTrace(err)
}
svc := sns.NewFromConfig(cfg)

func (s SNSTopic) nukeAll(identifiers []*string) error {
if len(identifiers) == 0 {
logging.Logger.Debugf("No SNS Topics to nuke in region %s", region)
logging.Logger.Debugf("No SNS Topics to nuke in region %s", s.Region)
}

if len(identifiers) > 100 {
Expand All @@ -146,13 +135,13 @@ func nukeAllSNSTopics(session *session.Session, identifiers []*string) error {
}

// There is no bulk delete SNS API, so we delete the batch of SNS Topics concurrently using goroutines
logging.Logger.Debugf("Deleting SNS Topics in region %s", region)
logging.Logger.Debugf("Deleting SNS Topics in region %s", s.Region)
wg := new(sync.WaitGroup)
wg.Add(len(identifiers))
errChans := make([]chan error, len(identifiers))
for i, topicArn := range identifiers {
errChans[i] = make(chan error, 1)
go deleteSNSTopicAsync(wg, errChans[i], svc, topicArn, region)
go s.deleteAsync(wg, errChans[i], topicArn)
}
wg.Wait()

Expand All @@ -164,7 +153,7 @@ func nukeAllSNSTopics(session *session.Session, identifiers []*string) error {
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking SNS Topic",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": s.Region,
})
}
}
Expand All @@ -175,16 +164,16 @@ func nukeAllSNSTopics(session *session.Session, identifiers []*string) error {
return nil
}

func deleteSNSTopicAsync(wg *sync.WaitGroup, errChan chan error, svc *sns.Client, topicArn *string, region string) {
func (s SNSTopic) deleteAsync(wg *sync.WaitGroup, errChan chan error, topicArn *string) {
defer wg.Done()

deleteParam := &sns.DeleteTopicInput{
TopicArn: topicArn,
}

logging.Logger.Debugf("Deleting SNS Topic (arn=%s) in region: %s", aws.StringValue(topicArn), region)
logging.Logger.Debugf("Deleting SNS Topic (arn=%s) in region: %s", aws.StringValue(topicArn), s.Region)

_, err := svc.DeleteTopic(context.TODO(), deleteParam)
_, err := s.Client.DeleteTopic(deleteParam)

errChan <- err

Expand All @@ -197,8 +186,8 @@ func deleteSNSTopicAsync(wg *sync.WaitGroup, errChan chan error, svc *sns.Client
report.Record(e)

if err == nil {
logging.Logger.Debugf("[OK] Deleted SNS Topic (arn=%s) in region: %s", aws.StringValue(topicArn), region)
logging.Logger.Debugf("[OK] Deleted SNS Topic (arn=%s) in region: %s", aws.StringValue(topicArn), s.Region)
} else {
logging.Logger.Debugf("[Failed] Error deleting SNS Topic (arn=%s) in %s", aws.StringValue(topicArn), region)
logging.Logger.Debugf("[Failed] Error deleting SNS Topic (arn=%s) in %s", aws.StringValue(topicArn), s.Region)
}
}
Loading

0 comments on commit 4723121

Please sign in to comment.