diff --git a/aws/aws.go b/aws/aws.go index 8347bea1..105ddb0c 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -1314,7 +1314,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp KeyPairs := EC2KeyPairs{} if IsNukeable(KeyPairs.ResourceName(), resourceTypes) { start := time.Now() - keyPairIds, err := getAllEc2KeyPairs(cloudNukeSession, excludeAfter, configObj) + keyPairIds, err := KeyPairs.getAll(configObj) if err != nil { return nil, errors.WithStackTrace(err) } diff --git a/aws/ec2_key_pair.go b/aws/ec2_key_pair.go index de2bfd20..748a272d 100644 --- a/aws/ec2_key_pair.go +++ b/aws/ec2_key_pair.go @@ -1,30 +1,28 @@ package aws import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" - commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" - "time" - - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" + "github.com/gruntwork-io/cloud-nuke/telemetry" + commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" "github.com/gruntwork-io/gruntwork-cli/errors" "github.com/hashicorp/go-multierror" ) // getAllEc2KeyPairs extracts the list of existing ec2 key pairs. -func getAllEc2KeyPairs(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) { - svc := ec2.New(session) - - result, err := svc.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{}) +func (k EC2KeyPairs) getAll(configObj config.Config) ([]*string, error) { + result, err := k.Client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{}) if err != nil { return nil, errors.WithStackTrace(err) } var ids []*string for _, keyPair := range result.KeyPairs { - if shouldIncludeEc2KeyPair(keyPair, excludeAfter, configObj) { + if configObj.EC2KeyPairs.ShouldInclude(config.ResourceValue{ + Name: keyPair.KeyName, + Time: keyPair.CreateTime, + }) { ids = append(ids, keyPair.KeyPairId) } } @@ -32,29 +30,13 @@ func getAllEc2KeyPairs(session *session.Session, excludeAfter time.Time, configO return ids, nil } -func shouldIncludeEc2KeyPair(keyPairInfo *ec2.KeyPairInfo, excludeAfter time.Time, configObj config.Config) bool { - if keyPairInfo == nil || keyPairInfo.KeyName == nil { - return false - } - - if keyPairInfo.CreateTime != nil && excludeAfter.Before(*keyPairInfo.CreateTime) { - return false - } - - return config.ShouldInclude( - *keyPairInfo.KeyName, - configObj.EC2KeyPairs.IncludeRule.NamesRegExp, - configObj.EC2KeyPairs.ExcludeRule.NamesRegExp, - ) -} - // deleteKeyPair is a helper method that deletes the given ec2 key pair. -func deleteKeyPair(svc *ec2.EC2, keyPairId *string) error { +func (k EC2KeyPairs) deleteKeyPair(keyPairId *string) error { params := &ec2.DeleteKeyPairInput{ KeyPairId: keyPairId, } - _, err := svc.DeleteKeyPair(params) + _, err := k.Client.DeleteKeyPair(params) if err != nil { return errors.WithStackTrace(err) } @@ -63,24 +45,22 @@ func deleteKeyPair(svc *ec2.EC2, keyPairId *string) error { } // nukeAllEc2KeyPairs attempts to delete given ec2 key pair IDs. -func nukeAllEc2KeyPairs(session *session.Session, keypairIds []*string) error { - svc := ec2.New(session) - +func (k EC2KeyPairs) nukeAll(keypairIds []*string) error { if len(keypairIds) == 0 { - logging.Logger.Infof("No EC2 key pairs to nuke in region %s", *session.Config.Region) + logging.Logger.Infof("No EC2 key pairs to nuke in region %s", k.Region) return nil } - logging.Logger.Infof("Terminating all EC2 key pairs in region %s", *session.Config.Region) + logging.Logger.Infof("Terminating all EC2 key pairs in region %s", k.Region) deletedKeyPairs := 0 var multiErr *multierror.Error for _, keypair := range keypairIds { - if err := deleteKeyPair(svc, keypair); err != nil { + if err := k.deleteKeyPair(keypair); err != nil { telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking EC2 Key Pair", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": k.Region, }) logging.Logger.Errorf("[Failed] %s", err) multiErr = multierror.Append(multiErr, err) diff --git a/aws/ec2_key_pair_test.go b/aws/ec2_key_pair_test.go index b9247908..40752861 100644 --- a/aws/ec2_key_pair_test.go +++ b/aws/ec2_key_pair_test.go @@ -1,97 +1,106 @@ package aws import ( - "fmt" + "github.com/aws/aws-sdk-go-v2/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/telemetry" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "regexp" "testing" "time" ) -// createTestEc2KeyPair is a helper method to create a test ec2 key pair -func createTestEc2KeyPair(t *testing.T, svc *ec2.EC2) *ec2.CreateKeyPairOutput { - keyPair, err := svc.CreateKeyPair(&ec2.CreateKeyPairInput{ - KeyName: awsgo.String(util.UniqueID()), - }) - - require.NoError(t, err) +type mockedEC2KeyPairs struct { + ec2iface.EC2API + DescribeKeyPairsOutput ec2.DescribeKeyPairsOutput + DeleteKeyPairOutput ec2.DeleteKeyPairOutput +} - err = svc.WaitUntilKeyPairExists(&ec2.DescribeKeyPairsInput{ - KeyPairIds: awsgo.StringSlice([]string{*keyPair.KeyPairId}), - }) +func (m mockedEC2KeyPairs) DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error) { + return &m.DescribeKeyPairsOutput, nil +} - require.NoError(t, err) - return keyPair +func (m mockedEC2KeyPairs) DeleteKeyPair(input *ec2.DeleteKeyPairInput) (*ec2.DeleteKeyPairOutput, error) { + return &m.DeleteKeyPairOutput, nil } -func TestEc2KeyPairListAndNuke(t *testing.T) { +func TestEC2KeyPairs_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) - - testSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - - require.NoError(t, err) - - svc := ec2.New(testSession) - createdKeyPair := createTestEc2KeyPair(t, svc) - testExcludeAfterTime := time.Now().Add(24 * time.Hour) - keyPairIds, err := getAllEc2KeyPairs(testSession, testExcludeAfterTime, config.Config{}) - - assert.Contains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair.KeyPairId) - - // Note: nuking the ec2 key pair created for testing purpose - err = nukeAllEc2KeyPairs(testSession, []*string{createdKeyPair.KeyPairId}) - require.NoError(t, err) + now := time.Now() + testId1 := "test-keypair-id1" + testName1 := "test-keypair1" + testId2 := "test-keypair-id2" + testName2 := "test-keypair2" + k := EC2KeyPairs{ + Client: mockedEC2KeyPairs{ + DescribeKeyPairsOutput: ec2.DescribeKeyPairsOutput{ + KeyPairs: []*ec2.KeyPairInfo{ + { + KeyName: awsgo.String(testName1), + KeyPairId: awsgo.String(testId1), + CreateTime: awsgo.Time(now), + }, + { + KeyName: awsgo.String(testName2), + KeyPairId: awsgo.String(testId2), + CreateTime: awsgo.Time(now.Add(1)), + }, + }, + }, + }, + } - // Check whether the key still exist or not. - keyPairIds, err = getAllEc2KeyPairs(testSession, testExcludeAfterTime, config.Config{}) - require.NoError(t, err) - require.NotContains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair.KeyPairId) + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testId1, testId2}, + }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testId2}, + }, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now), + }}, + expected: []string{testId1}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := k.getAll(config.Config{ + EC2KeyPairs: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, awsgo.StringValueSlice(names)) + }) + } } -func TestEc2KeyPairListWithConfig(t *testing.T) { +func TestEC2KeyPairs_NukeAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") - region, err := getRandomRegion() - require.NoError(t, err) - - testSession, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region)}, - ) - - require.NoError(t, err) - - svc := ec2.New(testSession) - createdKeyPair := createTestEc2KeyPair(t, svc) - createdKeyPair2 := createTestEc2KeyPair(t, svc) + t.Parallel() - // Regex expression to not include first key pair - nameRegexExp, err := regexp.Compile(fmt.Sprintf("^%s*", *createdKeyPair.KeyName)) - excludeConfig := config.Config{ - EC2KeyPairs: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *nameRegexExp, - }, - }, - }, + h := EC2KeyPairs{ + Client: mockedEC2KeyPairs{ + DeleteKeyPairOutput: ec2.DeleteKeyPairOutput{}, }, } - testExcludeAfterTime := time.Now().Add(24 * time.Hour) - keyPairIds, err := getAllEc2KeyPairs(testSession, testExcludeAfterTime, excludeConfig) - assert.NotContains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair.KeyPairId) - assert.Contains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair2.KeyPairId) + err := h.nukeAll([]*string{awsgo.String("test-keypair-id-1"), awsgo.String("test-keypair-id-2")}) + require.NoError(t, err) } diff --git a/aws/ec2_key_pair_types.go b/aws/ec2_key_pair_types.go index 8ad22f46..470da362 100644 --- a/aws/ec2_key_pair_types.go +++ b/aws/ec2_key_pair_types.go @@ -29,7 +29,7 @@ func (k EC2KeyPairs) MaxBatchSize() int { } func (k EC2KeyPairs) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllEc2KeyPairs(session, awsgo.StringSlice(identifiers)); err != nil { + if err := k.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) }