Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor EC2 keypair #551

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
KeyPairs := EC2KeyPairs{}
if IsNukeable(KeyPairs.ResourceName(), resourceTypes) {
start := time.Now()
keyPairIds, err := getAllEc2KeyPairs(cloudNukeSession, excludeAfter, configObj)
keyPairIds, err := KeyPairs.getAll(configObj)
if err != nil {
return nil, errors.WithStackTrace(err)
}
Expand Down
50 changes: 15 additions & 35 deletions aws/ec2_key_pair.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,42 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
"time"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
"github.com/gruntwork-io/gruntwork-cli/errors"
"github.com/hashicorp/go-multierror"
)

// getAllEc2KeyPairs extracts the list of existing ec2 key pairs.
func getAllEc2KeyPairs(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) {
svc := ec2.New(session)

result, err := svc.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{})
func (k EC2KeyPairs) getAll(configObj config.Config) ([]*string, error) {
result, err := k.Client.DescribeKeyPairs(&ec2.DescribeKeyPairsInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, keyPair := range result.KeyPairs {
if shouldIncludeEc2KeyPair(keyPair, excludeAfter, configObj) {
if configObj.EC2KeyPairs.ShouldInclude(config.ResourceValue{
Name: keyPair.KeyName,
Time: keyPair.CreateTime,
}) {
ids = append(ids, keyPair.KeyPairId)
}
}

return ids, nil
}

func shouldIncludeEc2KeyPair(keyPairInfo *ec2.KeyPairInfo, excludeAfter time.Time, configObj config.Config) bool {
if keyPairInfo == nil || keyPairInfo.KeyName == nil {
return false
}

if keyPairInfo.CreateTime != nil && excludeAfter.Before(*keyPairInfo.CreateTime) {
return false
}

return config.ShouldInclude(
*keyPairInfo.KeyName,
configObj.EC2KeyPairs.IncludeRule.NamesRegExp,
configObj.EC2KeyPairs.ExcludeRule.NamesRegExp,
)
}

// deleteKeyPair is a helper method that deletes the given ec2 key pair.
func deleteKeyPair(svc *ec2.EC2, keyPairId *string) error {
func (k EC2KeyPairs) deleteKeyPair(keyPairId *string) error {
params := &ec2.DeleteKeyPairInput{
KeyPairId: keyPairId,
}

_, err := svc.DeleteKeyPair(params)
_, err := k.Client.DeleteKeyPair(params)
if err != nil {
return errors.WithStackTrace(err)
}
Expand All @@ -63,24 +45,22 @@ func deleteKeyPair(svc *ec2.EC2, keyPairId *string) error {
}

// nukeAllEc2KeyPairs attempts to delete given ec2 key pair IDs.
func nukeAllEc2KeyPairs(session *session.Session, keypairIds []*string) error {
svc := ec2.New(session)

func (k EC2KeyPairs) nukeAll(keypairIds []*string) error {
if len(keypairIds) == 0 {
logging.Logger.Infof("No EC2 key pairs to nuke in region %s", *session.Config.Region)
logging.Logger.Infof("No EC2 key pairs to nuke in region %s", k.Region)
return nil
}

logging.Logger.Infof("Terminating all EC2 key pairs in region %s", *session.Config.Region)
logging.Logger.Infof("Terminating all EC2 key pairs in region %s", k.Region)

deletedKeyPairs := 0
var multiErr *multierror.Error
for _, keypair := range keypairIds {
if err := deleteKeyPair(svc, keypair); err != nil {
if err := k.deleteKeyPair(keypair); err != nil {
telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking EC2 Key Pair",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": k.Region,
})
logging.Logger.Errorf("[Failed] %s", err)
multiErr = multierror.Append(multiErr, err)
Expand Down
145 changes: 77 additions & 68 deletions aws/ec2_key_pair_test.go
Original file line number Diff line number Diff line change
@@ -1,97 +1,106 @@
package aws

import (
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/gruntwork-io/cloud-nuke/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"regexp"
"testing"
"time"
)

// createTestEc2KeyPair is a helper method to create a test ec2 key pair
func createTestEc2KeyPair(t *testing.T, svc *ec2.EC2) *ec2.CreateKeyPairOutput {
keyPair, err := svc.CreateKeyPair(&ec2.CreateKeyPairInput{
KeyName: awsgo.String(util.UniqueID()),
})

require.NoError(t, err)
type mockedEC2KeyPairs struct {
ec2iface.EC2API
DescribeKeyPairsOutput ec2.DescribeKeyPairsOutput
DeleteKeyPairOutput ec2.DeleteKeyPairOutput
}

err = svc.WaitUntilKeyPairExists(&ec2.DescribeKeyPairsInput{
KeyPairIds: awsgo.StringSlice([]string{*keyPair.KeyPairId}),
})
func (m mockedEC2KeyPairs) DescribeKeyPairs(input *ec2.DescribeKeyPairsInput) (*ec2.DescribeKeyPairsOutput, error) {
return &m.DescribeKeyPairsOutput, nil
}

require.NoError(t, err)
return keyPair
func (m mockedEC2KeyPairs) DeleteKeyPair(input *ec2.DeleteKeyPairInput) (*ec2.DeleteKeyPairOutput, error) {
return &m.DeleteKeyPairOutput, nil
}

func TestEc2KeyPairListAndNuke(t *testing.T) {
func TestEC2KeyPairs_GetAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)

testSession, err := session.NewSession(&awsgo.Config{
Region: awsgo.String(region)},
)

require.NoError(t, err)

svc := ec2.New(testSession)
createdKeyPair := createTestEc2KeyPair(t, svc)
testExcludeAfterTime := time.Now().Add(24 * time.Hour)
keyPairIds, err := getAllEc2KeyPairs(testSession, testExcludeAfterTime, config.Config{})

assert.Contains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair.KeyPairId)

// Note: nuking the ec2 key pair created for testing purpose
err = nukeAllEc2KeyPairs(testSession, []*string{createdKeyPair.KeyPairId})
require.NoError(t, err)
now := time.Now()
testId1 := "test-keypair-id1"
testName1 := "test-keypair1"
testId2 := "test-keypair-id2"
testName2 := "test-keypair2"
k := EC2KeyPairs{
Client: mockedEC2KeyPairs{
DescribeKeyPairsOutput: ec2.DescribeKeyPairsOutput{
KeyPairs: []*ec2.KeyPairInfo{
{
KeyName: awsgo.String(testName1),
KeyPairId: awsgo.String(testId1),
CreateTime: awsgo.Time(now),
},
{
KeyName: awsgo.String(testName2),
KeyPairId: awsgo.String(testId2),
CreateTime: awsgo.Time(now.Add(1)),
},
},
},
},
}

// Check whether the key still exist or not.
keyPairIds, err = getAllEc2KeyPairs(testSession, testExcludeAfterTime, config.Config{})
require.NoError(t, err)
require.NotContains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair.KeyPairId)
tests := map[string]struct {
configObj config.ResourceType
expected []string
}{
"emptyFilter": {
configObj: config.ResourceType{},
expected: []string{testId1, testId2},
},
"nameExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
NamesRegExp: []config.Expression{{
RE: *regexp.MustCompile(testName1),
}}},
},
expected: []string{testId2},
},
"timeAfterExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now),
}},
expected: []string{testId1},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
names, err := k.getAll(config.Config{
EC2KeyPairs: tc.configObj,
})
require.NoError(t, err)
require.Equal(t, tc.expected, awsgo.StringValueSlice(names))
})
}
}

