Skip to content

Commit

Permalink
Refactor nat gateway (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 authored Aug 1, 2023
1 parent f56fe75 commit f612dd3
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 224 deletions.
2 changes: 1 addition & 1 deletion aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
}
if IsNukeable(natGateways.ResourceName(), resourceTypes) {
start := time.Now()
ngwIDs, err := getAllNatGateways(cloudNukeSession, excludeAfter, configObj)
ngwIDs, err := natGateways.getAll(configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down
68 changes: 26 additions & 42 deletions aws/nat_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,23 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/gruntwork-io/go-commons/errors"
"github.com/gruntwork-io/go-commons/retry"
multierror "github.com/hashicorp/go-multierror"

"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
"github.com/gruntwork-io/go-commons/errors"
"github.com/gruntwork-io/go-commons/retry"
"github.com/hashicorp/go-multierror"
)

func getAllNatGateways(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) {
svc := ec2.New(session)

func (ngw NatGateways) getAll(configObj config.Config) ([]*string, error) {
allNatGateways := []*string{}
input := &ec2.DescribeNatGatewaysInput{}
err := svc.DescribeNatGatewaysPages(
err := ngw.Client.DescribeNatGatewaysPages(
input,
func(page *ec2.DescribeNatGatewaysOutput, lastPage bool) bool {
for _, ngw := range page.NatGateways {
if shouldIncludeNatGateway(ngw, excludeAfter, configObj) {
if shouldIncludeNatGateway(ngw, configObj) {
allNatGateways = append(allNatGateways, ngw.NatGatewayId)
}
}
Expand All @@ -40,7 +36,7 @@ func getAllNatGateways(session *session.Session, excludeAfter time.Time, configO
return allNatGateways, errors.WithStackTrace(err)
}

func shouldIncludeNatGateway(ngw *ec2.NatGateway, excludeAfter time.Time, configObj config.Config) bool {
func shouldIncludeNatGateway(ngw *ec2.NatGateway, configObj config.Config) bool {
if ngw == nil {
return false
}
Expand All @@ -50,37 +46,25 @@ func shouldIncludeNatGateway(ngw *ec2.NatGateway, excludeAfter time.Time, config
return false
}

if ngw.CreateTime != nil && excludeAfter.Before(*ngw.CreateTime) {
return false
}

if aws.StringValue(ngw.State) == ec2.NatGatewayStateDeleted {
return false
}

return config.ShouldInclude(
getNatGatewayName(ngw),
configObj.NatGateway.IncludeRule.NamesRegExp,
configObj.NatGateway.ExcludeRule.NamesRegExp,
)
return configObj.NatGateway.ShouldInclude(config.ResourceValue{
Time: ngw.CreateTime,
Name: getNatGatewayName(ngw),
})
}

func getNatGatewayName(ngw *ec2.NatGateway) string {
func getNatGatewayName(ngw *ec2.NatGateway) *string {
for _, tag := range ngw.Tags {
if aws.StringValue(tag.Key) == "Name" {
return aws.StringValue(tag.Value)
return tag.Value
}
}
return ""
}

func nukeAllNatGateways(session *session.Session, identifiers []*string) error {
region := aws.StringValue(session.Config.Region)

svc := ec2.New(session)
return nil
}

func (ngw NatGateways) nukeAll(identifiers []*string) error {
if len(identifiers) == 0 {
logging.Logger.Debugf("No Nat Gateways to nuke in region %s", region)
logging.Logger.Debugf("No Nat Gateways to nuke in region %s", ngw.Region)
return nil
}

Expand All @@ -94,13 +78,13 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error {
}

// There is no bulk delete nat gateway API, so we delete the batch of nat gateways concurrently using go routines.
logging.Logger.Debugf("Deleting Nat Gateways in region %s", region)
logging.Logger.Debugf("Deleting Nat Gateways in region %s", ngw.Region)
wg := new(sync.WaitGroup)
wg.Add(len(identifiers))
errChans := make([]chan error, len(identifiers))
for i, ngwID := range identifiers {
errChans[i] = make(chan error, 1)
go deleteNatGatewayAsync(wg, errChans[i], svc, ngwID)
go ngw.deleteAsync(wg, errChans[i], ngwID)
}
wg.Wait()

Expand All @@ -113,7 +97,7 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error {
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking NAT Gateway",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": ngw.Region,
})
}
}
Expand All @@ -129,7 +113,7 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error {
// Wait a maximum of 5 minutes: 10 seconds in between, up to 30 times
30, 10*time.Second,
func() error {
areDeleted, err := areAllNatGatewaysDeleted(svc, identifiers)
areDeleted, err := ngw.areAllNatGatewaysDeleted(identifiers)
if err != nil {
return errors.WithStackTrace(retry.FatalError{Underlying: err})
}
Expand All @@ -143,18 +127,18 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error {
return errors.WithStackTrace(err)
}
for _, ngwID := range identifiers {
logging.Logger.Debugf("[OK] NAT Gateway %s was deleted in %s", aws.StringValue(ngwID), region)
logging.Logger.Debugf("[OK] NAT Gateway %s was deleted in %s", aws.StringValue(ngwID), ngw.Region)
}
return nil
}

// areAllNatGatewaysDeleted returns true if all the requested NAT gateways have been deleted. This is determined by
// querying for the statuses of all the NAT gateways, and checking if AWS knows about them (if not, the NAT gateway was
// deleted and rolled off AWS DB) or if the status was updated to deleted.
func areAllNatGatewaysDeleted(svc *ec2.EC2, identifiers []*string) (bool, error) {
func (ngw NatGateways) areAllNatGatewaysDeleted(identifiers []*string) (bool, error) {
// NOTE: we don't need to do pagination here, because the pagination is handled by the caller to this function,
// based on NatGateways.MaxBatchSize.
resp, err := svc.DescribeNatGateways(&ec2.DescribeNatGatewaysInput{NatGatewayIds: identifiers})
resp, err := ngw.Client.DescribeNatGateways(&ec2.DescribeNatGatewaysInput{NatGatewayIds: identifiers})
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NatGatewayNotFound" {
return true, nil
Expand All @@ -179,11 +163,11 @@ func areAllNatGatewaysDeleted(svc *ec2.EC2, identifiers []*string) (bool, error)

// deleteNatGatewaysAsync deletes the provided NAT Gateway asynchronously in a goroutine, using wait groups for
// concurrency control and a return channel for errors.
func deleteNatGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *ec2.EC2, ngwID *string) {
func (ngw NatGateways) deleteAsync(wg *sync.WaitGroup, errChan chan error, ngwID *string) {
defer wg.Done()

input := &ec2.DeleteNatGatewayInput{NatGatewayId: ngwID}
_, err := svc.DeleteNatGateway(input)
_, err := ngw.Client.DeleteNatGateway(input)

// Record status of this resource
e := report.Entry{
Expand Down
Loading

0 comments on commit f612dd3

Please sign in to comment.