Skip to content

Commit

Permalink
Refactor Transit Gateway (#750)
Browse files Browse the repository at this point in the history
* fix:code refactoring of transit_gateway file

* fix:code refactoring of transit_gateway file

---------

Co-authored-by: James Kwon <[email protected]>
  • Loading branch information
james03160927 and james03160927 authored Jul 23, 2024
1 parent 31de63a commit 95ea781
Show file tree
Hide file tree
Showing 12 changed files with 683 additions and 566 deletions.
53 changes: 53 additions & 0 deletions aws/resources/tgw_peering_attachment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package resources

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
"github.com/gruntwork-io/go-commons/errors"
)

func (tgpa *TransitGatewayPeeringAttachment) getAll(c context.Context, configObj config.Config) ([]*string, error) {
var ids []*string
err := tgpa.Client.DescribeTransitGatewayPeeringAttachmentsPagesWithContext(tgpa.Context, &ec2.DescribeTransitGatewayPeeringAttachmentsInput{}, func(result *ec2.DescribeTransitGatewayPeeringAttachmentsOutput, lastPage bool) bool {
for _, attachment := range result.TransitGatewayPeeringAttachments {
if configObj.TransitGatewayPeeringAttachment.ShouldInclude(config.ResourceValue{
Time: attachment.CreationTime,
}) {
ids = append(ids, attachment.TransitGatewayAttachmentId)
}
}

return !lastPage
})
if err != nil {
return nil, errors.WithStackTrace(err)
}

return ids, nil
}

func (tgpa *TransitGatewayPeeringAttachment) nukeAll(ids []*string) error {
for _, id := range ids {
_, err := tgpa.Client.DeleteTransitGatewayPeeringAttachmentWithContext(tgpa.Context, &ec2.DeleteTransitGatewayPeeringAttachmentInput{
TransitGatewayAttachmentId: id,
})
// Record status of this resource
report.Record(report.Entry{
Identifier: aws.StringValue(id),
ResourceType: tgpa.ResourceName(),
Error: err,
})
if err != nil {
logging.Errorf("[Failed] %s", err)
} else {
logging.Debugf("Deleted Transit Gateway Peering Attachment: %s", *id)
}
}

return nil
}
95 changes: 95 additions & 0 deletions aws/resources/tgw_peering_attachment_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package resources

import (
"context"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/stretchr/testify/require"
)

type mockedTransitGatewayPeeringAttachment struct {
ec2iface.EC2API
DescribeTransitGatewayPeeringAttachmentsOutput ec2.DescribeTransitGatewayPeeringAttachmentsOutput
DeleteTransitGatewayPeeringAttachmentOutput ec2.DeleteTransitGatewayPeeringAttachmentOutput
}

func (m mockedTransitGatewayPeeringAttachment) DescribeTransitGatewayPeeringAttachmentsPagesWithContext(_ awsgo.Context, _ *ec2.DescribeTransitGatewayPeeringAttachmentsInput, fn func(*ec2.DescribeTransitGatewayPeeringAttachmentsOutput, bool) bool, _ ...request.Option) error {
fn(&m.DescribeTransitGatewayPeeringAttachmentsOutput, true)
return nil
}

func (m mockedTransitGatewayPeeringAttachment) DeleteTransitGatewayPeeringAttachmentWithContext(_ awsgo.Context, _ *ec2.DeleteTransitGatewayPeeringAttachmentInput, _ ...request.Option) (*ec2.DeleteTransitGatewayPeeringAttachmentOutput, error) {
return &m.DeleteTransitGatewayPeeringAttachmentOutput, nil
}

