Skip to content

Commit

Permalink
Refactor Transite Gateway Resource Types (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 authored Jul 26, 2023
1 parent 1c22113 commit 74ab0b4
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 327 deletions.
21 changes: 6 additions & 15 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,9 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
Client: ec2.New(cloudNukeSession),
Region: region,
}
transitGatewayIsAvailable, err := tgIsAvailableInRegion(cloudNukeSession, region)
if err != nil {
ge := report.GeneralError{
Error: err,
Description: "Unable to retrieve Transit Gateways",
ResourceType: transitGatewayVpcAttachments.ResourceName(),
}
report.RecordError(ge)
}
if IsNukeable(transitGatewayVpcAttachments.ResourceName(), resourceTypes) && transitGatewayIsAvailable {
if IsNukeable(transitGatewayVpcAttachments.ResourceName(), resourceTypes) {
start := time.Now()
transitGatewayVpcAttachmentIds, err := getAllTransitGatewayVpcAttachments(cloudNukeSession, region, excludeAfter)
transitGatewayVpcAttachmentIds, err := transitGatewayVpcAttachments.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down Expand Up @@ -521,9 +512,9 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
Client: ec2.New(cloudNukeSession),
Region: region,
}
if IsNukeable(transitGatewayRouteTables.ResourceName(), resourceTypes) && transitGatewayIsAvailable {
if IsNukeable(transitGatewayRouteTables.ResourceName(), resourceTypes) {
start := time.Now()
transitGatewayRouteTableIds, err := getAllTransitGatewayRouteTables(cloudNukeSession, region, excludeAfter)
transitGatewayRouteTableIds, err := transitGatewayRouteTables.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down Expand Up @@ -551,9 +542,9 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
Client: ec2.New(cloudNukeSession),
Region: region,
}
if IsNukeable(transitGateways.ResourceName(), resourceTypes) && transitGatewayIsAvailable {
if IsNukeable(transitGateways.ResourceName(), resourceTypes) {
start := time.Now()
transitGatewayIds, err := getAllTransitGatewayInstances(cloudNukeSession, region, excludeAfter)
transitGatewayIds, err := transitGateways.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down
94 changes: 31 additions & 63 deletions aws/transit_gateway.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/gruntwork-io/go-commons/errors"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
)

func sleepWithMessage(duration time.Duration, whySleepMessage string) {
logging.Logger.Debugf("Sleeping %v: %s", duration, whySleepMessage)
time.Sleep(duration)
}

// Returns a formatted string of TransitGateway IDs
func getAllTransitGatewayInstances(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := ec2.New(session)
result, err := svc.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{})
func (tgw TransitGateways) getAll(configObj config.Config) ([]*string, error) {
result, err := tgw.Client.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, transitGateway := range result.TransitGateways {
if excludeAfter.After(*transitGateway.CreationTime) && awsgo.StringValue(transitGateway.State) != "deleted" && awsgo.StringValue(transitGateway.State) != "deleting" {
if configObj.TransitGateway.ShouldInclude(config.ResourceValue{Time: transitGateway.CreationTime}) &&
awsgo.StringValue(transitGateway.State) != "deleted" && awsgo.StringValue(transitGateway.State) != "deleting" {
ids = append(ids, transitGateway.TransitGatewayId)
}
}
Expand All @@ -39,23 +31,21 @@ func getAllTransitGatewayInstances(session *session.Session, region string, excl
}

// Delete all TransitGateways
func nukeAllTransitGatewayInstances(session *session.Session, ids []*string) error {
svc := ec2.New(session)

func (tgw TransitGateways) nukeAll(ids []*string) error {
if len(ids) == 0 {
logging.Logger.Debugf("No Transit Gateways to nuke in region %s", *session.Config.Region)
logging.Logger.Debugf("No Transit Gateways to nuke in region %s", tgw.Region)
return nil
}

logging.Logger.Debugf("Deleting all Transit Gateways in region %s", *session.Config.Region)
logging.Logger.Debugf("Deleting all Transit Gateways in region %s", tgw.Region)
var deletedIds []*string

for _, id := range ids {
params := &ec2.DeleteTransitGatewayInput{
TransitGatewayId: id,
}

_, err := svc.DeleteTransitGateway(params)
_, err := tgw.Client.DeleteTransitGateway(params)

// Record status of this resource
e := report.Entry{
Expand All @@ -70,22 +60,20 @@ func nukeAllTransitGatewayInstances(session *session.Session, ids []*string) err
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking Transit Gateway Instance",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": tgw.Region,
})
} else {
deletedIds = append(deletedIds, id)
logging.Logger.Debugf("Deleted Transit Gateway: %s", *id)
}
}

logging.Logger.Debugf("[OK] %d Transit Gateway(s) deleted in %s", len(deletedIds), *session.Config.Region)
logging.Logger.Debugf("[OK] %d Transit Gateway(s) deleted in %s", len(deletedIds), tgw.Region)
return nil
}

// Returns a formatted string of TranstGatewayRouteTable IDs
func getAllTransitGatewayRouteTables(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := ec2.New(session)

func (tgw TransitGatewaysRouteTables) getAll(configObj config.Config) ([]*string, error) {
// Remove defalt route table, that will be deleted along with its TransitGateway
param := &ec2.DescribeTransitGatewayRouteTablesInput{
Filters: []*ec2.Filter{
Expand All @@ -98,14 +86,15 @@ func getAllTransitGatewayRouteTables(session *session.Session, region string, ex
},
}

result, err := svc.DescribeTransitGatewayRouteTables(param)
result, err := tgw.Client.DescribeTransitGatewayRouteTables(param)
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, transitGatewayRouteTable := range result.TransitGatewayRouteTables {
if excludeAfter.After(*transitGatewayRouteTable.CreationTime) && awsgo.StringValue(transitGatewayRouteTable.State) != "deleted" && awsgo.StringValue(transitGatewayRouteTable.State) != "deleting" {
if configObj.TransitGatewayRouteTable.ShouldInclude(config.ResourceValue{Time: transitGatewayRouteTable.CreationTime}) &&
awsgo.StringValue(transitGatewayRouteTable.State) != "deleted" && awsgo.StringValue(transitGatewayRouteTable.State) != "deleting" {
ids = append(ids, transitGatewayRouteTable.TransitGatewayRouteTableId)
}
}
Expand All @@ -114,23 +103,21 @@ func getAllTransitGatewayRouteTables(session *session.Session, region string, ex
}

// Delete all TransitGatewayRouteTables
func nukeAllTransitGatewayRouteTables(session *session.Session, ids []*string) error {
svc := ec2.New(session)

func (tgw TransitGatewaysRouteTables) nukeAll(ids []*string) error {
if len(ids) == 0 {
logging.Logger.Debugf("No Transit Gateway Route Tables to nuke in region %s", *session.Config.Region)
logging.Logger.Debugf("No Transit Gateway Route Tables to nuke in region %s", tgw.Region)
return nil
}

logging.Logger.Debugf("Deleting all Transit Gateway Route Tables in region %s", *session.Config.Region)
logging.Logger.Debugf("Deleting all Transit Gateway Route Tables in region %s", tgw.Region)
var deletedIds []*string

for _, id := range ids {
param := &ec2.DeleteTransitGatewayRouteTableInput{
TransitGatewayRouteTableId: id,
}

_, err := svc.DeleteTransitGatewayRouteTable(param)
_, err := tgw.Client.DeleteTransitGatewayRouteTable(param)
if err != nil {
logging.Logger.Debugf("[Failed] %s", err)
} else {
Expand All @@ -139,21 +126,21 @@ func nukeAllTransitGatewayRouteTables(session *session.Session, ids []*string) e
}
}

logging.Logger.Debugf("[OK] %d Transit Gateway Route Table(s) deleted in %s", len(deletedIds), *session.Config.Region)
logging.Logger.Debugf("[OK] %d Transit Gateway Route Table(s) deleted in %s", len(deletedIds), tgw.Region)
return nil
}

// Returns a formated string of TransitGatewayVpcAttachment IDs
func getAllTransitGatewayVpcAttachments(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := ec2.New(session)
result, err := svc.DescribeTransitGatewayVpcAttachments(&ec2.DescribeTransitGatewayVpcAttachmentsInput{})
func (tgw TransitGatewaysVpcAttachment) getAll(configObj config.Config) ([]*string, error) {
result, err := tgw.Client.DescribeTransitGatewayVpcAttachments(&ec2.DescribeTransitGatewayVpcAttachmentsInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, tgwVpcAttachment := range result.TransitGatewayVpcAttachments {
if excludeAfter.After(*tgwVpcAttachment.CreationTime) && awsgo.StringValue(tgwVpcAttachment.State) != "deleted" && awsgo.StringValue(tgwVpcAttachment.State) != "deleting" {
if configObj.TransitGatewaysVpcAttachment.ShouldInclude(config.ResourceValue{Time: tgwVpcAttachment.CreationTime}) &&
awsgo.StringValue(tgwVpcAttachment.State) != "deleted" && awsgo.StringValue(tgwVpcAttachment.State) != "deleting" {
ids = append(ids, tgwVpcAttachment.TransitGatewayAttachmentId)
}
}
Expand All @@ -162,23 +149,21 @@ func getAllTransitGatewayVpcAttachments(session *session.Session, region string,
}

// Delete all TransitGatewayVpcAttachments
func nukeAllTransitGatewayVpcAttachments(session *session.Session, ids []*string) error {
svc := ec2.New(session)

func (tgw TransitGatewaysVpcAttachment) nukeAll(ids []*string) error {
if len(ids) == 0 {
logging.Logger.Debugf("No Transit Gateway Vpc Attachments to nuke in region %s", *session.Config.Region)
logging.Logger.Debugf("No Transit Gateway Vpc Attachments to nuke in region %s", tgw.Region)
return nil
}

logging.Logger.Debugf("Deleting all Transit Gateway Vpc Attachments in region %s", *session.Config.Region)
logging.Logger.Debugf("Deleting all Transit Gateway Vpc Attachments in region %s", tgw.Region)
var deletedIds []*string

for _, id := range ids {
param := &ec2.DeleteTransitGatewayVpcAttachmentInput{
TransitGatewayAttachmentId: id,
}

_, err := svc.DeleteTransitGatewayVpcAttachment(param)
_, err := tgw.Client.DeleteTransitGatewayVpcAttachment(param)

// Record status of this resource
e := report.Entry{
Expand All @@ -196,23 +181,6 @@ func nukeAllTransitGatewayVpcAttachments(session *session.Session, ids []*string
}
}

sleepMessage := "TransitGateway Vpc Attachments takes some time to create, and since there is no waiter available, we sleep instead."
sleepFor := 180 * time.Second
sleepWithMessage(sleepFor, sleepMessage)

logging.Logger.Debugf(("[OK] %d Transit Gateway Vpc Attachment(s) deleted in %s"), len(deletedIds), *session.Config.Region)
logging.Logger.Debugf(("[OK] %d Transit Gateway Vpc Attachment(s) deleted in %s"), len(deletedIds), tgw.Region)
return nil
}

func tgIsAvailableInRegion(session *session.Session, region string) (bool, error) {
svc := ec2.New(session)
_, err := svc.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{})
if err != nil {
if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == "InvalidAction" {
return false, nil
} else {
return false, errors.WithStackTrace(err)
}
}
return true, nil
}
Loading

0 comments on commit 74ab0b4

Please sign in to comment.