func TestEc2KeyPairListWithConfig(t *testing.T) {
func TestEC2KeyPairs_NukeAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
region, err := getRandomRegion()
require.NoError(t, err)

testSession, err := session.NewSession(&awsgo.Config{
Region: awsgo.String(region)},
)

require.NoError(t, err)

svc := ec2.New(testSession)
createdKeyPair := createTestEc2KeyPair(t, svc)
createdKeyPair2 := createTestEc2KeyPair(t, svc)
t.Parallel()

// Regex expression to not include first key pair
nameRegexExp, err := regexp.Compile(fmt.Sprintf("^%s*", *createdKeyPair.KeyName))
excludeConfig := config.Config{
EC2KeyPairs: config.ResourceType{
ExcludeRule: config.FilterRule{
NamesRegExp: []config.Expression{
{
RE: *nameRegexExp,
},
},
},
h := EC2KeyPairs{
Client: mockedEC2KeyPairs{
DeleteKeyPairOutput: ec2.DeleteKeyPairOutput{},
},
}

testExcludeAfterTime := time.Now().Add(24 * time.Hour)
keyPairIds, err := getAllEc2KeyPairs(testSession, testExcludeAfterTime, excludeConfig)
assert.NotContains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair.KeyPairId)
assert.Contains(t, awsgo.StringValueSlice(keyPairIds), *createdKeyPair2.KeyPairId)
err := h.nukeAll([]*string{awsgo.String("test-keypair-id-1"), awsgo.String("test-keypair-id-2")})
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion aws/ec2_key_pair_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (k EC2KeyPairs) MaxBatchSize() int {
}

func (k EC2KeyPairs) Nuke(session *session.Session, identifiers []string) error {
if err := nukeAllEc2KeyPairs(session, awsgo.StringSlice(identifiers)); err != nil {
if err := k.nukeAll(awsgo.StringSlice(identifiers)); err != nil {
return errors.WithStackTrace(err)
}

Expand Down