func TestTransitGatewayPeeringAttachment_getAll(t *testing.T) {

t.Parallel()

now := time.Now()
attachment1 := "attachement1"
attachment2 := "attachement2"
tgpa := TransitGatewayPeeringAttachment{
Client: mockedTransitGatewayPeeringAttachment{
DescribeTransitGatewayPeeringAttachmentsOutput: ec2.DescribeTransitGatewayPeeringAttachmentsOutput{
TransitGatewayPeeringAttachments: []*ec2.TransitGatewayPeeringAttachment{
{
TransitGatewayAttachmentId: aws.String(attachment1),
CreationTime: aws.Time(now),
},
{
TransitGatewayAttachmentId: aws.String(attachment2),
CreationTime: aws.Time(now.Add(1)),
},
},
},
},
}

tests := map[string]struct {
configObj config.ResourceType
expected []string
}{
"emptyFilter": {
configObj: config.ResourceType{},
expected: []string{attachment1, attachment2},
},
"timeAfterExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now),
}},
expected: []string{attachment1},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
names, err := tgpa.getAll(context.Background(), config.Config{
TransitGatewayPeeringAttachment: tc.configObj,
})
require.NoError(t, err)
require.Equal(t, tc.expected, aws.StringValueSlice(names))
})
}
}

func TestTransitGatewayPeeringAttachment_nukeAll(t *testing.T) {

t.Parallel()

tgw := TransitGatewayPeeringAttachment{
Client: mockedTransitGatewayPeeringAttachment{
DeleteTransitGatewayPeeringAttachmentOutput: ec2.DeleteTransitGatewayPeeringAttachmentOutput{},
},
}

err := tgw.nukeAll([]*string{aws.String("test-attachment")})
require.NoError(t, err)
}
58 changes: 58 additions & 0 deletions aws/resources/tgw_peering_attachment_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package resources

import (
"context"

awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/go-commons/errors"
)

// TransitGatewayPeeringAttachment - represents all transit gateways peering attachment
type TransitGatewayPeeringAttachment struct {
BaseAwsResource
Client ec2iface.EC2API
Region string
Ids []string
}

func (tgpa *TransitGatewayPeeringAttachment) Init(session *session.Session) {
tgpa.Client = ec2.New(session)
}

func (tgpa *TransitGatewayPeeringAttachment) ResourceName() string {
return "transit-gateway-peering-attachment"
}

func (tgpa *TransitGatewayPeeringAttachment) MaxBatchSize() int {
return maxBatchSize
}

func (tgpa *TransitGatewayPeeringAttachment) ResourceIdentifiers() []string {
return tgpa.Ids
}

func (tgpa *TransitGatewayPeeringAttachment) GetAndSetResourceConfig(configObj config.Config) config.ResourceType {
return configObj.TransitGateway
}

func (tgpa *TransitGatewayPeeringAttachment) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) {
identifiers, err := tgpa.getAll(c, configObj)
if err != nil {
return nil, err
}

tgpa.Ids = awsgo.StringValueSlice(identifiers)
return tgpa.Ids, nil
}

func (tgpa *TransitGatewayPeeringAttachment) Nuke(identifiers []string) error {
if err := tgpa.nukeAll(awsgo.StringSlice(identifiers)); err != nil {
return errors.WithStackTrace(err)
}

return nil
}
70 changes: 70 additions & 0 deletions aws/resources/tgw_route_tables.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package resources

import (
"context"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/go-commons/errors"
)

// Returns a formatted string of TransitGatewayRouteTable IDs
func (tgw *TransitGatewaysRouteTables) getAll(c context.Context, configObj config.Config) ([]*string, error) {
// Remove default route table, that will be deleted along with its TransitGateway
param := &ec2.DescribeTransitGatewayRouteTablesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("default-association-route-table"),
Values: []*string{
aws.String("false"),
},
},
},
}

result, err := tgw.Client.DescribeTransitGatewayRouteTablesWithContext(tgw.Context, param)
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, transitGatewayRouteTable := range result.TransitGatewayRouteTables {
if configObj.TransitGatewayRouteTable.ShouldInclude(config.ResourceValue{Time: transitGatewayRouteTable.CreationTime}) &&
awsgo.StringValue(transitGatewayRouteTable.State) != "deleted" && awsgo.StringValue(transitGatewayRouteTable.State) != "deleting" {
ids = append(ids, transitGatewayRouteTable.TransitGatewayRouteTableId)
}
}

