From 86026035226fcde765ffeafdf97df0d02dae333c Mon Sep 17 00:00:00 2001 From: James Kwon Date: Wed, 9 Aug 2023 18:24:19 -0400 Subject: [PATCH] Refactor EC2 instance, clusters, and services --- aws/aws.go | 29 +-- aws/ec2.go | 54 ++--- aws/ec2_test.go | 423 ++++++++++---------------------------- aws/ec2_types.go | 12 +- aws/ec2_unit_test.go | 323 ----------------------------- aws/ec2_utils_for_test.go | 42 ---- aws/ecs_cluster.go | 96 +++++---- aws/ecs_cluster_test.go | 266 +++++++++--------------- aws/ecs_cluster_types.go | 2 +- aws/ecs_service.go | 110 +++++----- aws/ecs_service_test.go | 384 ++++++++++------------------------ aws/ecs_service_types.go | 2 +- aws/ecs_utils_for_test.go | 395 ----------------------------------- 13 files changed, 440 insertions(+), 1698 deletions(-) delete mode 100644 aws/ec2_unit_test.go delete mode 100644 aws/ec2_utils_for_test.go delete mode 100644 aws/ecs_utils_for_test.go diff --git a/aws/aws.go b/aws/aws.go index 8347bea1..4b487484 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -638,7 +638,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(ec2Instances.ResourceName(), resourceTypes) { start := time.Now() - instanceIds, err := getAllEc2Instances(cloudNukeSession, region, excludeAfter, configObj) + instanceIds, err := ec2Instances.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, @@ -819,29 +819,20 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(ecsServices.ResourceName(), resourceTypes) { start := time.Now() - clusterArns, err := getAllEcsClusters(cloudNukeSession) + + serviceArns, err := ecsServices.getAll(configObj) if err != nil { - ge := report.GeneralError{ - Error: err, - Description: "Unable to retrieve ECS clusters", - ResourceType: ecsServices.ResourceName(), - } - report.RecordError(ge) - } - if len(clusterArns) > 0 { - serviceArns, serviceClusterMap, err := getAllEcsServices(cloudNukeSession, clusterArns, excludeAfter, configObj) - if err != nil { - return nil, errors.WithStackTrace(err) - } - ecsServices.Services = awsgo.StringValueSlice(serviceArns) - ecsServices.ServiceClusterMap = serviceClusterMap - resourcesInRegion.Resources = append(resourcesInRegion.Resources, ecsServices) + return nil, errors.WithStackTrace(err) } + + ecsServices.Services = awsgo.StringValueSlice(serviceArns) + resourcesInRegion.Resources = append(resourcesInRegion.Resources, ecsServices) + telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Done Listing ECS Services", }, map[string]interface{}{ "region": region, - "recordCount": len(clusterArns), + "recordCount": len(serviceArns), "actionTime": time.Since(start).Seconds(), }) } @@ -852,7 +843,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(ecsClusters.ResourceName(), resourceTypes) { start := time.Now() - ecsClusterArns, err := getAllEcsClustersOlderThan(cloudNukeSession, excludeAfter, configObj) + ecsClusterArns, err := ecsClusters.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/ec2.go b/aws/ec2.go index 150371fd..258b524c 100644 --- a/aws/ec2.go +++ b/aws/ec2.go @@ -11,7 +11,6 @@ import ( awsgo "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/gruntwork-io/cloud-nuke/config" @@ -21,20 +20,21 @@ import ( ) // returns only instance Ids of unprotected ec2 instances -func filterOutProtectedInstances(svc *ec2.EC2, output *ec2.DescribeInstancesOutput, excludeAfter time.Time, configObj config.Config) ([]*string, error) { +func (ei EC2Instances) filterOutProtectedInstances(output *ec2.DescribeInstancesOutput, configObj config.Config) ([]*string, error) { var filteredIds []*string for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { instanceID := *instance.InstanceId - attr, err := svc.DescribeInstanceAttribute(&ec2.DescribeInstanceAttributeInput{ + attr, err := ei.Client.DescribeInstanceAttribute(&ec2.DescribeInstanceAttributeInput{ Attribute: awsgo.String("disableApiTermination"), InstanceId: awsgo.String(instanceID), }) if err != nil { return nil, errors.WithStackTrace(err) } - if shouldIncludeInstanceId(instance, excludeAfter, *attr.DisableApiTermination.Value, configObj) { + + if shouldIncludeInstanceId(instance, *attr.DisableApiTermination.Value, configObj) { filteredIds = append(filteredIds, &instanceID) } } @@ -44,9 +44,7 @@ func filterOutProtectedInstances(svc *ec2.EC2, output *ec2.DescribeInstancesOutp } // Returns a formatted string of EC2 instance ids -func getAllEc2Instances(session *session.Session, region string, excludeAfter time.Time, configObj config.Config) ([]*string, error) { - svc := ec2.New(session) - +func (ei EC2Instances) getAll(configObj config.Config) ([]*string, error) { params := &ec2.DescribeInstancesInput{ Filters: []*ec2.Filter{ { @@ -59,12 +57,12 @@ func getAllEc2Instances(session *session.Session, region string, excludeAfter ti }, } - output, err := svc.DescribeInstances(params) + output, err := ei.Client.DescribeInstances(params) if err != nil { return nil, errors.WithStackTrace(err) } - instanceIds, err := filterOutProtectedInstances(svc, output, excludeAfter, configObj) + instanceIds, err := ei.filterOutProtectedInstances(output, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -72,15 +70,7 @@ func getAllEc2Instances(session *session.Session, region string, excludeAfter ti return instanceIds, nil } -func shouldIncludeInstanceId(instance *ec2.Instance, excludeAfter time.Time, protected bool, configObj config.Config) bool { - if instance == nil { - return false - } - - if excludeAfter.Before(*instance.LaunchTime) { - return false - } - +func shouldIncludeInstanceId(instance *ec2.Instance, protected bool, configObj config.Config) bool { if protected { return false } @@ -88,41 +78,37 @@ func shouldIncludeInstanceId(instance *ec2.Instance, excludeAfter time.Time, pro // If Name is unset, GetEC2ResourceNameTagValue returns error and zero value string // Ignore this error and pass empty string to config.ShouldInclude instanceName := GetEC2ResourceNameTagValue(instance.Tags) - - return config.ShouldInclude( - *instanceName, - configObj.EC2.IncludeRule.NamesRegExp, - configObj.EC2.ExcludeRule.NamesRegExp, - ) + return configObj.EC2.ShouldInclude(config.ResourceValue{ + Name: instanceName, + Time: instance.LaunchTime, + }) } // Deletes all non protected EC2 instances -func nukeAllEc2Instances(session *session.Session, instanceIds []*string) error { - svc := ec2.New(session) - +func (ei EC2Instances) nukeAll(instanceIds []*string) error { if len(instanceIds) == 0 { - logging.Logger.Debugf("No EC2 instances to nuke in region %s", *session.Config.Region) + logging.Logger.Debugf("No EC2 instances to nuke in region %s", ei.Region) return nil } - logging.Logger.Debugf("Terminating all EC2 instances in region %s", *session.Config.Region) + logging.Logger.Debugf("Terminating all EC2 instances in region %s", ei.Region) params := &ec2.TerminateInstancesInput{ InstanceIds: instanceIds, } - _, err := svc.TerminateInstances(params) + _, err := ei.Client.TerminateInstances(params) if err != nil { logging.Logger.Debugf("[Failed] %s", err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking EC2 Instance", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": ei.Region, }) return errors.WithStackTrace(err) } - err = svc.WaitUntilInstanceTerminated(&ec2.DescribeInstancesInput{ + err = ei.Client.WaitUntilInstanceTerminated(&ec2.DescribeInstancesInput{ Filters: []*ec2.Filter{ { Name: awsgo.String("instance-id"), @@ -140,12 +126,12 @@ func nukeAllEc2Instances(session *session.Session, instanceIds []*string) error telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking EC2 Instance", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": ei.Region, }) return errors.WithStackTrace(err) } - logging.Logger.Debugf("[OK] %d instance(s) terminated in %s", len(instanceIds), *session.Config.Region) + logging.Logger.Debugf("[OK] %d instance(s) terminated in %s", len(instanceIds), ei.Region) return nil } diff --git a/aws/ec2_test.go b/aws/ec2_test.go index 60e64dd1..124dccd0 100644 --- a/aws/ec2_test.go +++ b/aws/ec2_test.go @@ -1,361 +1,144 @@ package aws import ( - "errors" - "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + awsgo "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/telemetry" + "github.com/stretchr/testify/require" "regexp" "testing" "time" - - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/util" - gruntworkerrors "github.com/gruntwork-io/go-commons/errors" - "github.com/stretchr/testify/assert" ) -const ( - ExampleId = "a1b2c3d4e5f601345" - ExampleIdTwo = "a1b2c3d4e5f654321" - ExampleIdThree = "a1b2c3d4e5f632154" - ExampleVpcId = "vpc-" + ExampleId - ExampleVpcIdTwo = "vpc-" + ExampleIdTwo - ExampleVpcIdThree = "vpc-" + ExampleIdThree - ExampleSubnetId = "subnet-" + ExampleId - ExampleSubnetIdTwo = "subnet-" + ExampleIdTwo - ExampleSubnetIdThree = "subnet-" + ExampleIdThree - ExampleRouteTableId = "rtb-" + ExampleId - ExampleNetworkAclId = "acl-" + ExampleId - ExampleSecurityGroupId = "sg-" + ExampleId - ExampleSecurityGroupIdTwo = "sg-" + ExampleIdTwo - ExampleSecurityGroupIdThree = "sg-" + ExampleIdThree - ExampleSecurityGroupRuleId = "sgr-" + ExampleId - ExampleInternetGatewayId = "igw-" + ExampleId - ExampleEndpointId = "vpce-" + ExampleId -) - -// getAMIIdByName - Retrieves an AMI ImageId given the name of the Id. Used for -// retrieving a standard AMI across AWS regions. -func getAMIIdByName(svc *ec2.EC2, name string) (string, error) { - imagesResult, err := svc.DescribeImages(&ec2.DescribeImagesInput{ - Owners: []*string{awsgo.String("self"), awsgo.String("amazon")}, - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: awsgo.String("name"), - Values: []*string{awsgo.String(name)}, - }, - }, - }) - - if err != nil { - return "", gruntworkerrors.WithStackTrace(err) - } - - if len(imagesResult.Images) == 0 { - return "", gruntworkerrors.WithStackTrace(fmt.Errorf("No images found with name %s", name)) - } - - image := imagesResult.Images[0] - return awsgo.StringValue(image.ImageId), nil +type mockedEC2Instances struct { + ec2iface.EC2API + DescribeInstancesOutput ec2.DescribeInstancesOutput + DescribeInstanceAttributeOutput map[string]ec2.DescribeInstanceAttributeOutput + TerminateInstancesOutput ec2.TerminateInstancesOutput } -// runAndWaitForInstance - Given a preconstructed ec2.RunInstancesInput object, -// make the API call to run the instance and then wait for the instance to be -// up and running before returning. -func runAndWaitForInstance(svc *ec2.EC2, name string, params *ec2.RunInstancesInput) (ec2.Instance, error) { - runResult, err := svc.RunInstances(params) - if err != nil { - return ec2.Instance{}, gruntworkerrors.WithStackTrace(err) - } - - if len(runResult.Instances) == 0 { - err := errors.New("Could not create test EC2 instance") - return ec2.Instance{}, gruntworkerrors.WithStackTrace(err) - } - - err = svc.WaitUntilInstanceExists(&ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: awsgo.String("instance-id"), - Values: []*string{runResult.Instances[0].InstanceId}, - }, - }, - }) - - if err != nil { - return ec2.Instance{}, gruntworkerrors.WithStackTrace(err) - } - - // Add test tag to the created instance - _, err = svc.CreateTags(&ec2.CreateTagsInput{ - Resources: []*string{runResult.Instances[0].InstanceId}, - Tags: []*ec2.Tag{ - { - Key: awsgo.String("Name"), - Value: awsgo.String(name), - }, - }, - }) - - if err != nil { - return ec2.Instance{}, gruntworkerrors.WithStackTrace(err) - } - - // EC2 Instance must be in a running before this function returns - err = svc.WaitUntilInstanceRunning(&ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: awsgo.String("instance-id"), - Values: []*string{runResult.Instances[0].InstanceId}, - }, - }, - }) - - if err != nil { - return ec2.Instance{}, gruntworkerrors.WithStackTrace(err) - } - - return *runResult.Instances[0], nil - -} - -func createTestEC2Instance(t *testing.T, session *session.Session, name string, protected bool) ec2.Instance { - svc := ec2.New(session) - - imageID, err := getAMIIdByName(svc, "amzn-ami-hvm-2018.03.0.20220315.0-x86_64-gp2") - if err != nil { - assert.Fail(t, err.Error()) - } - - params := &ec2.RunInstancesInput{ - ImageId: awsgo.String(imageID), - InstanceType: awsgo.String("t3.micro"), - MinCount: awsgo.Int64(1), - MaxCount: awsgo.Int64(1), - DisableApiTermination: awsgo.Bool(protected), - } - instance, err := runAndWaitForInstance(svc, name, params) - if err != nil { - assert.Fail(t, err.Error()) - } - return instance +func (m mockedEC2Instances) DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { + return &m.DescribeInstancesOutput, nil } -func removeEC2InstanceProtection(svc *ec2.EC2, instance *ec2.Instance) error { - // make instance unprotected so it can be cleaned up - _, err := svc.ModifyInstanceAttribute(&ec2.ModifyInstanceAttributeInput{ - DisableApiTermination: &ec2.AttributeBooleanValue{ - Value: awsgo.Bool(false), - }, - InstanceId: instance.InstanceId, - }) +func (m mockedEC2Instances) DescribeInstanceAttribute(input *ec2.DescribeInstanceAttributeInput) (*ec2.DescribeInstanceAttributeOutput, error) { + id := input.InstanceId + output := m.DescribeInstanceAttributeOutput[*id] - return err + return &output, nil } -func findEC2InstancesByNameTag(t *testing.T, session *session.Session, name string) []*string { - output, err := ec2.New(session).DescribeInstances(&ec2.DescribeInstancesInput{}) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - - var instanceIds []*string - for _, reservation := range output.Reservations { - for _, instance := range reservation.Instances { - instanceID := *instance.InstanceId - - // Retrive only IDs of instances with the unique test tag - for _, tag := range instance.Tags { - if *tag.Key == "Name" { - if *tag.Value == name { - instanceIds = append(instanceIds, &instanceID) - } - } - } - - } - } - - return instanceIds +func (m mockedEC2Instances) TerminateInstances(*ec2.TerminateInstancesInput) (*ec2.TerminateInstancesOutput, error) { + return &m.TerminateInstancesOutput, nil } -func TestListInstances(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region, err := getRandomRegion() - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - - uniqueTestID := "cloud-nuke-test-" + util.UniqueID() - instance := createTestEC2Instance(t, session, uniqueTestID, false) - protectedInstance := createTestEC2Instance(t, session, uniqueTestID, true) - // clean up after this test - defer nukeAllEc2Instances(session, []*string{instance.InstanceId, protectedInstance.InstanceId}) - - instanceIds, err := getAllEc2Instances(session, region, time.Now().Add(1*time.Hour*-1), config.Config{}) - if err != nil { - assert.Fail(t, "Unable to fetch list of EC2 Instances") - } - - assert.NotContains(t, instanceIds, instance.InstanceId) - assert.NotContains(t, instanceIds, protectedInstance.InstanceId) - - instanceIds, err = getAllEc2Instances(session, region, time.Now().Add(1*time.Hour), config.Config{}) - if err != nil { - assert.Fail(t, "Unable to fetch list of EC2 Instances") - } - - assert.Contains(t, instanceIds, instance.InstanceId) - assert.NotContains(t, instanceIds, protectedInstance.InstanceId) - - if err = removeEC2InstanceProtection(ec2.New(session), &protectedInstance); err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } +func (m mockedEC2Instances) WaitUntilInstanceTerminated(*ec2.DescribeInstancesInput) error { + return nil } -func TestNukeInstances(t *testing.T) { +func TestEc2Instances_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - - uniqueTestID := "cloud-nuke-test-" + util.UniqueID() - createTestEC2Instance(t, session, uniqueTestID, false) - - instanceIds := findEC2InstancesByNameTag(t, session, uniqueTestID) - - if err := nukeAllEc2Instances(session, instanceIds); err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - instances, err := getAllEc2Instances(session, region, time.Now().Add(1*time.Hour), config.Config{}) - - if err != nil { - assert.Fail(t, "Unable to fetch list of EC2 Instances") - } - - for _, instanceID := range instanceIds { - assert.NotContains(t, instances, *instanceID) - } -} - -// Test config file filtering works as expected -func TestShouldIncludeInstanceId(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - - mockInstance := &ec2.Instance{ - LaunchTime: awsgo.Time(time.Now()), - Tags: []*ec2.Tag{ - { - Key: awsgo.String("Name"), - Value: awsgo.String("cloud-nuke-test"), - }, - { - Key: awsgo.String("Foo"), - Value: awsgo.String("Bar"), - }, - }, - } - - mockExpression, err := regexp.Compile("^cloud-nuke-*") - if err != nil { - logging.Logger.Fatalf("There was an error compiling regex expression %v", err) - } - - mockExcludeConfig := config.Config{ - EC2: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ + testId1 := "testId1" + testId2 := "testId2" + testName1 := "testName1" + testName2 := "testName2" + now := time.Now() + ei := EC2Instances{ + Client: mockedEC2Instances{ + DescribeInstancesOutput: ec2.DescribeInstancesOutput{ + Reservations: []*ec2.Reservation{ { - RE: *mockExpression, + Instances: []*ec2.Instance{ + { + InstanceId: awsgo.String(testId1), + Tags: []*ec2.Tag{ + { + Key: awsgo.String("Name"), + Value: awsgo.String(testName1), + }, + }, + LaunchTime: awsgo.Time(now), + }, + { + InstanceId: awsgo.String(testId2), + Tags: []*ec2.Tag{ + { + Key: awsgo.String("Name"), + Value: awsgo.String(testName2), + }, + }, + LaunchTime: awsgo.Time(now.Add(1)), + }, + }, }, }, }, - }, - } - - mockIncludeConfig := config.Config{ - EC2: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *mockExpression, + DescribeInstanceAttributeOutput: map[string]ec2.DescribeInstanceAttributeOutput{ + testId1: { + DisableApiTermination: &ec2.AttributeBooleanValue{ + Value: awsgo.Bool(false), + }, + }, + testId2: { + DisableApiTermination: &ec2.AttributeBooleanValue{ + Value: awsgo.Bool(false), }, }, }, }, } - cases := []struct { - Name string - Instance *ec2.Instance - Config config.Config - ExcludeAfter time.Time - Protected bool - Expected bool + tests := map[string]struct { + configObj config.ResourceType + expected []string }{ - { - Name: "ConfigExclude", - Instance: mockInstance, - Config: mockExcludeConfig, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Protected: false, - Expected: false, + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testId1, testId2}, }, - { - Name: "ConfigInclude", - Instance: mockInstance, - Config: mockIncludeConfig, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Protected: false, - Expected: true, - }, - { - Name: "NotOlderThan", - Instance: mockInstance, - Config: config.Config{}, - ExcludeAfter: time.Now().Add(1 * time.Hour * -1), - Protected: false, - Expected: false, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testId2}, }, - { - Name: "Protected", - Instance: mockInstance, - Config: config.Config{}, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Protected: true, - Expected: false, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }}, + expected: []string{testId1}, }, } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - result := shouldIncludeInstanceId(c.Instance, c.ExcludeAfter, c.Protected, c.Config) - assert.Equal(t, c.Expected, result) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := ei.getAll(config.Config{ + EC2: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, awsgo.StringValueSlice(names)) }) } } + +func TestEc2Instances_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() + + ei := EC2Instances{ + Client: mockedEC2Instances{ + TerminateInstancesOutput: ec2.TerminateInstancesOutput{}, + }, + } + + err := ei.nukeAll([]*string{awsgo.String("testId1")}) + require.NoError(t, err) +} diff --git a/aws/ec2_types.go b/aws/ec2_types.go index 06f546a3..39d8a172 100644 --- a/aws/ec2_types.go +++ b/aws/ec2_types.go @@ -15,23 +15,23 @@ type EC2Instances struct { } // ResourceName - the simple name of the aws resource -func (instance EC2Instances) ResourceName() string { +func (ei EC2Instances) ResourceName() string { return "ec2" } // ResourceIdentifiers - The instance ids of the ec2 instances -func (instance EC2Instances) ResourceIdentifiers() []string { - return instance.InstanceIds +func (ei EC2Instances) ResourceIdentifiers() []string { + return ei.InstanceIds } -func (instance EC2Instances) MaxBatchSize() int { +func (ei EC2Instances) MaxBatchSize() int { // Tentative batch size to ensure AWS doesn't throttle return 49 } // Nuke - nuke 'em all!!! -func (instance EC2Instances) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllEc2Instances(session, awsgo.StringSlice(identifiers)); err != nil { +func (ei EC2Instances) Nuke(session *session.Session, identifiers []string) error { + if err := ei.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) } diff --git a/aws/ec2_unit_test.go b/aws/ec2_unit_test.go deleted file mode 100644 index 7add736d..00000000 --- a/aws/ec2_unit_test.go +++ /dev/null @@ -1,323 +0,0 @@ -// These tests use GoMock and the ec2iface to provide a mock framework for testing the EC2 API -// Unlike other tests in cloud-nuke, nuking the default VPCs and security groups is not an option. -// Other tests within cloud-nuke depend on the default VPCs/SGs to function, and other projects -// may be using the same AWS account at the same time. Deleting the default VPCs would break things. -// Therefore, the default VPC/SG nuke testing is mocked as unit tests. -// To generate the EC2API mock, install https://github.com/golang/mock, then use the following: -// mockgen -source vendor/github.com/aws/aws-sdk-go/service/ec2/ec2iface/interface.go -destination aws/mocks/EC2API.go - -package aws - -import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" - "testing" - - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/golang/mock/gomock" - mock_ec2iface "github.com/gruntwork-io/cloud-nuke/aws/mocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func getTestVpcs(mockEC2 *mock_ec2iface.MockEC2API) []Vpc { - return []Vpc{ - { - Region: "ap-southeast-1", - svc: mockEC2, - }, - { - Region: "eu-west-3", - svc: mockEC2, - }, - { - Region: "ca-central-1", - svc: mockEC2, - }, - } -} - -func TestGetDefaultVpcs(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockEC2 := mock_ec2iface.NewMockEC2API(mockCtrl) - - vpcs := getTestVpcs(mockEC2) - - describeVpcsInput := getDefaultDescribeVpcsInput() - describeVpcsOutputOne := &ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ - {VpcId: awsgo.String(ExampleVpcId)}, - }, - } - describeVpcsFunc := func(input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return describeVpcsOutputOne, nil - } - describeVpcsOutputTwo := &ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ - {VpcId: awsgo.String(ExampleVpcIdTwo)}, - }, - } - describeVpcsFuncTwo := func(input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return describeVpcsOutputTwo, nil - } - describeVpcsOutputThree := &ec2.DescribeVpcsOutput{Vpcs: []*ec2.Vpc{}} - describeVpcsFuncThree := func(input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) { - return describeVpcsOutputThree, nil - } - gomock.InOrder( - mockEC2.EXPECT().DescribeVpcs(describeVpcsInput).DoAndReturn(describeVpcsFunc), - mockEC2.EXPECT().DescribeVpcs(describeVpcsInput).DoAndReturn(describeVpcsFuncTwo), - mockEC2.EXPECT().DescribeVpcs(describeVpcsInput).DoAndReturn(describeVpcsFuncThree), - ) - - vpcs, err := GetDefaultVpcs(vpcs) - require.NoError(t, err) - assert.Len(t, vpcs, 2, "There should be two default VPCs") -} - -func getTestVpcsWithIds(mockEC2 *mock_ec2iface.MockEC2API) []Vpc { - return []Vpc{ - { - Region: "ap-southeast-1", - VpcId: ExampleVpcId, - svc: mockEC2, - }, - { - Region: "eu-west-3", - VpcId: ExampleVpcIdTwo, - svc: mockEC2, - }, - { - Region: "ca-central-1", - VpcId: ExampleVpcIdThree, - svc: mockEC2, - }, - } -} - -func TestNukeMockVpcs(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockEC2 := mock_ec2iface.NewMockEC2API(mockCtrl) - - vpcs := getTestVpcsWithIds(mockEC2) - for _, vpc := range vpcs { - describeInternetGatewaysInput := getDescribeInternetGatewaysInput(vpc.VpcId) - describeInternetGatewaysOutput := getDescribeInternetGatewaysOutput(ExampleInternetGatewayId) - describeInternetGatewaysFunc := func(input *ec2.DescribeInternetGatewaysInput) (*ec2.DescribeInternetGatewaysOutput, error) { - return describeInternetGatewaysOutput, nil - } - detachInternetGatewayInput := getDetachInternetGatewayInput(vpc.VpcId, ExampleInternetGatewayId) - deleteInternetGatewayInput := getDeleteInternetGatewayInput(ExampleInternetGatewayId) - egressOnlyInternetGatewaysInput := getDescribeEgressOnlyInternetGatewaysInput() - describeNetworkInterfacesInput := getDescribeNetworkInterfacesInput(vpc.VpcId) - describeEndpointsInput := getDescribeEndpointsInput(vpc.VpcId) - describeEndpointsOutput := getDescribeEndpointsOutput([]string{ExampleEndpointId}) - describeEndpointsFunc := func(input *ec2.DescribeVpcEndpointsInput) (*ec2.DescribeVpcEndpointsOutput, error) { - return describeEndpointsOutput, nil - } - deleteEndpointInput := getDeleteEndpointInput(ExampleEndpointId) - - describeEndpointsWaitForDeletionInput := getDescribeEndpointsWaitForDeletionInput(vpc.VpcId) - describeEndpointsWaitForDeletionOutput := getDescribeEndpointsOutput(nil) - describeEndpointsWaitForDeletionFunc := func(input *ec2.DescribeVpcEndpointsInput) (*ec2.DescribeVpcEndpointsOutput, error) { - return describeEndpointsWaitForDeletionOutput, nil - } - - describeSubnetsInput := getDescribeSubnetsInput(vpc.VpcId) - describeSubnetsOutput := getDescribeSubnetsOutput([]string{ExampleSubnetId, ExampleSubnetIdTwo, ExampleSubnetIdThree}) - describeSubnetsFunc := func(input *ec2.DescribeSubnetsInput) (*ec2.DescribeSubnetsOutput, error) { - return describeSubnetsOutput, nil - } - deleteSubnetInputOne := getDeleteSubnetInput(ExampleSubnetId) - deleteSubnetInputTwo := getDeleteSubnetInput(ExampleSubnetIdTwo) - deleteSubnetInputThree := getDeleteSubnetInput(ExampleSubnetIdThree) - - describeRouteTablesInput := getDescribeRouteTablesInput(vpc.VpcId) - describeRouteTablesOutput := getDescribeRouteTablesOutput([]string{ExampleRouteTableId}) - describeRouteTablesFunc := func(input *ec2.DescribeRouteTablesInput) (*ec2.DescribeRouteTablesOutput, error) { - return describeRouteTablesOutput, nil - } - deleteRouteTableInput := getDeleteRouteTableInput(ExampleRouteTableId) - - describeNetworkAclsInput := getDescribeNetworkAclsInput(vpc.VpcId) - describeNetworkAclsOutput := getDescribeNetworkAclsOutput([]string{ExampleNetworkAclId}) - describeNetworkAclsFunc := func(input *ec2.DescribeNetworkAclsInput) (*ec2.DescribeNetworkAclsOutput, error) { - return describeNetworkAclsOutput, nil - } - deleteNetworkAclInput := getDeleteNetworkAclInput(ExampleNetworkAclId) - - describeSecurityGroupRulesInput := getDescribeSecurityGroupRulesInput(ExampleSecurityGroupId) - describeSecurityGroupRulesOutput := getDescribeSecurityGroupRulesOutput([]string{ExampleSecurityGroupRuleId}) - describeSecurityGroupRulesFunc := func(input *ec2.DescribeSecurityGroupRulesInput) (*ec2.DescribeSecurityGroupRulesOutput, error) { - return describeSecurityGroupRulesOutput, nil - } - revokeSecurityGroupEgressInput := getRevokeSecurityGroupEgressInput(ExampleSecurityGroupId, ExampleSecurityGroupRuleId) - associateDhcpOptionsInput := getAssociateDhcpOptionsInput(vpc.VpcId) - revokeSecurityGroupIngressInput := getRevokeSecurityGroupIngressInput(ExampleSecurityGroupId, ExampleSecurityGroupRuleId) - - describeSecurityGroupsInput := getDescribeSecurityGroupsInput(vpc.VpcId) - describeSecurityGroupsOutput := getDescribeSecurityGroupsOutput([]string{ExampleSecurityGroupId}) - describeSecurityGroupsFunc := func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) { - return describeSecurityGroupsOutput, nil - } - deleteSecurityGroupInput := getDeleteSecurityGroupInput(ExampleSecurityGroupId) - - deleteVpcInput := getDeleteVpcInput(vpc.VpcId) - - gomock.InOrder( - mockEC2.EXPECT().DescribeInternetGateways(describeInternetGatewaysInput).DoAndReturn(describeInternetGatewaysFunc), - mockEC2.EXPECT().DetachInternetGateway(detachInternetGatewayInput), - mockEC2.EXPECT().DeleteInternetGateway(deleteInternetGatewayInput), - mockEC2.EXPECT().DescribeEgressOnlyInternetGatewaysPages(egressOnlyInternetGatewaysInput, gomock.Any()), - mockEC2.EXPECT().DescribeVpcEndpoints(describeEndpointsInput).DoAndReturn(describeEndpointsFunc), - mockEC2.EXPECT().DeleteVpcEndpoints(deleteEndpointInput), - mockEC2.EXPECT().DescribeVpcEndpoints(describeEndpointsWaitForDeletionInput).DoAndReturn(describeEndpointsWaitForDeletionFunc), - mockEC2.EXPECT().DescribeNetworkInterfacesPages(describeNetworkInterfacesInput, gomock.Any()), - mockEC2.EXPECT().DescribeSubnets(describeSubnetsInput).DoAndReturn(describeSubnetsFunc), - mockEC2.EXPECT().DeleteSubnet(deleteSubnetInputOne), - mockEC2.EXPECT().DeleteSubnet(deleteSubnetInputTwo), - mockEC2.EXPECT().DeleteSubnet(deleteSubnetInputThree), - mockEC2.EXPECT().DescribeRouteTables(describeRouteTablesInput).DoAndReturn(describeRouteTablesFunc), - mockEC2.EXPECT().DeleteRouteTable(deleteRouteTableInput), - mockEC2.EXPECT().DescribeNetworkAcls(describeNetworkAclsInput).DoAndReturn(describeNetworkAclsFunc), - mockEC2.EXPECT().DeleteNetworkAcl(deleteNetworkAclInput), - mockEC2.EXPECT().DescribeSecurityGroups(describeSecurityGroupsInput).DoAndReturn(describeSecurityGroupsFunc), - mockEC2.EXPECT().DescribeSecurityGroupRules(describeSecurityGroupRulesInput).DoAndReturn(describeSecurityGroupRulesFunc), - mockEC2.EXPECT().RevokeSecurityGroupEgress(revokeSecurityGroupEgressInput), - mockEC2.EXPECT().RevokeSecurityGroupIngress(revokeSecurityGroupIngressInput), - mockEC2.EXPECT().DeleteSecurityGroup(deleteSecurityGroupInput), - mockEC2.EXPECT().AssociateDhcpOptions(associateDhcpOptionsInput), - mockEC2.EXPECT().DeleteVpc(deleteVpcInput), - ) - } - - err := NukeVpcs(vpcs) - require.NoError(t, err) -} - -func TestNukeDefaultSecurityGroups(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockEC2 := mock_ec2iface.NewMockEC2API(mockCtrl) - - regions := []string{ - "ap-southeast-1", - "eu-west-3", - } - - groups := []DefaultSecurityGroup{ - { - Region: regions[0], - GroupName: "default", - GroupId: ExampleSecurityGroupId, - svc: mockEC2, - }, - { - Region: regions[0], - GroupName: "default", - GroupId: ExampleSecurityGroupIdTwo, - svc: mockEC2, - }, - { - Region: regions[1], - GroupName: "default", - GroupId: ExampleSecurityGroupIdThree, - svc: mockEC2, - }, - } - describeSecurityGroupsInput := getDescribeSecurityGroupsInputEmpty() - describeSecurityGroupsOutputOne := getDescribeDefaultSecurityGroupsOutput(groups[0:2]) - describeSecurityGroupsFuncOne := func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) { - return describeSecurityGroupsOutputOne, nil - } - describeSecurityGroupsOutputTwo := getDescribeDefaultSecurityGroupsOutput(groups[2:]) - describeSecurityGroupsFuncTwo := func(input *ec2.DescribeSecurityGroupsInput) (*ec2.DescribeSecurityGroupsOutput, error) { - return describeSecurityGroupsOutputTwo, nil - } - - gomock.InOrder( - mockEC2.EXPECT().DescribeSecurityGroups(describeSecurityGroupsInput).DoAndReturn(describeSecurityGroupsFuncOne), - mockEC2.EXPECT().DescribeSecurityGroups(describeSecurityGroupsInput).DoAndReturn(describeSecurityGroupsFuncTwo), - mockEC2.EXPECT().RevokeSecurityGroupIngress(groups[0].getDefaultSecurityGroupIngressRule()), - mockEC2.EXPECT().RevokeSecurityGroupEgress(groups[0].getDefaultSecurityGroupEgressRule()), - mockEC2.EXPECT().RevokeSecurityGroupEgress(groups[0].getDefaultSecurityGroupIPv6EgressRule()), - mockEC2.EXPECT().RevokeSecurityGroupIngress(groups[1].getDefaultSecurityGroupIngressRule()), - mockEC2.EXPECT().RevokeSecurityGroupEgress(groups[1].getDefaultSecurityGroupEgressRule()), - mockEC2.EXPECT().RevokeSecurityGroupEgress(groups[1].getDefaultSecurityGroupIPv6EgressRule()), - mockEC2.EXPECT().RevokeSecurityGroupIngress(groups[2].getDefaultSecurityGroupIngressRule()), - mockEC2.EXPECT().RevokeSecurityGroupEgress(groups[2].getDefaultSecurityGroupEgressRule()), - mockEC2.EXPECT().RevokeSecurityGroupEgress(groups[2].getDefaultSecurityGroupIPv6EgressRule()), - ) - - for range regions { - _, err := DescribeDefaultSecurityGroups(mockEC2) - require.NoError(t, err) - } - - err := NukeDefaultSecurityGroupRules(groups) - require.NoError(t, err) -} - -// ********************************************************************************** -// The test methodology below deletes default VPCs for reals which breaks other tests -// and hence is commented out in favor of the mock testing approach above -// ********************************************************************************** -// func createRandomDefaultVpc(t *testing.T, region string) Vpc { -// svc := ec2.New(newSession(region)) -// defaultVpc, err := getDefaultVpc(region) -// require.NoError(t, err) -// if defaultVpc == (Vpc{}) { -// vpc, err := svc.CreateDefaultVpc(nil) -// require.NoError(t, err) -// defaultVpc.Region = region -// defaultVpc.VpcId = awsgo.StringValue(vpc.Vpc.VpcId) -// defaultVpc.svc = svc -// } -// return defaultVpc -// } -// -// func getRandomDefaultVpcs(t *testing.T, howMany int) []Vpc { -// var defaultVpcs []Vpc -// -// for i := 0; i < howMany; i++ { -// region := getRandomRegion() -// defaultVpcs = append(defaultVpcs, createRandomDefaultVpc(t, region)) -// } -// return defaultVpcs -// } -// -// -// func TestNukeDefaultVpcs(t *testing.T) { -// t.Parallel() -// -// // How many default VPCs to nuke for this test -// count := 3 -// -// defaultVpcs := getRandomDefaultVpcs(t, count) -// -// err := NukeVpcs(defaultVpcs) -// require.NoError(t, err) -// -// for _, vpc := range defaultVpcs { -// input := &ec2.DescribeVpcsInput{ -// Filters: []*ec2.Filter{ -// { -// Name: awsgo.String("vpc-id"), -// Values: []*string{awsgo.String(vpc.VpcId)}, -// }, -// }, -// } -// result, err := vpc.svc.DescribeVpcs(input) -// require.NoError(t, err) -// assert.Len(t, result.Vpcs, 0) -// } -// } diff --git a/aws/ec2_utils_for_test.go b/aws/ec2_utils_for_test.go deleted file mode 100644 index 016f4d95..00000000 --- a/aws/ec2_utils_for_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package aws - -import ( - "testing" - - "github.com/aws/aws-sdk-go/aws" - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/stretchr/testify/require" -) - -func getVpcSubnetsDistinctByAz(t *testing.T, session *session.Session, vpcId string) []string { - result := getVpcSubnets(t, session, vpcId) - var subnetsByAz = make(map[string]string) - // collect subnetworks distinct by AZ - for _, v := range result.Subnets { - subnetsByAz[*v.AvailabilityZone] = awsgo.StringValue(v.SubnetId) - } - var subnets []string - for _, subnet := range subnetsByAz { - subnets = append(subnets, subnet) - } - return subnets -} - -func getVpcSubnets(t *testing.T, session *session.Session, vpcId string) *ec2.DescribeSubnetsOutput { - svc := ec2.New(session) - - param := &ec2.DescribeSubnetsInput{ - Filters: []*ec2.Filter{ - { - Name: aws.String("vpc-id"), - Values: aws.StringSlice([]string{vpcId}), - }, - }, - } - - result, err := svc.DescribeSubnets(param) - require.NoError(t, err) - return result -} diff --git a/aws/ecs_cluster.go b/aws/ecs_cluster.go index 5a0fa3a9..0544b451 100644 --- a/aws/ecs_cluster.go +++ b/aws/ecs_cluster.go @@ -1,6 +1,7 @@ package aws import ( + "github.com/gruntwork-io/cloud-nuke/util" "time" "github.com/gruntwork-io/cloud-nuke/telemetry" @@ -8,7 +9,6 @@ import ( "github.com/aws/aws-sdk-go/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecs" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" @@ -24,11 +24,31 @@ const activeEcsClusterStatus string = "ACTIVE" // For more details on this, please read here: https://docs.aws.amazon.com/cli/latest/reference/ecs/describe-clusters.html#options const describeClustersRequestBatchSize = 100 -// Filter all active ecs clusters -func getAllActiveEcsClusterArns(awsSession *session.Session, configObj config.Config) ([]*string, error) { - svc := ecs.New(awsSession) +// getAllEcsClusters - Returns a string of ECS Cluster ARNs, which uniquely identifies the cluster. +// We need to get all clusters before we can get all services. +func (clusters ECSClusters) getAllEcsClusters() ([]*string, error) { + clusterArns := []*string{} + result, err := clusters.Client.ListClusters(&ecs.ListClustersInput{}) + if err != nil { + return nil, errors.WithStackTrace(err) + } + clusterArns = append(clusterArns, result.ClusterArns...) + + // Handle pagination: continuously pull the next page if nextToken is set + for awsgo.StringValue(result.NextToken) != "" { + result, err = clusters.Client.ListClusters(&ecs.ListClustersInput{NextToken: result.NextToken}) + if err != nil { + return nil, errors.WithStackTrace(err) + } + clusterArns = append(clusterArns, result.ClusterArns...) + } + + return clusterArns, nil +} - allClusters, err := getAllEcsClusters(awsSession) +// Filter all active ecs clusters +func (clusters ECSClusters) getAllActiveEcsClusterArns(configObj config.Config) ([]*string, error) { + allClusters, err := clusters.getAllEcsClusters() if err != nil { logging.Logger.Debug("Error getting all ECS clusters") return nil, errors.WithStackTrace(err) @@ -42,7 +62,7 @@ func getAllActiveEcsClusterArns(awsSession *session.Session, configObj config.Co Clusters: awsgo.StringSlice(batch), } - describedClusters, describeErr := svc.DescribeClusters(input) + describedClusters, describeErr := clusters.Client.DescribeClusters(input) if describeErr != nil { logging.Logger.Debugf("Error describing ECS clusters from input %s: ", input) return nil, errors.WithStackTrace(describeErr) @@ -71,15 +91,11 @@ func shouldIncludeECSCluster(cluster *ecs.Cluster, configObj config.Config) bool return false } - return config.ShouldInclude( - awsgo.StringValue(cluster.ClusterName), - configObj.ECSCluster.IncludeRule.NamesRegExp, - configObj.ECSCluster.ExcludeRule.NamesRegExp, - ) + return configObj.ECSCluster.ShouldInclude(config.ResourceValue{Name: cluster.ClusterName}) } -func getAllEcsClustersOlderThan(awsSession *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) { - clusterArns, err := getAllActiveEcsClusterArns(awsSession, configObj) +func (clusters ECSClusters) getAll(configObj config.Config) ([]*string, error) { + clusterArns, err := clusters.getAllActiveEcsClusterArns(configObj) if err != nil { logging.Logger.Debugf("Error getting all ECS clusters with `ACTIVE` status") return nil, errors.WithStackTrace(err) @@ -88,43 +104,41 @@ func getAllEcsClustersOlderThan(awsSession *session.Session, excludeAfter time.T var filteredEcsClusters []*string for _, clusterArn := range clusterArns { - firstSeenTime, err := getFirstSeenEcsClusterTag(awsSession, clusterArn) + firstSeenTime, err := clusters.getFirstSeenTag(clusterArn) if err != nil { logging.Logger.Debugf("Error getting the `cloud-nuke-first-seen` tag for ECS cluster with ARN %s", aws.StringValue(clusterArn)) return nil, errors.WithStackTrace(err) } if firstSeenTime.IsZero() { - err := tagEcsClusterWhenFirstSeen(awsSession, clusterArn, time.Now().UTC()) + err := clusters.setFirstSeenTag(clusterArn, time.Now().UTC()) if err != nil { logging.Logger.Debugf("Error tagging the ECS cluster with ARN %s", aws.StringValue(clusterArn)) return nil, errors.WithStackTrace(err) } - } else if excludeAfter.After(firstSeenTime) { + } else if configObj.ECSCluster.ShouldInclude(config.ResourceValue{Time: firstSeenTime}) { filteredEcsClusters = append(filteredEcsClusters, clusterArn) } } return filteredEcsClusters, nil } -func nukeEcsClusters(awsSession *session.Session, ecsClusterArns []*string) error { - svc := ecs.New(awsSession) - +func (clusters ECSClusters) nukeAll(ecsClusterArns []*string) error { numNuking := len(ecsClusterArns) if numNuking == 0 { - logging.Logger.Debugf("No ECS clusters to nuke in region %s", aws.StringValue(awsSession.Config.Region)) + logging.Logger.Debugf("No ECS clusters to nuke in region %s", clusters.Region) return nil } - logging.Logger.Debugf("Deleting %d ECS clusters in region %s", numNuking, aws.StringValue(awsSession.Config.Region)) + logging.Logger.Debugf("Deleting %d ECS clusters in region %s", numNuking, clusters.Region) var nukedEcsClusters []*string for _, clusterArn := range ecsClusterArns { params := &ecs.DeleteClusterInput{ Cluster: clusterArn, } - _, err := svc.DeleteCluster(params) + _, err := clusters.Client.DeleteCluster(params) // Record status of this resource e := report.Entry{ @@ -139,7 +153,7 @@ func nukeEcsClusters(awsSession *session.Session, ecsClusterArns []*string) erro telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Cluster", }, map[string]interface{}{ - "region": *awsSession.Config.Region, + "region": clusters.Region, }) return errors.WithStackTrace(err) } @@ -149,16 +163,14 @@ func nukeEcsClusters(awsSession *session.Session, ecsClusterArns []*string) erro } numNuked := len(nukedEcsClusters) - logging.Logger.Debugf("[OK] %d of %d ECS cluster(s) deleted in %s", numNuked, numNuking, aws.StringValue(awsSession.Config.Region)) + logging.Logger.Debugf("[OK] %d of %d ECS cluster(s) deleted in %s", numNuked, numNuking, clusters.Region) return nil } // Tag an ECS cluster identified by the given cluster ARN when it's first seen by cloud-nuke -func tagEcsClusterWhenFirstSeen(awsSession *session.Session, clusterArn *string, timestamp time.Time) error { - svc := ecs.New(awsSession) - - firstSeenTime := formatTimestampTag(timestamp) +func (clusters ECSClusters) setFirstSeenTag(clusterArn *string, timestamp time.Time) error { + firstSeenTime := util.FormatTimestampTag(timestamp) input := &ecs.TagResourceInput{ ResourceArn: clusterArn, @@ -170,7 +182,7 @@ func tagEcsClusterWhenFirstSeen(awsSession *session.Session, clusterArn *string, }, } - _, err := svc.TagResource(input) + _, err := clusters.Client.TagResource(input) if err != nil { return errors.WithStackTrace(err) } @@ -179,24 +191,23 @@ func tagEcsClusterWhenFirstSeen(awsSession *session.Session, clusterArn *string, } // Get the `cloud-nuke-first-seen` tag value for a given ECS cluster -func getFirstSeenEcsClusterTag(awsSession *session.Session, clusterArn *string) (time.Time, error) { - var firstSeenTime time.Time +func (clusters ECSClusters) getFirstSeenTag(clusterArn *string) (*time.Time, error) { + var firstSeenTime *time.Time - svc := ecs.New(awsSession) input := &ecs.ListTagsForResourceInput{ ResourceArn: clusterArn, } - clusterTags, err := svc.ListTagsForResource(input) + clusterTags, err := clusters.Client.ListTagsForResource(input) if err != nil { logging.Logger.Debugf("Error getting the tags for ECS cluster with ARN %s", aws.StringValue(clusterArn)) return firstSeenTime, errors.WithStackTrace(err) } for _, tag := range clusterTags.Tags { - if aws.StringValue(tag.Key) == firstSeenTagKey { + if util.IsFirstSeenTag(tag.Key) { - firstSeenTime, err := parseTimestampTag(aws.StringValue(tag.Value)) + firstSeenTime, err := util.ParseTimestampTag(tag.Value) if err != nil { logging.Logger.Debugf("Error parsing the `cloud-nuke-first-seen` tag for ECS cluster with ARN %s", aws.StringValue(clusterArn)) return firstSeenTime, errors.WithStackTrace(err) @@ -205,19 +216,6 @@ func getFirstSeenEcsClusterTag(awsSession *session.Session, clusterArn *string) return firstSeenTime, nil } } - return firstSeenTime, nil -} -func parseTimestampTag(timestamp string) (time.Time, error) { - parsed, err := time.Parse(firstSeenTimeFormat, timestamp) - if err != nil { - logging.Logger.Debugf("Error parsing the timestamp into a `RFC3339` Time format") - return parsed, errors.WithStackTrace(err) - - } - return parsed, nil -} - -func formatTimestampTag(timestamp time.Time) string { - return timestamp.Format(firstSeenTimeFormat) + return firstSeenTime, nil } diff --git a/aws/ecs_cluster_test.go b/aws/ecs_cluster_test.go index 0cb9cc9e..2b3fcc82 100644 --- a/aws/ecs_cluster_test.go +++ b/aws/ecs_cluster_test.go @@ -1,214 +1,138 @@ package aws import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" - "regexp" - "testing" - "time" - - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go/service/ecs/ecsiface" "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/logging" + "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/gruntwork-io/cloud-nuke/util" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "regexp" + "testing" + "time" ) -// Test we can create a cluster, tag it, and then find the tag -func TestCanTagEcsClusters(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - region := getRandomFargateSupportedRegion() - - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - require.NoError(t, err) - - cluster := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID()) - defer deleteEcsCluster(awsSession, cluster) - - tagValue := time.Now().UTC() - - tagErr := tagEcsClusterWhenFirstSeen(awsSession, cluster.ClusterArn, tagValue) - require.NoError(t, tagErr) - - returnedTag, err := getFirstSeenEcsClusterTag(awsSession, cluster.ClusterArn) - require.NoError(t, err) - - parsedTagValue, parseErr1 := parseTimestampTag(formatTimestampTag(tagValue)) - require.NoError(t, parseErr1) - - parsedReturnValue, parseErr2 := parseTimestampTag(formatTimestampTag(returnedTag)) - require.NoError(t, parseErr2) - - //compare that the tags' Time values after formatting are equal - assert.Equal(t, parsedTagValue, parsedReturnValue) +type mockedEC2Cluster struct { + ecsiface.ECSAPI + ListClustersOutput ecs.ListClustersOutput + DescribeClustersOutput ecs.DescribeClustersOutput + TagResourceOutput ecs.TagResourceOutput + ListTagsForResourceOutput ecs.ListTagsForResourceOutput + DeleteClusterOutput ecs.DeleteClusterOutput } -// Test we can get all ECS clusters younger than < X time based on tags -func TestCanListAllEcsClustersOlderThan24hours(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - region := getRandomFargateSupportedRegion() - - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - require.NoError(t, err) - - cluster1 := createEcsFargateCluster(t, awsSession, util.UniqueID()) - defer deleteEcsCluster(awsSession, cluster1) - cluster2 := createEcsFargateCluster(t, awsSession, util.UniqueID()) - defer deleteEcsCluster(awsSession, cluster2) +func (m mockedEC2Cluster) ListClusters(*ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { + return &m.ListClustersOutput, nil +} - now := time.Now().UTC() - var olderClusterTagValue = now.Add(time.Hour * time.Duration(-48)) - var youngerClusterTagValue = now.Add(time.Hour * time.Duration(-23)) +func (m mockedEC2Cluster) DescribeClusters(*ecs.DescribeClustersInput) (*ecs.DescribeClustersOutput, error) { + return &m.DescribeClustersOutput, nil +} - err1 := tagEcsClusterWhenFirstSeen(awsSession, cluster1.ClusterArn, olderClusterTagValue) - require.NoError(t, err1) - err2 := tagEcsClusterWhenFirstSeen(awsSession, cluster2.ClusterArn, youngerClusterTagValue) - require.NoError(t, err2) +func (m mockedEC2Cluster) TagResource(*ecs.TagResourceInput) (*ecs.TagResourceOutput, error) { + return &m.TagResourceOutput, nil +} - last24Hours := now.Add(time.Hour * time.Duration(-24)) - filteredClusterArns, err := getAllEcsClustersOlderThan(awsSession, last24Hours, config.Config{}) - require.NoError(t, err) +func (m mockedEC2Cluster) ListTagsForResource(*ecs.ListTagsForResourceInput) (*ecs.ListTagsForResourceOutput, error) { + return &m.ListTagsForResourceOutput, nil +} - assert.Contains(t, awsgo.StringValueSlice(filteredClusterArns), awsgo.StringValue(cluster1.ClusterArn)) +func (m mockedEC2Cluster) DeleteCluster(*ecs.DeleteClusterInput) (*ecs.DeleteClusterOutput, error) { + return &m.DeleteClusterOutput, nil } -// Test we can nuke all ECS clusters older than 24hrs -func TestCanNukeAllEcsClustersOlderThan24Hours(t *testing.T) { +func TestEC2Cluster_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region := getRandomFargateSupportedRegion() - - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - require.NoError(t, err) - cluster1 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-24-"+util.UniqueID()) - defer deleteEcsCluster(awsSession, cluster1) - cluster2 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-24-"+util.UniqueID()) - defer deleteEcsCluster(awsSession, cluster2) - cluster3 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-24-"+util.UniqueID()) - defer deleteEcsCluster(awsSession, cluster3) - - now := time.Now().UTC() - var oldClusterTagValue1 = now.Add(time.Hour * time.Duration(-48)) - var youngClusterTagValue = now - var oldClusterTagValue2 = now.Add(time.Hour * time.Duration(-27)) - - err1 := tagEcsClusterWhenFirstSeen(awsSession, cluster1.ClusterArn, oldClusterTagValue1) - require.NoError(t, err1) - err2 := tagEcsClusterWhenFirstSeen(awsSession, cluster2.ClusterArn, youngClusterTagValue) - require.NoError(t, err2) - err3 := tagEcsClusterWhenFirstSeen(awsSession, cluster3.ClusterArn, oldClusterTagValue2) - require.NoError(t, err3) - - // expression to match created clusters - clusterMatchExpression, err := regexp.Compile("^cloud-nuke-test-24*") - assert.NoError(t, err) - - last24Hours := now.Add(time.Hour * time.Duration(-24)) - filteredClusterArns, err := getAllEcsClustersOlderThan(awsSession, last24Hours, config.Config{ - ECSCluster: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *clusterMatchExpression, - }, + testArn1 := "arn:aws:ecs:us-east-1:123456789012:cluster/cluster1" + testArn2 := "arn:aws:ecs:us-east-1:123456789012:cluster/cluster2" + testName1 := "cluster1" + testName2 := "cluster2" + now := time.Now() + ec := ECSClusters{ + Client: mockedEC2Cluster{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []*string{ + aws.String(testArn1), + aws.String(testArn2), }, }, - }, - }) - require.NoError(t, err) - - nukeErr := nukeEcsClusters(awsSession, filteredClusterArns) - require.NoError(t, nukeErr) - - allLeftClusterArns, err := getAllEcsClusters(awsSession) - require.NoError(t, err) - - assert.Contains(t, awsgo.StringValueSlice(allLeftClusterArns), awsgo.StringValue(cluster2.ClusterArn)) -} - -// Test the config file filtering works as expected -func TestShouldIncludeECSCluster(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - mockCluster := &ecs.Cluster{ - ClusterName: awsgo.String("cloud-nuke-test"), - Status: awsgo.String("ACTIVE"), - } - - mockClusterInactive := &ecs.Cluster{ - ClusterName: awsgo.String("cloud-nuke-test"), - Status: awsgo.String("INACTIVE"), - } - - mockExpression, err := regexp.Compile("^cloud-nuke-*") - if err != nil { - logging.Logger.Fatalf("There was an error compiling regex expression %v", err) - } - mockExcludeConfig := config.Config{ - ECSCluster: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ + DescribeClustersOutput: ecs.DescribeClustersOutput{ + Clusters: []*ecs.Cluster{ { - RE: *mockExpression, + ClusterArn: aws.String(testArn1), + Status: aws.String("ACTIVE"), + ClusterName: aws.String(testName1), + }, + { + ClusterArn: aws.String(testArn2), + Status: aws.String("ACTIVE"), + ClusterName: aws.String(testName2), }, }, }, - }, - } - mockIncludeConfig := config.Config{ - ECSCluster: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ + ListTagsForResourceOutput: ecs.ListTagsForResourceOutput{ + Tags: []*ecs.Tag{ { - RE: *mockExpression, + Key: aws.String(util.FirstSeenTagKey), + Value: aws.String(util.FormatTimestampTag(now)), }, }, }, }, } - cases := []struct { - Name string - Cluster *ecs.Cluster - Config config.Config - Expected bool + tests := map[string]struct { + configObj config.ResourceType + expected []string }{ - { - Name: "ConfigExclude", - Cluster: mockCluster, - Config: mockExcludeConfig, - Expected: false, + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testArn1, testArn2}, }, - { - Name: "ConfigInclude", - Cluster: mockCluster, - Config: mockIncludeConfig, - Expected: true, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testArn2}, }, - { - Name: "ConfigIncludeInactive", - Cluster: mockClusterInactive, - Config: mockIncludeConfig, - Expected: false, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1 * time.Hour)), + }}, + expected: []string{}, }, } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - result := shouldIncludeECSCluster(c.Cluster, c.Config) - assert.Equal(t, c.Expected, result) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := ec.getAll(config.Config{ + ECSCluster: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.StringValueSlice(names)) }) } + +} + +func TestEC2Cluster_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() + + ec := ECSClusters{ + Client: mockedEC2Cluster{ + DeleteClusterOutput: ecs.DeleteClusterOutput{}, + }, + } + + err := ec.nukeAll([]*string{aws.String("arn:aws:ecs:us-east-1:123456789012:cluster/cluster1")}) + require.NoError(t, err) } diff --git a/aws/ecs_cluster_types.go b/aws/ecs_cluster_types.go index cd7a5d91..ef688172 100644 --- a/aws/ecs_cluster_types.go +++ b/aws/ecs_cluster_types.go @@ -37,7 +37,7 @@ func (clusters ECSClusters) MaxBatchSize() int { // Nuke - nuke all ECS Cluster resources func (clusters ECSClusters) Nuke(awsSession *session.Session, identifiers []string) error { - if err := nukeEcsClusters(awsSession, awsgo.StringSlice(identifiers)); err != nil { + if err := clusters.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) } return nil diff --git a/aws/ecs_service.go b/aws/ecs_service.go index 04599f08..26d05b66 100644 --- a/aws/ecs_service.go +++ b/aws/ecs_service.go @@ -1,14 +1,11 @@ package aws import ( - "time" - "github.com/gruntwork-io/cloud-nuke/telemetry" commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" "github.com/aws/aws-sdk-go/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecs" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" @@ -18,10 +15,9 @@ import ( // getAllEcsClusters - Returns a string of ECS Cluster ARNs, which uniquely identifies the cluster. // We need to get all clusters before we can get all services. -func getAllEcsClusters(awsSession *session.Session) ([]*string, error) { - svc := ecs.New(awsSession) +func (services ECSServices) getAllEcsClusters() ([]*string, error) { clusterArns := []*string{} - result, err := svc.ListClusters(&ecs.ListClustersInput{}) + result, err := services.Client.ListClusters(&ecs.ListClustersInput{}) if err != nil { return nil, errors.WithStackTrace(err) } @@ -29,7 +25,7 @@ func getAllEcsClusters(awsSession *session.Session) ([]*string, error) { // Handle pagination: continuously pull the next page if nextToken is set for awsgo.StringValue(result.NextToken) != "" { - result, err = svc.ListClusters(&ecs.ListClustersInput{NextToken: result.NextToken}) + result, err = services.Client.ListClusters(&ecs.ListClustersInput{NextToken: result.NextToken}) if err != nil { return nil, errors.WithStackTrace(err) } @@ -42,7 +38,7 @@ func getAllEcsClusters(awsSession *session.Session) ([]*string, error) { // filterOutRecentServices - Given a list of services and an excludeAfter // timestamp, filter out any services that were created after `excludeAfter. // Additionally, filter based on Config file patterns. -func filterOutRecentServices(svc *ecs.ECS, clusterArn *string, ecsServiceArns []string, excludeAfter time.Time, configObj config.Config) ([]*string, error) { +func (services ECSServices) filterOutRecentServices(clusterArn *string, ecsServiceArns []string, configObj config.Config) ([]*string, error) { // Fetch descriptions in batches of 10, which is the max that AWS // accepts for describe service. var filteredEcsServiceArns []*string @@ -52,12 +48,15 @@ func filterOutRecentServices(svc *ecs.ECS, clusterArn *string, ecsServiceArns [] Cluster: clusterArn, Services: awsgo.StringSlice(batch), } - describeResult, err := svc.DescribeServices(params) + describeResult, err := services.Client.DescribeServices(params) if err != nil { return nil, errors.WithStackTrace(err) } for _, service := range describeResult.Services { - if shouldIncludeECSService(service, excludeAfter, configObj) { + if configObj.ECSService.ShouldInclude(config.ResourceValue{ + Name: service.ServiceName, + Time: service.CreatedAt, + }) { filteredEcsServiceArns = append(filteredEcsServiceArns, service.ServiceArn) } } @@ -65,42 +64,30 @@ func filterOutRecentServices(svc *ecs.ECS, clusterArn *string, ecsServiceArns [] return filteredEcsServiceArns, nil } -func shouldIncludeECSService(service *ecs.Service, excludeAfter time.Time, configObj config.Config) bool { - if service == nil { - return false - } - - if service.CreatedAt != nil && excludeAfter.Before(*service.CreatedAt) { - return false - } - - return config.ShouldInclude( - awsgo.StringValue(service.ServiceName), - configObj.ECSService.IncludeRule.NamesRegExp, - configObj.ECSService.ExcludeRule.NamesRegExp, - ) -} - // getAllEcsServices - Returns a formatted string of ECS Service ARNs, which // uniquely identifies the service, in addition to a mapping of services to // clusters. For ECS, need to track ECS clusters of services as all service // level API endpoints require providing the corresponding cluster. // Note that this looks up services by ECS cluster ARNs. -func getAllEcsServices(awsSession *session.Session, ecsClusterArns []*string, excludeAfter time.Time, configObj config.Config) ([]*string, map[string]string, error) { +func (services ECSServices) getAll(configObj config.Config) ([]*string, error) { + ecsClusterArns, err := services.getAllEcsClusters() + if err != nil { + return nil, errors.WithStackTrace(err) + } + ecsServiceClusterMap := map[string]string{} - svc := ecs.New(awsSession) // For each cluster, fetch all services, filtering out recently created // ones. var ecsServiceArns []*string for _, clusterArn := range ecsClusterArns { - result, err := svc.ListServices(&ecs.ListServicesInput{Cluster: clusterArn}) + result, err := services.Client.ListServices(&ecs.ListServicesInput{Cluster: clusterArn}) if err != nil { - return nil, nil, errors.WithStackTrace(err) + return nil, errors.WithStackTrace(err) } - filteredServiceArns, err := filterOutRecentServices(svc, clusterArn, awsgo.StringValueSlice(result.ServiceArns), excludeAfter, configObj) + filteredServiceArns, err := services.filterOutRecentServices(clusterArn, awsgo.StringValueSlice(result.ServiceArns), configObj) if err != nil { - return nil, nil, errors.WithStackTrace(err) + return nil, errors.WithStackTrace(err) } // Update mapping to be used later in nuking for _, serviceArn := range filteredServiceArns { @@ -109,27 +96,28 @@ func getAllEcsServices(awsSession *session.Session, ecsClusterArns []*string, ex ecsServiceArns = append(ecsServiceArns, filteredServiceArns...) } - return ecsServiceArns, ecsServiceClusterMap, nil + services.ServiceClusterMap = ecsServiceClusterMap + return ecsServiceArns, nil } // drainEcsServices - Drain all tasks from all services requested. This will // return a list of service ARNs that have been successfully requested to be // drained. -func drainEcsServices(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecsServiceArns []*string) []*string { +func (services ECSServices) drainEcsServices(ecsServiceArns []*string) []*string { var requestedDrains []*string for _, ecsServiceArn := range ecsServiceArns { describeParams := &ecs.DescribeServicesInput{ - Cluster: awsgo.String(ecsServiceClusterMap[*ecsServiceArn]), + Cluster: awsgo.String(services.ServiceClusterMap[*ecsServiceArn]), Services: []*string{ecsServiceArn}, } - describeServicesOutput, err := svc.DescribeServices(describeParams) + describeServicesOutput, err := services.Client.DescribeServices(describeParams) if err != nil { logging.Logger.Errorf("[Failed] Failed to describe service %s: %s", *ecsServiceArn, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Service", }, map[string]interface{}{ - "region": *svc.Config.Region, + "region": services.Region, "reason": "Unable to describe", }) } else { @@ -139,17 +127,17 @@ func drainEcsServices(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecsS requestedDrains = append(requestedDrains, ecsServiceArn) } else { params := &ecs.UpdateServiceInput{ - Cluster: awsgo.String(ecsServiceClusterMap[*ecsServiceArn]), + Cluster: awsgo.String(services.ServiceClusterMap[*ecsServiceArn]), Service: ecsServiceArn, DesiredCount: awsgo.Int64(0), } - _, err = svc.UpdateService(params) + _, err = services.Client.UpdateService(params) if err != nil { logging.Logger.Errorf("[Failed] Failed to drain service %s: %s", *ecsServiceArn, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Service", }, map[string]interface{}{ - "region": *svc.Config.Region, + "region": services.Region, "reason": "Unable to drain", }) } else { @@ -165,20 +153,20 @@ func drainEcsServices(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecsS // given list of services, by waiting for stability which is defined as // desiredCount == runningCount. This will return a list of service ARNs that // have successfully been drained. -func waitUntilServicesDrained(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecsServiceArns []*string) []*string { +func (services ECSServices) waitUntilServicesDrained(ecsServiceArns []*string) []*string { var successfullyDrained []*string for _, ecsServiceArn := range ecsServiceArns { params := &ecs.DescribeServicesInput{ - Cluster: awsgo.String(ecsServiceClusterMap[*ecsServiceArn]), + Cluster: awsgo.String(services.ServiceClusterMap[*ecsServiceArn]), Services: []*string{ecsServiceArn}, } - err := svc.WaitUntilServicesStable(params) + err := services.Client.WaitUntilServicesStable(params) if err != nil { logging.Logger.Debugf("[Failed] Failed waiting for service to be stable %s: %s", *ecsServiceArn, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Service", }, map[string]interface{}{ - "region": *svc.Config.Region, + "region": services.Region, "reason": "Failed Waiting for Drain", }) } else { @@ -191,20 +179,20 @@ func waitUntilServicesDrained(svc *ecs.ECS, ecsServiceClusterMap map[string]stri // deleteEcsServices - Deletes all services requested. Returns a list of // service ARNs that have been accepted by AWS for deletion. -func deleteEcsServices(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecsServiceArns []*string) []*string { +func (services ECSServices) deleteEcsServices(ecsServiceArns []*string) []*string { var requestedDeletes []*string for _, ecsServiceArn := range ecsServiceArns { params := &ecs.DeleteServiceInput{ - Cluster: awsgo.String(ecsServiceClusterMap[*ecsServiceArn]), + Cluster: awsgo.String(services.ServiceClusterMap[*ecsServiceArn]), Service: ecsServiceArn, } - _, err := svc.DeleteService(params) + _, err := services.Client.DeleteService(params) if err != nil { logging.Logger.Debugf("[Failed] Failed deleting service %s: %s", *ecsServiceArn, err) telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Service", }, map[string]interface{}{ - "region": *svc.Config.Region, + "region": services.Region, "reason": "Unable to Delete", }) } else { @@ -217,14 +205,14 @@ func deleteEcsServices(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecs // waitUntilServicesDeleted - Waits until the service has been actually deleted // from AWS. Returns a list of service ARNs that have been successfully // deleted. -func waitUntilServicesDeleted(svc *ecs.ECS, ecsServiceClusterMap map[string]string, ecsServiceArns []*string) []*string { +func (services ECSServices) waitUntilServicesDeleted(ecsServiceArns []*string) []*string { var successfullyDeleted []*string for _, ecsServiceArn := range ecsServiceArns { params := &ecs.DescribeServicesInput{ - Cluster: awsgo.String(ecsServiceClusterMap[*ecsServiceArn]), + Cluster: awsgo.String(services.ServiceClusterMap[*ecsServiceArn]), Services: []*string{ecsServiceArn}, } - err := svc.WaitUntilServicesInactive(params) + err := services.Client.WaitUntilServicesInactive(params) // Record status of this resource e := report.Entry{ @@ -239,7 +227,7 @@ func waitUntilServicesDeleted(svc *ecs.ECS, ecsServiceClusterMap map[string]stri telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking ECS Service", }, map[string]interface{}{ - "region": *svc.Config.Region, + "region": services.Region, "reason": "Failed Waiting for Delete", }) } else { @@ -258,16 +246,14 @@ func waitUntilServicesDeleted(svc *ecs.ECS, ecsServiceClusterMap map[string]stri // 2.) Delete service object once no tasks are running. // Note that this will swallow failed deletes and continue along, logging the // service ARN so that we can find it later. -func nukeAllEcsServices(awsSession *session.Session, ecsServiceClusterMap map[string]string, ecsServiceArns []*string) error { +func (services ECSServices) nukeAll(ecsServiceArns []*string) error { numNuking := len(ecsServiceArns) - svc := ecs.New(awsSession) - if numNuking == 0 { - logging.Logger.Debugf("No ECS services to nuke in region %s", *awsSession.Config.Region) + logging.Logger.Debugf("No ECS services to nuke in region %s", services.Region) return nil } - logging.Logger.Debugf("Deleting %d ECS services in region %s", numNuking, *awsSession.Config.Region) + logging.Logger.Debugf("Deleting %d ECS services in region %s", numNuking, services.Region) // First, drain all the services to 0. You can't delete a // service that is running tasks. @@ -275,12 +261,12 @@ func nukeAllEcsServices(awsSession *session.Session, ecsServiceClusterMap map[st // wait for them in a separate loop because it will take a // while to drain the services. // Then, we delete the services that have been successfully drained. - requestedDrains := drainEcsServices(svc, ecsServiceClusterMap, ecsServiceArns) - successfullyDrained := waitUntilServicesDrained(svc, ecsServiceClusterMap, requestedDrains) - requestedDeletes := deleteEcsServices(svc, ecsServiceClusterMap, successfullyDrained) - successfullyDeleted := waitUntilServicesDeleted(svc, ecsServiceClusterMap, requestedDeletes) + requestedDrains := services.drainEcsServices(ecsServiceArns) + successfullyDrained := services.waitUntilServicesDrained(requestedDrains) + requestedDeletes := services.deleteEcsServices(successfullyDrained) + successfullyDeleted := services.waitUntilServicesDeleted(requestedDeletes) numNuked := len(successfullyDeleted) - logging.Logger.Debugf("[OK] %d of %d ECS service(s) deleted in %s", numNuked, numNuking, *awsSession.Config.Region) + logging.Logger.Debugf("[OK] %d of %d ECS service(s) deleted in %s", numNuked, numNuking, services.Client) return nil } diff --git a/aws/ecs_service_test.go b/aws/ecs_service_test.go index 0fab1860..c8d5eaa8 100644 --- a/aws/ecs_service_test.go +++ b/aws/ecs_service_test.go @@ -1,314 +1,148 @@ package aws import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go/service/ecs/ecsiface" + "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/telemetry" + "github.com/stretchr/testify/require" "regexp" "testing" "time" - - "github.com/gruntwork-io/cloud-nuke/telemetry" - - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/logging" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/go-commons/errors" - "github.com/stretchr/testify/assert" ) -// Test that we can find ECS services that are running Fargate tasks -func TestListECSFargateServices(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region := getRandomFargateSupportedRegion() - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - - if err != nil { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - - ecsServiceClusterMap := map[string]string{} - uniqueTestID := "cloud-nuke-test-" + util.UniqueID() - clusterName := uniqueTestID + "-cluster" - serviceName := uniqueTestID + "-service" - taskFamilyName := uniqueTestID + "-task" - - cluster := createEcsFargateCluster(t, awsSession, clusterName) - defer deleteEcsCluster(awsSession, cluster) - - taskDefinition := createEcsTaskDefinition(t, awsSession, taskFamilyName, "FARGATE") - defer deleteEcsTaskDefinition(awsSession, taskDefinition) - - service := createEcsService(t, awsSession, serviceName, cluster, "FARGATE", taskDefinition, "REPLICA") - ecsServiceClusterMap[*service.ServiceArn] = *cluster.ClusterArn - defer nukeAllEcsServices(awsSession, ecsServiceClusterMap, []*string{service.ServiceArn}) - - ecsServiceArns, newEcsServiceClusterMap, err := getAllEcsServices(awsSession, []*string{cluster.ClusterArn}, time.Now().Add(1*time.Hour*-1), config.Config{}) - if err != nil { - assert.Failf(t, "Unable to fetch list of services: %s", err.Error()) - } - assert.NotContains(t, awsgo.StringValueSlice(ecsServiceArns), *service.ServiceArn) - _, exists := newEcsServiceClusterMap[*service.ServiceArn] - assert.False(t, exists) - - ecsServiceArns, newEcsServiceClusterMap, err = getAllEcsServices(awsSession, []*string{cluster.ClusterArn}, time.Now().Add(1*time.Hour), config.Config{}) - if err != nil { - assert.Failf(t, "Unable to fetch list of services: %s", err.Error()) - } - assert.Contains(t, awsgo.StringValueSlice(ecsServiceArns), *service.ServiceArn) - _, exists = newEcsServiceClusterMap[*service.ServiceArn] - assert.True(t, exists) +type mockedEC2Service struct { + ecsiface.ECSAPI + ListClustersOutput ecs.ListClustersOutput + DescribeServicesOutput ecs.DescribeServicesOutput + ListServicesOutput ecs.ListServicesOutput + UpdateServiceOutput ecs.UpdateServiceOutput + DeleteServiceOutput ecs.DeleteServiceOutput } -// Test that we can successfully nuke ECS services running Fargate tasks -func TestNukeECSFargateServices(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region := getRandomFargateSupportedRegion() - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - - if err != nil { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - - uniqueTestID := "cloud-nuke-test-" + util.UniqueID() - clusterName := uniqueTestID + "-cluster" - serviceName := uniqueTestID + "-service" - taskFamilyName := uniqueTestID + "-task" - - cluster := createEcsFargateCluster(t, awsSession, clusterName) - defer deleteEcsCluster(awsSession, cluster) - - taskDefinition := createEcsTaskDefinition(t, awsSession, taskFamilyName, "FARGATE") - defer deleteEcsTaskDefinition(awsSession, taskDefinition) - - service := createEcsService(t, awsSession, serviceName, cluster, "FARGATE", taskDefinition, "REPLICA") - - ecsServiceClusterMap := map[string]string{} - ecsServiceClusterMap[*service.ServiceArn] = *cluster.ClusterArn - err = nukeAllEcsServices(awsSession, ecsServiceClusterMap, []*string{service.ServiceArn}) - if err != nil { - assert.Fail(t, err.Error()) - } - - ecsServiceArns, _, err := getAllEcsServices(awsSession, []*string{cluster.ClusterArn}, time.Now().Add(1*time.Hour), config.Config{}) - if err != nil { - assert.Failf(t, "Unable to fetch list of services: %s", err.Error()) - } - assert.NotContains(t, awsgo.StringValueSlice(ecsServiceArns), *service.ServiceArn) +func (m mockedEC2Service) ListClusters(*ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { + return &m.ListClustersOutput, nil } -// Test that we can find ECS services running EC2 tasks -func TestListECSEC2Services(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region := getRandomFargateSupportedRegion() - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - - if err != nil { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - - ecsServiceClusterMap := map[string]string{} - uniqueTestID := "cloud-nuke-test-" + util.UniqueID() - clusterName := uniqueTestID + "-cluster" - serviceName := uniqueTestID + "-service" - taskFamilyName := uniqueTestID + "-task" - roleName := uniqueTestID + "-role" - instanceProfileName := uniqueTestID + "-instance-profile" - - // Prepare resources - // Create the IAM roles for ECS EC2 container instances - role := createEcsRole(t, awsSession, roleName) - defer deleteRole(awsSession, role) - - instanceProfile := createEcsInstanceProfile(t, awsSession, instanceProfileName, role) - defer deleteInstanceProfile(awsSession, instanceProfile) - - // IAM resources are slow to propagate, so give it some - // time - time.Sleep(15 * time.Second) - - // Provision a cluster with ec2 container instances, not - // forgetting to schedule deletion - cluster, instance := createEcsEC2Cluster(t, awsSession, clusterName, instanceProfile) - defer deleteEcsCluster(awsSession, cluster) - defer nukeAllEc2Instances(awsSession, []*string{instance.InstanceId}) - - // Finally, define the task and service - taskDefinition := createEcsTaskDefinition(t, awsSession, taskFamilyName, "EC2") - defer deleteEcsTaskDefinition(awsSession, taskDefinition) - - service := createEcsService(t, awsSession, serviceName, cluster, "EC2", taskDefinition, "REPLICA") - ecsServiceClusterMap[*service.ServiceArn] = *cluster.ClusterArn - defer nukeAllEcsServices(awsSession, ecsServiceClusterMap, []*string{service.ServiceArn}) - // END prepare resources - - ecsServiceArns, newEcsServiceClusterMap, err := getAllEcsServices(awsSession, []*string{cluster.ClusterArn}, time.Now().Add(1*time.Hour*-1), config.Config{}) - if err != nil { - assert.Failf(t, "Unable to fetch list of services: %s", err.Error()) - } - assert.NotContains(t, awsgo.StringValueSlice(ecsServiceArns), *service.ServiceArn) - _, exists := newEcsServiceClusterMap[*service.ServiceArn] - assert.False(t, exists) - - ecsServiceArns, newEcsServiceClusterMap, err = getAllEcsServices(awsSession, []*string{cluster.ClusterArn}, time.Now().Add(1*time.Hour), config.Config{}) - if err != nil { - assert.Failf(t, "Unable to fetch list of services: %s", err.Error()) - } - assert.Contains(t, awsgo.StringValueSlice(ecsServiceArns), *service.ServiceArn) - _, exists = newEcsServiceClusterMap[*service.ServiceArn] - assert.True(t, exists) +func (m mockedEC2Service) DescribeServices(*ecs.DescribeServicesInput) (*ecs.DescribeServicesOutput, error) { + return &m.DescribeServicesOutput, nil } -// Test that we can successfully nuke ECS services running EC2 tasks -func TestNukeECSEC2Services(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region := getRandomFargateSupportedRegion() - awsSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }) - - if err != nil { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - - ecsServiceClusterMap := map[string]string{} - uniqueTestID := "cloud-nuke-test-" + util.UniqueID() - clusterName := uniqueTestID + "-cluster" - serviceName := uniqueTestID + "-service" - daemonServiceName := uniqueTestID + "-daemon-service" - taskFamilyName := uniqueTestID + "-task" - taskDaemonFamilyName := uniqueTestID + "-daemon-task" - roleName := uniqueTestID + "-role" - instanceProfileName := uniqueTestID + "-instance-profile" - - // Prepare resources - // Create the IAM roles for ECS EC2 container instances - role := createEcsRole(t, awsSession, roleName) - defer deleteRole(awsSession, role) - - instanceProfile := createEcsInstanceProfile(t, awsSession, instanceProfileName, role) - defer deleteInstanceProfile(awsSession, instanceProfile) - - // IAM resources are slow to propagate, so give it some - // time - time.Sleep(15 * time.Second) - - // Provision a cluster with ec2 container instances, not - // forgetting to schedule deletion - cluster, instance := createEcsEC2Cluster(t, awsSession, clusterName, instanceProfile) - defer deleteEcsCluster(awsSession, cluster) - defer nukeAllEc2Instances(awsSession, []*string{instance.InstanceId}) - - // Finally, define the task and service - taskDefinition := createEcsTaskDefinition(t, awsSession, taskFamilyName, "EC2") - defer deleteEcsTaskDefinition(awsSession, taskDefinition) - - service := createEcsService(t, awsSession, serviceName, cluster, "EC2", taskDefinition, "REPLICA") - ecsServiceClusterMap[*service.ServiceArn] = *cluster.ClusterArn - - taskDaemonDefinition := createEcsTaskDefinition(t, awsSession, taskDaemonFamilyName, "EC2") - defer deleteEcsTaskDefinition(awsSession, taskDaemonDefinition) +func (m mockedEC2Service) ListServices(*ecs.ListServicesInput) (*ecs.ListServicesOutput, error) { + return &m.ListServicesOutput, nil +} - daemonService := createEcsService(t, awsSession, daemonServiceName, cluster, "EC2", taskDaemonDefinition, "DAEMON") - ecsServiceClusterMap[*daemonService.ServiceArn] = *cluster.ClusterArn +func (m mockedEC2Service) UpdateService(*ecs.UpdateServiceInput) (*ecs.UpdateServiceOutput, error) { + return &m.UpdateServiceOutput, nil +} - // END prepare resources +func (m mockedEC2Service) WaitUntilServicesStable(*ecs.DescribeServicesInput) error { + return nil +} - err = nukeAllEcsServices(awsSession, ecsServiceClusterMap, []*string{service.ServiceArn, daemonService.ServiceArn}) +func (m mockedEC2Service) DeleteService(*ecs.DeleteServiceInput) (*ecs.DeleteServiceOutput, error) { + return &m.DeleteServiceOutput, nil +} - ecsServiceArns, _, err := getAllEcsServices(awsSession, []*string{cluster.ClusterArn}, time.Now().Add(1*time.Hour), config.Config{}) - if err != nil { - assert.Failf(t, "Unable to fetch list of services: %s", err.Error()) - } - assert.NotContains(t, awsgo.StringValueSlice(ecsServiceArns), *service.ServiceArn) - assert.NotContains(t, awsgo.StringValueSlice(ecsServiceArns), *daemonService.ServiceArn) +func (m mockedEC2Service) WaitUntilServicesInactive(*ecs.DescribeServicesInput) error { + return nil } -// Test the config file filtering works as expected -func TestShouldIncludeECSService(t *testing.T) { +func TestEC2Service_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") - mockService := &ecs.Service{ - ServiceName: awsgo.String("cloud-nuke-test"), - CreatedAt: awsgo.Time(time.Now()), - } - - mockExpression, err := regexp.Compile("^cloud-nuke-*") - if err != nil { - logging.Logger.Fatalf("There was an error compiling regex expression %v", err) - } + t.Parallel() - mockExcludeConfig := config.Config{ - ECSService: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *mockExpression, - }, + testArn1 := "testArn1" + testArn2 := "testArn2" + testName1 := "testService1" + testName2 := "testService2" + now := time.Now() + es := ECSServices{ + Client: mockedEC2Service{ + ListClustersOutput: ecs.ListClustersOutput{ + ClusterArns: []*string{ + aws.String(testArn1), }, }, - }, - } - - mockIncludeConfig := config.Config{ - ECSService: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ + ListServicesOutput: ecs.ListServicesOutput{ + ServiceArns: []*string{ + aws.String(testArn1), + }, + }, + DescribeServicesOutput: ecs.DescribeServicesOutput{ + Services: []*ecs.Service{ + { + ServiceArn: aws.String(testArn1), + ServiceName: aws.String(testName1), + CreatedAt: aws.Time(now), + }, { - RE: *mockExpression, + ServiceArn: aws.String(testArn2), + ServiceName: aws.String(testName2), + CreatedAt: aws.Time(now.Add(1)), }, }, }, }, } - cases := []struct { - Name string - Service *ecs.Service - Config config.Config - ExcludeAfter time.Time - Expected bool + tests := map[string]struct { + configObj config.ResourceType + expected []string }{ - { - Name: "ConfigExclude", - Service: mockService, - Config: mockExcludeConfig, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Expected: false, + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testArn1, testArn2}, }, - { - Name: "ConfigInclude", - Service: mockService, - Config: mockIncludeConfig, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Expected: true, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testArn2}, }, - { - Name: "NotOlderThan", - Service: mockService, - Config: config.Config{}, - ExcludeAfter: time.Now().Add(1 * time.Hour * -1), - Expected: false, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }}, + expected: []string{testArn1}, }, } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - result := shouldIncludeECSService(c.Service, c.ExcludeAfter, c.Config) - assert.Equal(t, c.Expected, result) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := es.getAll(config.Config{ + ECSService: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.StringValueSlice(names)) }) } + +} + +func TestEC2Service_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() + + es := ECSServices{ + Client: mockedEC2Service{ + DescribeServicesOutput: ecs.DescribeServicesOutput{ + Services: []*ecs.Service{ + { + ServiceArn: aws.String("testArn1"), + SchedulingStrategy: aws.String(ecs.SchedulingStrategyDaemon), + }, + }, + }, + UpdateServiceOutput: ecs.UpdateServiceOutput{}, + DeleteServiceOutput: ecs.DeleteServiceOutput{}, + }, + } + + err := es.nukeAll([]*string{aws.String("testArn1")}) + require.NoError(t, err) } diff --git a/aws/ecs_service_types.go b/aws/ecs_service_types.go index a9d7c8d4..c1d77112 100644 --- a/aws/ecs_service_types.go +++ b/aws/ecs_service_types.go @@ -31,7 +31,7 @@ func (services ECSServices) MaxBatchSize() int { // Nuke - nuke all ECS service resources func (services ECSServices) Nuke(awsSession *session.Session, identifiers []string) error { - if err := nukeAllEcsServices(awsSession, services.ServiceClusterMap, awsgo.StringSlice(identifiers)); err != nil { + if err := services.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) } return nil diff --git a/aws/ecs_utils_for_test.go b/aws/ecs_utils_for_test.go deleted file mode 100644 index 51e8835d..00000000 --- a/aws/ecs_utils_for_test.go +++ /dev/null @@ -1,395 +0,0 @@ -package aws - -import ( - "encoding/base64" - "errors" - "fmt" - "math/rand" - "testing" - "time" - - "github.com/gruntwork-io/cloud-nuke/config" - - awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/gruntwork-io/go-commons/collections" - gruntworkerrors "github.com/gruntwork-io/go-commons/errors" - "github.com/gruntwork-io/terratest/modules/retry" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/gruntwork-io/cloud-nuke/logging" -) - -// We black list us-east-1e because this zone is frequently out of capacity -var AvailabilityZoneBlackList = []string{"us-east-1e"} - -// getRandomFargateSupportedRegion - Returns a random AWS -// region that supports Fargate. -// Refer to https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/ -func getRandomFargateSupportedRegion() string { - supportedRegions := []string{ - "us-east-1", "us-east-2", "us-west-2", - "eu-central-1", "eu-west-1", - "ap-southeast-1", "ap-southeast-2", "ap-northeast-1", - } - rand.Seed(time.Now().UnixNano()) - randIndex := rand.Intn(len(supportedRegions)) - return supportedRegions[randIndex] -} - -func createEcsFargateCluster(t *testing.T, awsSession *session.Session, name string) ecs.Cluster { - logging.Logger.Infof("Creating ECS cluster %s in region %s", name, *awsSession.Config.Region) - - svc := ecs.New(awsSession) - result, err := svc.CreateCluster(&ecs.CreateClusterInput{ClusterName: awsgo.String(name)}) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - _, err = retry.DoWithRetryE( - t, - "Wait for ECS cluster to be created", - 4, - 15*time.Second, - func() (string, error) { - ecsClusters, err := getAllActiveEcsClusterArns(awsSession, config.Config{}) - if err != nil { - return "", retry.FatalError{Underlying: err} - } - - for _, value := range awsgo.StringValueSlice(ecsClusters) { - if value == awsgo.StringValue(result.Cluster.ClusterArn) { - return "", nil - } - } - return "", errors.New("cluster not found") - }, - ) - require.NoError(t, err) - - return *result.Cluster -} - -func createEcsEC2Cluster(t *testing.T, awsSession *session.Session, name string, instanceProfile iam.InstanceProfile) (ecs.Cluster, ec2.Instance) { - cluster := createEcsFargateCluster(t, awsSession, name) - - ec2Svc := ec2.New(awsSession) - imageID, err := getAMIIdByName(ec2Svc, "amzn-ami-2018.03.20211120-amazon-ecs-optimized") - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - - rawUserDataText := fmt.Sprintf("#!/bin/bash\necho 'ECS_CLUSTER=%s' >> /etc/ecs/ecs.config", *cluster.ClusterName) - userDataText := base64.StdEncoding.EncodeToString([]byte(rawUserDataText)) - - instanceProfileSpecification := &ec2.IamInstanceProfileSpecification{ - Arn: instanceProfile.Arn, - } - params := &ec2.RunInstancesInput{ - ImageId: awsgo.String(imageID), - InstanceType: awsgo.String("t3.small"), - MinCount: awsgo.Int64(1), - MaxCount: awsgo.Int64(1), - DisableApiTermination: awsgo.Bool(false), - IamInstanceProfile: instanceProfileSpecification, - UserData: awsgo.String(userDataText), - } - instance, err := runAndWaitForInstance(ec2Svc, name, params) - if err != nil { - assert.Fail(t, err.Error()) - } - - // At this point we assume the instance successfully - // registered itself to the cluster - return cluster, instance -} - -func deleteEcsCluster(awsSession *session.Session, cluster ecs.Cluster) error { - svc := ecs.New(awsSession) - params := &ecs.DeleteClusterInput{Cluster: cluster.ClusterArn} - _, err := svc.DeleteCluster(params) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - return nil -} - -func createEcsService(t *testing.T, awsSession *session.Session, serviceName string, cluster ecs.Cluster, launchType string, taskDefinition ecs.TaskDefinition, schedulingStrategy string) ecs.Service { - svc := ecs.New(awsSession) - createServiceParams := &ecs.CreateServiceInput{ - Cluster: cluster.ClusterArn, - LaunchType: awsgo.String(launchType), - ServiceName: awsgo.String(serviceName), - TaskDefinition: taskDefinition.TaskDefinitionArn, - } - if launchType == "EC2" && schedulingStrategy == "DAEMON" { - createServiceParams.SetSchedulingStrategy(schedulingStrategy) - } else { - createServiceParams.SetDesiredCount(1) - } - - if launchType == "FARGATE" { - vpcConfiguration, err := getVpcConfiguration(awsSession) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - networkConfiguration := &ecs.NetworkConfiguration{ - AwsvpcConfiguration: &vpcConfiguration, - } - createServiceParams.SetNetworkConfiguration(networkConfiguration) - } - result, err := svc.CreateService(createServiceParams) - require.NoError(t, err) - - // Wait for the service to come up before continuing. We try at most two times to wait for the service. Oftentimes - // the service wait times out on the first try, but eventually succeeds. - retry.DoWithRetry( - t, - fmt.Sprintf("Waiting for service %s to be stable", awsgo.StringValue(result.Service.ServiceArn)), - 2, - 0*time.Second, - func() (string, error) { - err := svc.WaitUntilServicesStable(&ecs.DescribeServicesInput{ - Cluster: cluster.ClusterArn, - Services: []*string{result.Service.ServiceArn}, - }) - return "", err - }, - ) - - return *result.Service -} - -func createEcsTaskDefinition(t *testing.T, awsSession *session.Session, taskFamilyName string, launchType string) ecs.TaskDefinition { - svc := ecs.New(awsSession) - containerDefinition := &ecs.ContainerDefinition{ - Image: awsgo.String("nginx:latest"), - Name: awsgo.String("nginx"), - } - registerTaskParams := &ecs.RegisterTaskDefinitionInput{ - ContainerDefinitions: []*ecs.ContainerDefinition{containerDefinition}, - Cpu: awsgo.String("256"), - Memory: awsgo.String("512"), - Family: awsgo.String(taskFamilyName), - } - if launchType == "FARGATE" { - registerTaskParams.SetNetworkMode("awsvpc") - registerTaskParams.SetRequiresCompatibilities([]*string{awsgo.String("FARGATE")}) - } - result, err := svc.RegisterTaskDefinition(registerTaskParams) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - return *result.TaskDefinition -} - -func deleteEcsTaskDefinition(awsSession *session.Session, taskDefinition ecs.TaskDefinition) error { - svc := ecs.New(awsSession) - deregisterTaskDefinitionParams := &ecs.DeregisterTaskDefinitionInput{ - TaskDefinition: taskDefinition.TaskDefinitionArn, - } - _, err := svc.DeregisterTaskDefinition(deregisterTaskDefinitionParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - return nil -} - -func createEcsInstanceProfile(t *testing.T, awsSession *session.Session, instanceProfileName string, role iam.Role) iam.InstanceProfile { - svc := iam.New(awsSession) - createInstanceProfileParams := &iam.CreateInstanceProfileInput{ - InstanceProfileName: awsgo.String(instanceProfileName), - } - result, err := svc.CreateInstanceProfile(createInstanceProfileParams) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - instanceProfile := result.InstanceProfile - addRoleToInstanceProfileParams := &iam.AddRoleToInstanceProfileInput{ - InstanceProfileName: instanceProfile.InstanceProfileName, - RoleName: role.RoleName, - } - _, err = svc.AddRoleToInstanceProfile(addRoleToInstanceProfileParams) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - return *instanceProfile -} - -func deleteInstanceProfile(awsSession *session.Session, instanceProfile iam.InstanceProfile) error { - svc := iam.New(awsSession) - getInstanceProfileParams := &iam.GetInstanceProfileInput{ - InstanceProfileName: instanceProfile.InstanceProfileName, - } - result, err := svc.GetInstanceProfile(getInstanceProfileParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - refreshedInstanceProfile := result.InstanceProfile - for _, role := range refreshedInstanceProfile.Roles { - removeRoleParams := &iam.RemoveRoleFromInstanceProfileInput{ - InstanceProfileName: refreshedInstanceProfile.InstanceProfileName, - RoleName: role.RoleName, - } - _, err := svc.RemoveRoleFromInstanceProfile(removeRoleParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - } - deleteInstanceProfileParams := &iam.DeleteInstanceProfileInput{ - InstanceProfileName: refreshedInstanceProfile.InstanceProfileName, - } - _, err = svc.DeleteInstanceProfile(deleteInstanceProfileParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - return nil - -} - -func createEcsRole(t *testing.T, awsSession *session.Session, roleName string) iam.Role { - svc := iam.New(awsSession) - createRoleParams := &iam.CreateRoleInput{ - AssumeRolePolicyDocument: awsgo.String(ECS_ASSUME_ROLE_POLICY), - RoleName: awsgo.String(roleName), - } - result, err := svc.CreateRole(createRoleParams) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - putRolePolicyParams := &iam.PutRolePolicyInput{ - RoleName: awsgo.String(roleName), - PolicyDocument: awsgo.String(ECS_ROLE_POLICY), - PolicyName: awsgo.String(roleName + "Policy"), - } - _, err = svc.PutRolePolicy(putRolePolicyParams) - if err != nil { - assert.Fail(t, gruntworkerrors.WithStackTrace(err).Error()) - } - return *result.Role -} - -func deleteRole(awsSession *session.Session, role iam.Role) error { - svc := iam.New(awsSession) - listRolePoliciesParams := &iam.ListRolePoliciesInput{ - RoleName: role.RoleName, - } - result, err := svc.ListRolePolicies(listRolePoliciesParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - for _, policyName := range result.PolicyNames { - deleteRolePolicyParams := &iam.DeleteRolePolicyInput{ - RoleName: role.RoleName, - PolicyName: policyName, - } - _, err := svc.DeleteRolePolicy(deleteRolePolicyParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - } - deleteRoleParams := &iam.DeleteRoleInput{ - RoleName: role.RoleName, - } - _, err = svc.DeleteRole(deleteRoleParams) - if err != nil { - return gruntworkerrors.WithStackTrace(err) - } - return nil -} - -func getVpcConfiguration(awsSession *session.Session) (ecs.AwsVpcConfiguration, error) { - ec2Svc := ec2.New(awsSession) - describeVpcsParams := &ec2.DescribeVpcsInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: awsgo.String("isDefault"), - Values: []*string{awsgo.String("true")}, - }, - }, - } - vpcs, err := ec2Svc.DescribeVpcs(describeVpcsParams) - if err != nil { - return ecs.AwsVpcConfiguration{}, gruntworkerrors.WithStackTrace(err) - } - if len(vpcs.Vpcs) == 0 { - err := errors.New(fmt.Sprintf("Could not find any default VPC in region %s", *awsSession.Config.Region)) - return ecs.AwsVpcConfiguration{}, gruntworkerrors.WithStackTrace(err) - } - defaultVpc := vpcs.Vpcs[0] - - describeSubnetsParams := &ec2.DescribeSubnetsInput{ - Filters: []*ec2.Filter{ - &ec2.Filter{ - Name: awsgo.String("vpc-id"), - Values: []*string{defaultVpc.VpcId}, - }, - }, - } - subnets, err := ec2Svc.DescribeSubnets(describeSubnetsParams) - if err != nil { - return ecs.AwsVpcConfiguration{}, gruntworkerrors.WithStackTrace(err) - } - if len(subnets.Subnets) == 0 { - err := errors.New(fmt.Sprintf("Could not find any subnets for default VPC in region %s", *awsSession.Config.Region)) - return ecs.AwsVpcConfiguration{}, gruntworkerrors.WithStackTrace(err) - } - var subnetIds []*string - for _, subnet := range subnets.Subnets { - // Only use public subnets for testing simplicity - if !collections.ListContainsElement(AvailabilityZoneBlackList, awsgo.StringValue(subnet.AvailabilityZone)) && awsgo.BoolValue(subnet.MapPublicIpOnLaunch) { - subnetIds = append(subnetIds, subnet.SubnetId) - } - } - vpcConfig := ecs.AwsVpcConfiguration{ - Subnets: subnetIds, - AssignPublicIp: awsgo.String(ecs.AssignPublicIpEnabled), - } - return vpcConfig, nil -} - -const ECS_ASSUME_ROLE_POLICY = `{ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Principal": { - "Service": "ec2.amazonaws.com" - }, - "Action": "sts:AssumeRole" - } - ] -}` - -const ECS_ROLE_POLICY = `{ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - "ecr:BatchCheckLayerAvailability", - "ecr:BatchGetImage", - "ecr:DescribeRepositories", - "ecr:GetAuthorizationToken", - "ecr:GetDownloadUrlForLayer", - "ecr:GetRepositoryPolicy", - "ecr:ListImages", - "ecs:CreateCluster", - "ecs:DeregisterContainerInstance", - "ecs:DiscoverPollEndpoint", - "ecs:Poll", - "ecs:RegisterContainerInstance", - "ecs:StartTask", - "ecs:StartTelemetrySession", - "ecs:SubmitContainerStateChange", - "ecs:SubmitTaskStateChange" - ], - "Resource": [ - "*" - ] - } - ] -}`