diff --git a/aws/aws.go b/aws/aws.go index 44b0944b..e6f1d61b 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -1584,7 +1584,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(redshiftClusters.ResourceName(), resourceTypes) { start := time.Now() - clusters, err := getAllRedshiftClusters(cloudNukeSession, region, excludeAfter, configObj) + clusters, err := redshiftClusters.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/redshift.go b/aws/redshift.go index 38a48de6..b09e5659 100644 --- a/aws/redshift.go +++ b/aws/redshift.go @@ -2,7 +2,6 @@ package aws import ( "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/redshift" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" @@ -10,55 +9,46 @@ import ( "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/gruntwork-io/go-commons/errors" commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" - "time" ) -func getAllRedshiftClusters(session *session.Session, region string, excludeAfter time.Time, configObj config.Config) ([]*string, error) { - svc := redshift.New(session) +func (rc RedshiftClusters) getAll(configObj config.Config) ([]*string, error) { var clusterIds []*string - err := svc.DescribeClustersPages( + err := rc.Client.DescribeClustersPages( &redshift.DescribeClustersInput{}, func(page *redshift.DescribeClustersOutput, lastPage bool) bool { for _, cluster := range page.Clusters { - if shouldIncludeRedshiftCluster(cluster, excludeAfter, configObj) { + if configObj.Redshift.ShouldInclude(config.ResourceValue{ + Time: cluster.ClusterCreateTime, + Name: cluster.ClusterIdentifier, + }) { clusterIds = append(clusterIds, cluster.ClusterIdentifier) } } + return !lastPage }, ) - return clusterIds, errors.WithStackTrace(err) -} -func shouldIncludeRedshiftCluster(cluster *redshift.Cluster, excludeAfter time.Time, configObj config.Config) bool { - if cluster == nil { - return false - } - if excludeAfter.Before(*cluster.ClusterCreateTime) { - return false - } - return config.ShouldInclude( - aws.StringValue(cluster.ClusterIdentifier), - configObj.Redshift.IncludeRule.NamesRegExp, - configObj.Redshift.ExcludeRule.NamesRegExp, - ) + return clusterIds, errors.WithStackTrace(err) } -func nukeAllRedshiftClusters(session *session.Session, identifiers []*string) error { - svc := redshift.New(session) +func (rc RedshiftClusters) nukeAll(identifiers []*string) error { if len(identifiers) == 0 { - logging.Logger.Debugf("No Redshift Clusters to nuke in region %s", *session.Config.Region) + logging.Logger.Debugf("No Redshift Clusters to nuke in region %s", rc.Region) return nil } - logging.Logger.Debugf("Deleting all Redshift Clusters in region %s", *session.Config.Region) + logging.Logger.Debugf("Deleting all Redshift Clusters in region %s", rc.Region) deletedIds := []*string{} for _, id := range identifiers { - _, err := svc.DeleteCluster(&redshift.DeleteClusterInput{ClusterIdentifier: id, SkipFinalClusterSnapshot: aws.Bool(true)}) + _, err := rc.Client.DeleteCluster(&redshift.DeleteClusterInput{ + ClusterIdentifier: id, + SkipFinalClusterSnapshot: aws.Bool(true), + }) if err != nil { telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking RedshiftCluster", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": rc.Region, }) logging.Logger.Errorf("[Failed] %s: %s", *id, err) } else { @@ -69,7 +59,7 @@ func nukeAllRedshiftClusters(session *session.Session, identifiers []*string) er if len(deletedIds) > 0 { for _, id := range deletedIds { - err := svc.WaitUntilClusterDeleted(&redshift.DescribeClustersInput{ClusterIdentifier: id}) + err := rc.Client.WaitUntilClusterDeleted(&redshift.DescribeClustersInput{ClusterIdentifier: id}) // Record status of this resource e := report.Entry{ Identifier: aws.StringValue(id), @@ -81,13 +71,13 @@ func nukeAllRedshiftClusters(session *session.Session, identifiers []*string) er telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking Redshift Cluster", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": rc.Region, }) logging.Logger.Errorf("[Failed] %s", err) return errors.WithStackTrace(err) } } } - logging.Logger.Debugf("[OK] %d Redshift Cluster(s) deleted in %s", len(deletedIds), *session.Config.Region) + logging.Logger.Debugf("[OK] %d Redshift Cluster(s) deleted in %s", len(deletedIds), rc.Region) return nil } diff --git a/aws/redshift_test.go b/aws/redshift_test.go index 44ffe22d..33998665 100644 --- a/aws/redshift_test.go +++ b/aws/redshift_test.go @@ -2,67 +2,106 @@ package aws import ( "github.com/aws/aws-sdk-go/aws" - awsSession "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/redshift" + "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/telemetry" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "strings" + "regexp" "testing" "time" ) -func TestNukeRedshiftClusters(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) +type mockedRedshift struct { + redshiftiface.RedshiftAPI - session, err := awsSession.NewSession(&aws.Config{ - Region: aws.String(region), - }) - require.NoError(t, err) + DeleteClusterOutput redshift.DeleteClusterOutput + DescribeClustersOutput redshift.DescribeClustersOutput +} + +func (m mockedRedshift) DescribeClustersPages(input *redshift.DescribeClustersInput, fn func(*redshift.DescribeClustersOutput, bool) bool) error { + fn(&m.DescribeClustersOutput, true) + return nil +} - svc := redshift.New(session) +func (m mockedRedshift) DeleteCluster(input *redshift.DeleteClusterInput) (*redshift.DeleteClusterOutput, error) { + return &m.DeleteClusterOutput, nil +} - clusterName := "test-" + strings.ToLower(util.UniqueID()) +func (m mockedRedshift) WaitUntilClusterDeleted(*redshift.DescribeClustersInput) error { + return nil +} - //create cluster - _, err = svc.CreateCluster( - &redshift.CreateClusterInput{ - ClusterIdentifier: aws.String(clusterName), - MasterUsername: aws.String("grunty"), - MasterUserPassword: aws.String("Gruntysecurepassword1"), - NodeType: aws.String("dc2.large"), - NumberOfNodes: aws.Int64(2), - }, - ) - require.NoError(t, err) - err = svc.WaitUntilClusterAvailable(&redshift.DescribeClustersInput{ - ClusterIdentifier: aws.String(clusterName), - }) - require.NoError(t, err) - defer svc.DeleteCluster(&redshift.DeleteClusterInput{ClusterIdentifier: aws.String(clusterName)}) +func TestRedshiftCluster_GetAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() - //Sleep for a minute for consistency in aws - sleepTime, err := time.ParseDuration("1m") - time.Sleep(sleepTime) + now := time.Now() + testName1 := "test-cluster1" + testName2 := "test-cluster2" + rc := RedshiftClusters{ + Client: mockedRedshift{ + DescribeClustersOutput: redshift.DescribeClustersOutput{ + Clusters: []*redshift.Cluster{ + { + ClusterIdentifier: aws.String(testName1), + ClusterCreateTime: aws.Time(now), + }, + { + ClusterIdentifier: aws.String(testName2), + ClusterCreateTime: aws.Time(now.Add(1)), + }, + }, + }, + }, + } - //test list clusters - clusters, err := getAllRedshiftClusters(session, region, time.Now().Add(1*time.Hour), config.Config{}) - require.NoError(t, err) + tests := map[string]struct { + configObj config.ResourceType + expected []string + }{ + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testName1, testName2}, + }, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testName2}, + }, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1 * time.Hour)), + }}, + expected: []string{}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := rc.getAll(config.Config{ + Redshift: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.StringValueSlice(names)) + }) + } +} - //Ensure our cluster exists - assert.Contains(t, aws.StringValueSlice(clusters), clusterName) +func TestRedshiftCluster_NukeAll(t *testing.T) { + telemetry.InitTelemetry("cloud-nuke", "") + t.Parallel() - //nuke cluster - err = nukeAllRedshiftClusters(session, aws.StringSlice([]string{clusterName})) - require.NoError(t, err) + rc := RedshiftClusters{ + Client: mockedRedshift{ + DeleteClusterOutput: redshift.DeleteClusterOutput{}, + }, + } - //check that the cluster no longer exists - clusters, err = getAllRedshiftClusters(session, region, time.Now().Add(1*time.Hour), config.Config{}) + err := rc.nukeAll([]*string{aws.String("test")}) require.NoError(t, err) - assert.NotContains(t, aws.StringValueSlice(clusters), aws.StringSlice([]string{clusterName})) } diff --git a/aws/redshift_types.go b/aws/redshift_types.go index d247244f..91651c36 100644 --- a/aws/redshift_types.go +++ b/aws/redshift_types.go @@ -13,23 +13,23 @@ type RedshiftClusters struct { ClusterIdentifiers []string } -func (cluster RedshiftClusters) ResourceName() string { +func (rc RedshiftClusters) ResourceName() string { return "redshift" } // ResourceIdentifiers - The instance names of the rds db instances -func (cluster RedshiftClusters) ResourceIdentifiers() []string { - return cluster.ClusterIdentifiers +func (rc RedshiftClusters) ResourceIdentifiers() []string { + return rc.ClusterIdentifiers } -func (cluster RedshiftClusters) MaxBatchSize() int { +func (rc RedshiftClusters) MaxBatchSize() int { // Tentative batch size to ensure AWS doesn't throttle return 49 } // Nuke - nuke 'em all!!! -func (cluster RedshiftClusters) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllRedshiftClusters(session, awsgo.StringSlice(identifiers)); err != nil { +func (rc RedshiftClusters) Nuke(session *session.Session, identifiers []string) error { + if err := rc.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) }