From 1c2211320d083228e8dde49400e2bcd9d7ae05c1 Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+hongil0316@users.noreply.github.com> Date: Wed, 26 Jul 2023 09:28:37 -0400 Subject: [PATCH] Refactor DynamoDB (#515) --- aws/aws.go | 2 +- aws/dynamodb.go | 90 +++++----------- aws/dynamodb_test.go | 233 +++++++++++++----------------------------- aws/dynamodb_types.go | 12 +-- 4 files changed, 104 insertions(+), 233 deletions(-) diff --git a/aws/aws.go b/aws/aws.go index 66ff6346..67700662 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -1260,7 +1260,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } if IsNukeable(DynamoDB.ResourceName(), resourceTypes) { start := time.Now() - tablenames, err := getAllDynamoTables(cloudNukeSession, excludeAfter, configObj, DynamoDB) + tablenames, err := DynamoDB.getAll(configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/dynamodb.go b/aws/dynamodb.go index 48968384..8205746e 100644 --- a/aws/dynamodb.go +++ b/aws/dynamodb.go @@ -1,98 +1,60 @@ package aws import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" - commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" - "log" - "time" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/telemetry" + commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" "github.com/gruntwork-io/gruntwork-cli/errors" + "log" ) -func getAllDynamoTables(session *session.Session, excludeAfter time.Time, configObj config.Config, db DynamoDB) ([]*string, error) { +func (ddb DynamoDB) getAll(configObj config.Config) ([]*string, error) { var tableNames []*string - svc := dynamodb.New(session) - var lastTableName *string - // Run count is used for pagination if the list tables exceeds max value - // Tells loop to rerun - PaginationRunCount := 1 - for PaginationRunCount > 0 { - result, err := svc.ListTables(&dynamodb.ListTablesInput{ExclusiveStartTableName: lastTableName, Limit: aws.Int64(int64(DynamoDB.MaxBatchSize(db)))}) - - lastTableName = result.LastEvaluatedTableName - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Error() { - case dynamodb.ErrCodeInternalServerError: - return nil, errors.WithStackTrace(aerr) - default: - return nil, errors.WithStackTrace(aerr) + err := ddb.Client.ListTablesPages( + &dynamodb.ListTablesInput{}, func(page *dynamodb.ListTablesOutput, lastPage bool) bool { + for _, table := range page.TableNames { + tableDetail, err := ddb.Client.DescribeTable(&dynamodb.DescribeTableInput{TableName: table}) + if err != nil { + log.Fatalf("There was an error describing table: %v\n", err) } - } - } - - tableLen := len(result.TableNames) - // Check table length if it matches the max value add 1 to rerun - if tableLen == DynamoDB.MaxBatchSize(db) { - // Tell the user that this will be run twice due to max tables detected - logging.Logger.Debugf("The tables detected exceed the 100. Running more than once") - // Adds one to the count as it will = 2 runs at least until this loops again to check if it's another max. - PaginationRunCount += 1 - } - for _, table := range result.TableNames { - responseDescription, err := svc.DescribeTable(&dynamodb.DescribeTableInput{TableName: table}) - if err != nil { - log.Fatalf("There was an error describing table: %v\n", err) - } - if shouldIncludeTable(responseDescription.Table, excludeAfter, configObj) { - tableNames = append(tableNames, table) + if configObj.DynamoDB.ShouldInclude(config.ResourceValue{ + Time: tableDetail.Table.CreationDateTime, + Name: tableDetail.Table.TableName, + }) { + tableNames = append(tableNames, table) + } } - } - // Remove 1 from the counter if it's one run the loop will end as PaginationRunCount will = 0 - PaginationRunCount -= 1 - } - return tableNames, nil -} -func shouldIncludeTable(table *dynamodb.TableDescription, excludeAfter time.Time, configObj config.Config) bool { - if table == nil { - return false - } + return !lastPage + }) - if table.CreationDateTime != nil && excludeAfter.Before(*table.CreationDateTime) { - return false + if err != nil { + return nil, err } - return config.ShouldInclude( - aws.StringValue(table.TableName), - configObj.DynamoDB.IncludeRule.NamesRegExp, - configObj.DynamoDB.ExcludeRule.NamesRegExp, - ) + return tableNames, nil } -func nukeAllDynamoDBTables(session *session.Session, tables []*string) error { - svc := dynamodb.New(session) +func (ddb DynamoDB) nukeAll(tables []*string) error { if len(tables) == 0 { - logging.Logger.Debugf("No DynamoDB tables to nuke in region %s", *session.Config.Region) + logging.Logger.Debugf("No DynamoDB tables to nuke in region %s", ddb.Region) return nil } - logging.Logger.Debugf("Deleting all DynamoDB tables in region %s", *session.Config.Region) + logging.Logger.Debugf("Deleting all DynamoDB tables in region %s", ddb.Region) for _, table := range tables { input := &dynamodb.DeleteTableInput{ TableName: aws.String(*table), } - _, err := svc.DeleteTable(input) + _, err := ddb.Client.DeleteTable(input) // Record status of this resource e := report.Entry{ @@ -107,7 +69,7 @@ func nukeAllDynamoDBTables(session *session.Session, tables []*string) error { telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking DynamoDB Table", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": ddb.Region, }) switch aerr.Error() { case dynamodb.ErrCodeInternalServerError: diff --git a/aws/dynamodb_test.go b/aws/dynamodb_test.go index 7edd7070..5d6f7d0d 100644 --- a/aws/dynamodb_test.go +++ b/aws/dynamodb_test.go @@ -1,208 +1,117 @@ package aws import ( + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/gruntwork-io/cloud-nuke/telemetry" - "log" "regexp" "testing" "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/gruntwork-io/cloud-nuke/config" - "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/gruntwork-cli/errors" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func createTestDynamoTables(t *testing.T, tableName, region string) { - awsSession, err := session.NewSession(&aws.Config{ - Region: aws.String(region)}, - ) - - svc := dynamodb.New(awsSession) - // THE INFORMATION TO CREATE THE TABLE - input := &dynamodb.CreateTableInput{ - TableName: &tableName, - AttributeDefinitions: []*dynamodb.AttributeDefinition{ - { - AttributeName: aws.String("Nuke" + string(rune(1))), - AttributeType: aws.String("S"), - }, - { - AttributeName: aws.String("TypeofNuke" + string(rune(1))), - AttributeType: aws.String("S"), - }, - }, - KeySchema: []*dynamodb.KeySchemaElement{ - { - AttributeName: aws.String("Nuke" + string(rune(1))), - KeyType: aws.String("HASH"), - }, - { - AttributeName: aws.String("TypeofNuke" + string(rune(1))), - KeyType: aws.String("RANGE"), - }, - }, - ProvisionedThroughput: &dynamodb.ProvisionedThroughput{ - ReadCapacityUnits: aws.Int64(1), - WriteCapacityUnits: aws.Int64(1), - }, - } - // CREATING THE TABLE FROM THE INPUT - _, err = svc.CreateTable(input) - require.NoError(t, err) - +type mockedDynamoDB struct { + dynamodbiface.DynamoDBAPI + DescribeTableOutputMap map[string]dynamodb.DescribeTableOutput + ListTablesOutput dynamodb.ListTablesOutput + DeleteTableOutput dynamodb.DeleteTableOutput } -func getTableStatus(TableName string, region string) *string { - awsSession, err := session.NewSession(&aws.Config{ - Region: aws.String(region)}, - ) - - svc := dynamodb.New(awsSession) - - tableInput := &dynamodb.DescribeTableInput{TableName: &TableName} - - result, err := svc.DescribeTable(tableInput) - if err != nil { - log.Fatalf("There was an error describing tables %v", err) - } +func (m mockedDynamoDB) ListTablesPages(input *dynamodb.ListTablesInput, fn func(*dynamodb.ListTablesOutput, bool) bool) error { + fn(&m.ListTablesOutput, true) + return nil +} - return result.Table.TableStatus +func (m mockedDynamoDB) DescribeTable(input *dynamodb.DescribeTableInput) (*dynamodb.DescribeTableOutput, error) { + output := m.DescribeTableOutputMap[*input.TableName] + return &output, nil +} +func (m mockedDynamoDB) DeleteTable(input *dynamodb.DeleteTableInput) (*dynamodb.DeleteTableOutput, error) { + return &m.DeleteTableOutput, nil } -func TestShouldIncludeTable(t *testing.T) { +func TestDynamoDB_GetAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") - mockTable := &dynamodb.TableDescription{ - TableName: aws.String("cloud-nuke-test"), - CreationDateTime: aws.Time(time.Now()), - } - - mockExpression, err := regexp.Compile("^cloud-nuke-*") - if err != nil { - log.Fatalf("There was an error compiling regex expression %v", err) - } + t.Parallel() - mockExcludeConfig := config.Config{ - DynamoDB: config.ResourceType{ - ExcludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *mockExpression, - }, + testName1 := "table1" + testName2 := "table2" + now := time.Now() + ddb := DynamoDB{ + Client: mockedDynamoDB{ + ListTablesOutput: dynamodb.ListTablesOutput{ + TableNames: []*string{ + aws.String(testName1), + aws.String(testName2), }, }, - }, - } - - mockIncludeConfig := config.Config{ - DynamoDB: config.ResourceType{ - IncludeRule: config.FilterRule{ - NamesRegExp: []config.Expression{ - { - RE: *mockExpression, + DescribeTableOutputMap: map[string]dynamodb.DescribeTableOutput{ + testName1: { + Table: &dynamodb.TableDescription{ + TableName: aws.String(testName1), + CreationDateTime: aws.Time(now), + }, + }, + testName2: { + Table: &dynamodb.TableDescription{ + TableName: aws.String(testName2), + CreationDateTime: aws.Time(now.Add(1)), }, }, }, }, } - cases := []struct { - Name string - Table *dynamodb.TableDescription - Config config.Config - ExcludeAfter time.Time - Expected bool + tests := map[string]struct { + configObj config.ResourceType + expected []string }{ - { - Name: "ConfigExclude", - Table: mockTable, - Config: mockExcludeConfig, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Expected: false, + "emptyFilter": { + configObj: config.ResourceType{}, + expected: []string{testName1, testName2}, }, - { - Name: "ConfigInclude", - Table: mockTable, - Config: mockIncludeConfig, - ExcludeAfter: time.Now().Add(1 * time.Hour), - Expected: true, + "nameExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + NamesRegExp: []config.Expression{{ + RE: *regexp.MustCompile(testName1), + }}}, + }, + expected: []string{testName2}, }, - { - Name: "NotOlderThan", - Table: mockTable, - Config: config.Config{}, - ExcludeAfter: time.Now().Add(1 * time.Hour * -1), - Expected: false, + "timeAfterExclusionFilter": { + configObj: config.ResourceType{ + ExcludeRule: config.FilterRule{ + TimeAfter: aws.Time(now.Add(-1)), + }}, + expected: []string{}, }, } - - for _, c := range cases { - t.Run(c.Name, func(t *testing.T) { - result := shouldIncludeTable(c.Table, c.ExcludeAfter, c.Config) - assert.Equal(t, c.Expected, result) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + names, err := ddb.getAll(config.Config{ + DynamoDB: tc.configObj, + }) + require.NoError(t, err) + require.Equal(t, tc.expected, aws.StringValueSlice(names)) }) } } -func TestGetTablesDynamo(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) - - db := DynamoDB{} - awsSession, err := session.NewSession(&aws.Config{ - Region: aws.String(region)}, - ) - require.NoError(t, err) - - _, err = getAllDynamoTables(awsSession, time.Now().Add(1*time.Hour*-1), config.Config{}, db) - require.NoError(t, err) -} - -func TestNukeAllDynamoDBTables(t *testing.T) { +func TestDynamoDb_NukeAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - db := DynamoDB{} - region, err := getRandomRegion() - if err != nil { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - awsSession, err := session.NewSession(&aws.Config{ - Region: aws.String(region)}, - ) - require.NoError(t, err) - - tableName := "cloud-nuke-test-" + util.UniqueID() - defer nukeAllDynamoDBTables(awsSession, []*string{&tableName}) - createTestDynamoTables(t, tableName, region) - COUNTER := 0 - for COUNTER <= 1 { - tableStatus := getTableStatus(tableName, region) - if *tableStatus == "ACTIVE" { - COUNTER += 1 - log.Printf("Created a table: %v\n", tableName) - } else { - log.Printf("Table not ready yet: %v", tableName) - } + ddb := DynamoDB{ + Client: mockedDynamoDB{ + DeleteTableOutput: dynamodb.DeleteTableOutput{}, + }, } - nukeErr := nukeAllDynamoDBTables(awsSession, []*string{&tableName}) - require.NoError(t, nukeErr) - - time.Sleep(5 * time.Second) - tables, err := getAllDynamoTables(awsSession, time.Now().Add(1*time.Hour*-1), config.Config{}, db) + err := ddb.nukeAll([]*string{aws.String("table1"), aws.String("table2")}) require.NoError(t, err) - - for _, table := range tables { - if tableName == *table { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - } } diff --git a/aws/dynamodb_types.go b/aws/dynamodb_types.go index 7cae1d17..cdf2ddb0 100644 --- a/aws/dynamodb_types.go +++ b/aws/dynamodb_types.go @@ -13,22 +13,22 @@ type DynamoDB struct { DynamoTableNames []string } -func (tables DynamoDB) ResourceName() string { +func (ddb DynamoDB) ResourceName() string { return "dynamodb" } -func (tables DynamoDB) ResourceIdentifiers() []string { - return tables.DynamoTableNames +func (ddb DynamoDB) ResourceIdentifiers() []string { + return ddb.DynamoTableNames } -func (tables DynamoDB) MaxBatchSize() int { +func (ddb DynamoDB) MaxBatchSize() int { // Tentative batch size to ensure AWS doesn't throttle return 49 } // Nuke - nuke all Dynamo DB Tables -func (tables DynamoDB) Nuke(awsSession *session.Session, identifiers []string) error { - if err := nukeAllDynamoDBTables(awsSession, awsgo.StringSlice(identifiers)); err != nil { +func (ddb DynamoDB) Nuke(awsSession *session.Session, identifiers []string) error { + if err := ddb.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) } return nil