diff --git a/aws/aws.go b/aws/aws.go index 6f2f9836..1f1dc9d7 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -578,7 +578,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(natGateways.ResourceName(), resourceTypes) { start := time.Now() - ngwIDs, err := getAllNatGateways(cloudNukeSession, excludeAfter, configObj) + ngwIDs, err := natGateways.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/nat_gateway.go b/aws/nat_gateway.go index 5ac5277e..de8e0873 100644 --- a/aws/nat_gateway.go +++ b/aws/nat_gateway.go @@ -9,27 +9,23 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" - "github.com/gruntwork-io/go-commons/errors" - "github.com/gruntwork-io/go-commons/retry" - multierror "github.com/hashicorp/go-multierror" - "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/go-commons/errors" + "github.com/gruntwork-io/go-commons/retry" + "github.com/hashicorp/go-multierror" ) -func getAllNatGateways(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) { - svc := ec2.New(session) - +func (ngw NatGateways) getAll(configObj config.Config) ([]*string, error) { allNatGateways := []*string{} input := &ec2.DescribeNatGatewaysInput{} - err := svc.DescribeNatGatewaysPages( + err := ngw.Client.DescribeNatGatewaysPages( input, func(page *ec2.DescribeNatGatewaysOutput, lastPage bool) bool { for _, ngw := range page.NatGateways { - if shouldIncludeNatGateway(ngw, excludeAfter, configObj) { + if shouldIncludeNatGateway(ngw, configObj) { allNatGateways = append(allNatGateways, ngw.NatGatewayId) } } @@ -40,7 +36,7 @@ func getAllNatGateways(session *session.Session, excludeAfter time.Time, configO return allNatGateways, errors.WithStackTrace(err) } -func shouldIncludeNatGateway(ngw *ec2.NatGateway, excludeAfter time.Time, configObj config.Config) bool { +func shouldIncludeNatGateway(ngw *ec2.NatGateway, configObj config.Config) bool { if ngw == nil { return false } @@ -50,37 +46,25 @@ func shouldIncludeNatGateway(ngw *ec2.NatGateway, excludeAfter time.Time, config return false } - if ngw.CreateTime != nil && excludeAfter.Before(*ngw.CreateTime) { - return false - } - - if aws.StringValue(ngw.State) == ec2.NatGatewayStateDeleted { - return false - } - - return config.ShouldInclude( - getNatGatewayName(ngw), - configObj.NatGateway.IncludeRule.NamesRegExp, - configObj.NatGateway.ExcludeRule.NamesRegExp, - ) + return configObj.NatGateway.ShouldInclude(config.ResourceValue{ + Time: ngw.CreateTime, + Name: getNatGatewayName(ngw), + }) } -func getNatGatewayName(ngw *ec2.NatGateway) string { +func getNatGatewayName(ngw *ec2.NatGateway) *string { for _, tag := range ngw.Tags { if aws.StringValue(tag.Key) == "Name" { - return aws.StringValue(tag.Value) + return tag.Value } } - return "" -} - -func nukeAllNatGateways(session *session.Session, identifiers []*string) error { - region := aws.StringValue(session.Config.Region) - svc := ec2.New(session) + return nil +} +func (ngw NatGateways) nukeAll(identifiers []*string) error { if len(identifiers) == 0 { - logging.Logger.Debugf("No Nat Gateways to nuke in region %s", region) + logging.Logger.Debugf("No Nat Gateways to nuke in region %s", ngw.Region) return nil } @@ -94,13 +78,13 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error { } // There is no bulk delete nat gateway API, so we delete the batch of nat gateways concurrently using go routines. - logging.Logger.Debugf("Deleting Nat Gateways in region %s", region) + logging.Logger.Debugf("Deleting Nat Gateways in region %s", ngw.Region) wg := new(sync.WaitGroup) wg.Add(len(identifiers)) errChans := make([]chan error, len(identifiers)) for i, ngwID := range identifiers { errChans[i] = make(chan error, 1) - go deleteNatGatewayAsync(wg, errChans[i], svc, ngwID) + go ngw.deleteAsync(wg, errChans[i], ngwID) } wg.Wait() @@ -113,7 +97,7 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error { telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking NAT Gateway", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": ngw.Region, }) } } @@ -129,7 +113,7 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error { // Wait a maximum of 5 minutes: 10 seconds in between, up to 30 times 30, 10*time.Second, func() error { - areDeleted, err := areAllNatGatewaysDeleted(svc, identifiers) + areDeleted, err := ngw.areAllNatGatewaysDeleted(identifiers) if err != nil { return errors.WithStackTrace(retry.FatalError{Underlying: err}) } @@ -143,7 +127,7 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error { return errors.WithStackTrace(err) } for _, ngwID := range identifiers { - logging.Logger.Debugf("[OK] NAT Gateway %s was deleted in %s", aws.StringValue(ngwID), region) + logging.Logger.Debugf("[OK] NAT Gateway %s was deleted in %s", aws.StringValue(ngwID), ngw.Region) } return nil } @@ -151,10 +135,10 @@ func nukeAllNatGateways(session *session.Session, identifiers []*string) error { // areAllNatGatewaysDeleted returns true if all the requested NAT gateways have been deleted. This is determined by // querying for the statuses of all the NAT gateways, and checking if AWS knows about them (if not, the NAT gateway was // deleted and rolled off AWS DB) or if the status was updated to deleted. -func areAllNatGatewaysDeleted(svc *ec2.EC2, identifiers []*string) (bool, error) { +func (ngw NatGateways) areAllNatGatewaysDeleted(identifiers []*string) (bool, error) { // NOTE: we don't need to do pagination here, because the pagination is handled by the caller to this function, // based on NatGateways.MaxBatchSize. - resp, err := svc.DescribeNatGateways(&ec2.DescribeNatGatewaysInput{NatGatewayIds: identifiers}) + resp, err := ngw.Client.DescribeNatGateways(&ec2.DescribeNatGatewaysInput{NatGatewayIds: identifiers}) if err != nil { if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NatGatewayNotFound" { return true, nil @@ -179,11 +163,11 @@ func areAllNatGatewaysDeleted(svc *ec2.EC2, identifiers []*string) (bool, error) // deleteNatGatewaysAsync deletes the provided NAT Gateway asynchronously in a goroutine, using wait groups for // concurrency control and a return channel for errors. -func deleteNatGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *ec2.EC2, ngwID *string) { +func (ngw NatGateways) deleteAsync(wg *sync.WaitGroup, errChan chan error, ngwID *string) { defer wg.Done() input := &ec2.DeleteNatGatewayInput{NatGatewayId: ngwID} - _, err := svc.DeleteNatGateway(input) + _, err := ngw.Client.DeleteNatGateway(input) // Record status of this resource e := report.Entry{ diff --git a/aws/nat_gateway_test.go b/aws/nat_gateway_test.go index 0bc711d9..e544d42a 100644 --- a/aws/nat_gateway_test.go +++ b/aws/nat_gateway_test.go @@ -1,215 +1,124 @@ package aws import ( + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/gruntwork-io/cloud-nuke/telemetry" "regexp" "testing" "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/go-commons/errors" - terraws "github.com/gruntwork-io/terratest/modules/aws" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestListNatGateways(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) - require.NoError(t, err) - svc := ec2.New(session) - - ngwID := createNatGateway(t, svc, region) - defer deleteNatGateway(t, svc, ngwID, true) - - natGatewayIDs, err := getAllNatGateways(session, time.Now(), config.Config{}) - require.NoError(t, err) - assert.Contains(t, aws.StringValueSlice(natGatewayIDs), aws.StringValue(ngwID)) +type mockedNatGateway struct { + ec2iface.EC2API + DeleteNatGatewayOutput ec2.DeleteNatGatewayOutput + DescribeNatGatewaysOutput ec2.DescribeNatGatewaysOutput + DescribeNatGatewaysError error } -func createNatGatewayWithName(t *testing.T, svc *ec2.EC2, region string, name string) *string { - ngwID := createNatGateway(t, svc, region) +func (m mockedNatGateway) DeleteNatGateway(input *ec2.DeleteNatGatewayInput) (*ec2.DeleteNatGatewayOutput, error) { + return &m.DeleteNatGatewayOutput, nil +} - err := setTagsToResource(t, svc, ngwID, []*ec2.Tag{ - { - Key: aws.String("Name"), - Value: aws.String(name), - }, - }) - require.NoError(t, err) +func (m mockedNatGateway) DescribeNatGatewaysPages(input *ec2.DescribeNatGatewaysInput, fn func(*ec2.DescribeNatGatewaysOutput, bool) bool) error { + fn(&m.DescribeNatGatewaysOutput, true) + return nil +} - return ngwID +func (m mockedNatGateway) DescribeNatGateways(input *ec2.DescribeNatGatewaysInput) (*ec2.DescribeNatGatewaysOutput, error) { + return &m.DescribeNatGatewaysOutput, m.DescribeNatGatewaysError } -func TestListNatGatewaysWithConfigFile(t *testing.T) { +func TestNatGateway_GetAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) - require.NoError(t, err) - svc := ec2.New(session) - - includedNatGatewayName := "cloud-nuke-test-include-" + util.UniqueID() - excludedNatGatewayName := "cloud-nuke-test-" + util.UniqueID() - includedNatGatewayID := createNatGatewayWithName(t, svc, region, includedNatGatewayName) - excludedNatGatewayID := createNatGatewayWithName(t, svc, region, excludedNatGatewayName) - defer nukeAllNatGateways(session, []*string{includedNatGatewayID, excludedNatGatewayID}) - - natGatewayIds, err := getAllNatGateways(session, time.Now().Add(1*time.Hour), config.Config{ - NatGateway: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - {RE: *regexp.MustCompile("^cloud-nuke-test-include-.*")}, + testId1 := "test-nat-gateway-id1" + testId2 := "test-nat-gateway-id2" + testName1 := "test-nat-gateway-1" + testName2 := "test-nat-gateway-2" + now := time.Now() + ng := NatGateways{ + Client: mockedNatGateway{ + DescribeNatGatewaysOutput: ec2.DescribeNatGatewaysOutput{ + NatGateways: []*ec2.NatGateway{ + { + NatGatewayId: aws.String(testId1), + Tags: []*ec2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(testName1), + }, + }, + CreateTime: aws.Time(now), + }, + { + NatGatewayId: aws.String(testId2), + Tags: []*ec2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(testName2), + }, + }, + CreateTime: aws.Time(now.Add(1)), + }, }, }, }, - }) - - require.NoError(t, err) - require.Equal(t, 1, len(natGatewayIds)) - require.Equal(t, aws.StringValue(includedNatGatewayID), aws.StringValue(natGatewayIds[0])) -} - -func TestTimeFilterExclusionNewlyCreatedNatGateway(t *testing.T) { - t.Parallel() - - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) - require.NoError(t, err) - svc := ec2.New(session) - - // Creates a NGW - ngwID := createNatGateway(t, svc, region) - defer deleteNatGateway(t, svc, ngwID, true) - - // Assert NGW is picked up without filters - natGatewayIDsNewer, err := getAllNatGateways(session, time.Now(), config.Config{}) - require.NoError(t, err) - assert.Contains(t, aws.StringValueSlice(natGatewayIDsNewer), aws.StringValue(ngwID)) - - // Assert user doesn't appear when we look at users older than 1 Hour - olderThan := time.Now().Add(-1 * time.Hour) - natGatewayIDsOlder, err := getAllNatGateways(session, olderThan, config.Config{}) - require.NoError(t, err) - assert.NotContains(t, aws.StringValueSlice(natGatewayIDsOlder), aws.StringValue(ngwID)) -} - -func TestNukeNatGatewayOne(t *testing.T) { - t.Parallel() - - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) - require.NoError(t, err) - svc := ec2.New(session) - - // We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke. - ngwID := createNatGateway(t, svc, region) - defer deleteNatGateway(t, svc, ngwID, false) - identifiers := []*string{ngwID} - - require.NoError( - t, - nukeAllNatGateways(session, identifiers), - ) - - // Make sure the NAT gateway is deleted. - assertNatGatewaysDeleted(t, svc, identifiers) -} - -func TestNukeNatGatewayMoreThanOne(t *testing.T) { - t.Parallel() - - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) - require.NoError(t, err) - svc := ec2.New(session) - - natGateways := []*string{} - for i := 0; i < 3; i++ { - // We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke. - ngwID := createNatGateway(t, svc, region) - defer deleteNatGateway(t, svc, ngwID, false) - natGateways = append(natGateways, ngwID) } - require.NoError( - t, - nukeAllNatGateways(session, natGateways), - ) - - // Make sure the NAT Gateway is deleted. - assertNatGatewaysDeleted(t, svc, natGateways) -} - -// Helper functions for driving the NAT gateway tests -func setTagsToResource(t *testing.T, svc *ec2.EC2, resourceId *string, tags []*ec2.Tag) error { - _, err := svc.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{resourceId}, - Tags: tags, - }) - return err -} - -// createNatGateway will create a new NAT gateway in the default VPC -func createNatGateway(t *testing.T, svc *ec2.EC2, region string) *string { - defaultVpc := terraws.GetDefaultVpc(t, region) - subnet := defaultVpc.Subnets[0] - - resp, err := svc.CreateNatGateway(&ec2.CreateNatGatewayInput{ - SubnetId: aws.String(subnet.Id), - ConnectivityType: aws.String(ec2.ConnectivityTypePrivate), - }) - if err != nil { - assert.Failf(t, "Could not create test NAT gateways", errors.WithStackTrace(err).Error()) + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testId1, testId2}, + }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testId2}, + }, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }}, + expected: []string{testId1}, + }, } - if resp.NatGateway == nil { - t.Fatalf("Impossible error: AWS returned nil NAT gateway") + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := ng.getAll(config.Config{ + NatGateway: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.StringValueSlice(names)) + }) } - - return resp.NatGateway.NatGatewayId } -// deleteNatGateway is a function to delete the given NAT gateway. -func deleteNatGateway(t *testing.T, svc *ec2.EC2, ngwID *string, checkErr bool) { - input := &ec2.DeleteNatGatewayInput{NatGatewayId: ngwID} - _, err := svc.DeleteNatGateway(input) - if checkErr { - require.NoError(t, err) +func TestNatGateway_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() + + ngw := NatGateways{ + Client: mockedNatGateway{ + DeleteNatGatewayOutput: ec2.DeleteNatGatewayOutput{}, + DescribeNatGatewaysError: awserr.New("NatGatewayNotFound", "", nil), + }, } -} -func assertNatGatewaysDeleted(t *testing.T, svc *ec2.EC2, identifiers []*string) { - resp, err := svc.DescribeNatGateways(&ec2.DescribeNatGatewaysInput{NatGatewayIds: identifiers}) + err := ngw.nukeAll([]*string{aws.String("test")}) require.NoError(t, err) - if len(resp.NatGateways) == 0 { - return - } - if len(resp.NatGateways) > len(identifiers) { - t.Fatalf("More than expected %d NAT gateway (found %d) for query", len(identifiers), len(resp.NatGateways)) - } - for _, ngw := range resp.NatGateways { - if ngw == nil { - continue - } - if aws.StringValue(ngw.State) != ec2.NatGatewayStateDeleted { - t.Fatalf("NAT Gateway not deleted by nuke operation") - } - } } diff --git a/aws/nat_gateway_types.go b/aws/nat_gateway_types.go index d8d62059..72ac7438 100644 --- a/aws/nat_gateway_types.go +++ b/aws/nat_gateway_types.go @@ -33,7 +33,7 @@ func (secret NatGateways) MaxBatchSize() int { // Nuke - nuke 'em all!!! func (ngw NatGateways) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllNatGateways(session, awsgo.StringSlice(identifiers)); err != nil { + if err := ngw.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) }