return ids, nil
}

// Delete all TransitGatewayRouteTables
func (tgw *TransitGatewaysRouteTables) nukeAll(ids []*string) error {
if len(ids) == 0 {
logging.Debugf("No Transit Gateway Route Tables to nuke in region %s", tgw.Region)
return nil
}

logging.Debugf("Deleting all Transit Gateway Route Tables in region %s", tgw.Region)
var deletedIds []*string

for _, id := range ids {
param := &ec2.DeleteTransitGatewayRouteTableInput{
TransitGatewayRouteTableId: id,
}

_, err := tgw.Client.DeleteTransitGatewayRouteTableWithContext(tgw.Context, param)
if err != nil {
logging.Debugf("[Failed] %s", err)
} else {
deletedIds = append(deletedIds, id)
logging.Debugf("Deleted Transit Gateway Route Table: %s", *id)
}
}

logging.Debugf("[OK] %d Transit Gateway Route Table(s) deleted in %s", len(deletedIds), tgw.Region)
return nil
}
94 changes: 94 additions & 0 deletions aws/resources/tgw_route_tables_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package resources

import (
"context"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/stretchr/testify/require"
)

type mockedTransitGatewayRouteTable struct {
ec2iface.EC2API
DescribeTransitGatewayRouteTablesOutput ec2.DescribeTransitGatewayRouteTablesOutput
DeleteTransitGatewayRouteTableOutput ec2.DeleteTransitGatewayRouteTableOutput
}

func (m mockedTransitGatewayRouteTable) DescribeTransitGatewayRouteTablesWithContext(_ awsgo.Context, _ *ec2.DescribeTransitGatewayRouteTablesInput, _ ...request.Option) (*ec2.DescribeTransitGatewayRouteTablesOutput, error) {
return &m.DescribeTransitGatewayRouteTablesOutput, nil
}

func (m mockedTransitGatewayRouteTable) DeleteTransitGatewayRouteTableWithContext(_ awsgo.Context, _ *ec2.DeleteTransitGatewayRouteTableInput, _ ...request.Option) (*ec2.DeleteTransitGatewayRouteTableOutput, error) {
return &m.DeleteTransitGatewayRouteTableOutput, nil
}
func TestTransitGatewayRouteTables_GetAll(t *testing.T) {

t.Parallel()

now := time.Now()
tableId1 := "table1"
tableId2 := "table2"
tgw := TransitGatewaysRouteTables{
Client: mockedTransitGatewayRouteTable{
DescribeTransitGatewayRouteTablesOutput: ec2.DescribeTransitGatewayRouteTablesOutput{
TransitGatewayRouteTables: []*ec2.TransitGatewayRouteTable{
{
TransitGatewayRouteTableId: aws.String(tableId1),
CreationTime: aws.Time(now),
State: aws.String("available"),
},
{
TransitGatewayRouteTableId: aws.String(tableId2),
CreationTime: aws.Time(now.Add(1)),
State: aws.String("deleting"),
},
},
},
},
}
tests := map[string]struct {
configObj config.ResourceType
expected []string
}{
"emptyFilter": {
configObj: config.ResourceType{},
expected: []string{tableId1},
},
"timeAfterExclusionFilter": {
configObj: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now.Add(-1 * time.Hour)),
}},
expected: []string{},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
names, err := tgw.getAll(context.Background(), config.Config{
TransitGatewayRouteTable: tc.configObj,
})
require.NoError(t, err)
require.Equal(t, tc.expected, aws.StringValueSlice(names))
})
}
}

func TestTransitGatewayRouteTables_NukeAll(t *testing.T) {

t.Parallel()

tgw := TransitGatewaysRouteTables{
Client: mockedTransitGatewayRouteTable{
DeleteTransitGatewayRouteTableOutput: ec2.DeleteTransitGatewayRouteTableOutput{},
},
}

err := tgw.nukeAll([]*string{aws.String("test-route-table")})
require.NoError(t, err)
}
Loading

0 comments on commit 95ea781

Please sign in to comment.