Skip to content

Commit

Permalink
Refactor code to pass around context for potential features (e.g., ti…
Browse files Browse the repository at this point in the history
…meout, accountId match) (#588)
  • Loading branch information
james03160927 authored Sep 21, 2023
1 parent dc85c15 commit fbb1842
Show file tree
Hide file tree
Showing 178 changed files with 472 additions and 296 deletions.
12 changes: 6 additions & 6 deletions aws/aws.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aws

import (
"context"
"fmt"
"github.com/gruntwork-io/cloud-nuke/util"
"github.com/pterm/pterm"
Expand All @@ -11,7 +12,6 @@ import (
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/progressbar"
Expand All @@ -21,7 +21,7 @@ import (
)

// GetAllResources - Lists all aws resources
func GetAllResources(query *Query, configObj config.Config) (*AwsAccountResources, error) {
func GetAllResources(c context.Context, query *Query, configObj config.Config) (*AwsAccountResources, error) {

configObj.AddExcludeAfterTime(query.ExcludeAfter)
configObj.AddIncludeAfterTime(query.IncludeAfter)
Expand All @@ -33,10 +33,10 @@ func GetAllResources(query *Query, configObj config.Config) (*AwsAccountResource
spinner, _ := pterm.DefaultSpinner.WithRemoveWhenDone(true).Start()
for _, region := range query.Regions {
cloudNukeSession := NewSession(region)
stsService := sts.New(cloudNukeSession)
resp, err := stsService.GetCallerIdentity(&sts.GetCallerIdentityInput{})
accountId, err := util.GetCurrentAccountId(cloudNukeSession)
if err == nil {
telemetry.SetAccountId(*resp.Account)
telemetry.SetAccountId(accountId)
c = context.WithValue(c, util.AccountIdKey, accountId)
}

awsResource := AwsRegionResource{}
Expand All @@ -46,7 +46,7 @@ func GetAllResources(query *Query, configObj config.Config) (*AwsAccountResource
spinner.UpdateText(
fmt.Sprintf("Searching %s resources in %s", (*resource).ResourceName(), region))
start := time.Now()
identifiers, err := (*resource).GetAndSetIdentifiers(configObj)
identifiers, err := (*resource).GetAndSetIdentifiers(c, configObj)
if err != nil {
ge := report.GeneralError{
Error: err,
Expand Down
3 changes: 2 additions & 1 deletion aws/resources/access_analyzer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"github.com/aws/aws-sdk-go/service/accessanalyzer"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
Expand All @@ -11,7 +12,7 @@ import (
"sync"
)

func (analyzer *AccessAnalyzer) getAll(configObj config.Config) ([]*string, error) {
func (analyzer *AccessAnalyzer) getAll(c context.Context, configObj config.Config) ([]*string, error) {
allAnalyzers := []*string{}
err := analyzer.Client.ListAnalyzersPages(
&accessanalyzer.ListAnalyzersInput{},
Expand Down
3 changes: 2 additions & 1 deletion aws/resources/access_analyzer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/accessanalyzer"
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestAccessAnalyzer_GetAll(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
names, err := aa.getAll(config.Config{
names, err := aa.getAll(context.Background(), config.Config{
AccessAnalyzer: tc.configObj,
})
require.NoError(t, err)
Expand Down
5 changes: 3 additions & 2 deletions aws/resources/access_analyzer_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/accessanalyzer"
Expand Down Expand Up @@ -37,8 +38,8 @@ func (analyzer *AccessAnalyzer) MaxBatchSize() int {
return 10
}

func (analyzer *AccessAnalyzer) GetAndSetIdentifiers(configObj config.Config) ([]string, error) {
identifiers, err := analyzer.getAll(configObj)
func (analyzer *AccessAnalyzer) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) {
identifiers, err := analyzer.getAll(c, configObj)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion aws/resources/acm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"github.com/aws/aws-sdk-go/service/acm"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
Expand All @@ -11,7 +12,7 @@ import (
)

// Returns a list of strings of ACM ARNs
func (a *ACM) getAll(configObj config.Config) ([]*string, error) {
func (a *ACM) getAll(c context.Context, configObj config.Config) ([]*string, error) {

params := &acm.ListCertificatesInput{}

Expand Down
9 changes: 5 additions & 4 deletions aws/resources/acm_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"regexp"
"testing"
"time"
Expand Down Expand Up @@ -52,12 +53,12 @@ func TestACMGetAll(t *testing.T) {
}

// without any filters
acms, err := acm.getAll(config.Config{})
acms, err := acm.getAll(context.Background(), config.Config{})
require.NoError(t, err)
require.Contains(t, aws.StringValueSlice(acms), testArn)

// filtering domain names
acms, err = acm.getAll(config.Config{
acms, err = acm.getAll(context.Background(), config.Config{
ACM: config.ResourceType{
ExcludeRule: config.FilterRule{
NamesRegExp: []config.Expression{{
Expand All @@ -68,7 +69,7 @@ func TestACMGetAll(t *testing.T) {
require.NotContains(t, aws.StringValueSlice(acms), testArn)

// filtering with time
acms, err = acm.getAll(config.Config{
acms, err = acm.getAll(context.Background(), config.Config{
ACM: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now.Add(-1)),
Expand Down Expand Up @@ -98,7 +99,7 @@ func TestACMGetAll_FilterInUse(t *testing.T) {
},
}

acms, err := acm.getAll(config.Config{})
acms, err := acm.getAll(context.Background(), config.Config{})
require.NoError(t, err)
require.NotContains(t, acms, testArn)
}
Expand Down
5 changes: 3 additions & 2 deletions aws/resources/acm_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/acm"
Expand Down Expand Up @@ -35,8 +36,8 @@ func (a *ACM) MaxBatchSize() int {
return 10
}

func (a *ACM) GetAndSetIdentifiers(configObj config.Config) ([]string, error) {
identifiers, err := a.getAll(configObj)
func (a *ACM) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) {
identifiers, err := a.getAll(c, configObj)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion aws/resources/acmpca.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"fmt"
"sync"
"time"
Expand All @@ -18,7 +19,7 @@ import (
)

// GetAll returns a list of all arns of ACMPCA, which can be deleted.
func (ap *ACMPCA) getAll(configObj config.Config) ([]*string, error) {
func (ap *ACMPCA) getAll(c context.Context, configObj config.Config) ([]*string, error) {
var arns []*string
paginationErr := ap.Client.ListCertificateAuthoritiesPages(
&acmpca.ListCertificateAuthoritiesInput{},
Expand Down
5 changes: 3 additions & 2 deletions aws/resources/acmpca_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -63,12 +64,12 @@ func TestAcmPcaGetAll(t *testing.T) {
}

// without filters
arns, err := acmPca.getAll(config.Config{})
arns, err := acmPca.getAll(context.Background(), config.Config{})
require.NoError(t, err)
require.Contains(t, awsgo.StringValueSlice(arns), testArn)

// with exclude after filter
arns, err = acmPca.getAll(config.Config{
arns, err = acmPca.getAll(context.Background(), config.Config{
ACMPCA: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: awsgo.Time(now.Add(-1))}},
Expand Down
5 changes: 3 additions & 2 deletions aws/resources/acmpca_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/acmpca"
Expand Down Expand Up @@ -35,8 +36,8 @@ func (ap *ACMPCA) MaxBatchSize() int {
return 10
}

func (ap *ACMPCA) GetAndSetIdentifiers(configObj config.Config) ([]string, error) {
identifiers, err := ap.getAll(configObj)
func (ap *ACMPCA) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) {
identifiers, err := ap.getAll(c, configObj)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion aws/resources/ami.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"strings"
"time"

Expand All @@ -17,7 +18,7 @@ import (
)

// Returns a formatted string of AMI Image ids
func (ami *AMIs) getAll(configObj config.Config) ([]*string, error) {
func (ami *AMIs) getAll(c context.Context, configObj config.Config) ([]*string, error) {
params := &ec2.DescribeImagesInput{
Owners: []*string{awsgo.String("self")},
}
Expand Down
9 changes: 5 additions & 4 deletions aws/resources/ami_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/gruntwork-io/cloud-nuke/config"
Expand Down Expand Up @@ -67,7 +68,7 @@ func TestAMIGetAll_SkipAWSManaged(t *testing.T) {
},
}

amis, err := acm.getAll(config.Config{})
amis, err := acm.getAll(context.Background(), config.Config{})
assert.NoError(t, err)
assert.NotContains(t, awsgo.StringValueSlice(amis), testImageId1)
assert.NotContains(t, awsgo.StringValueSlice(amis), testImageId2)
Expand All @@ -93,12 +94,12 @@ func TestAMIGetAll(t *testing.T) {
}

// without filters
amis, err := acm.getAll(config.Config{})
amis, err := acm.getAll(context.Background(), config.Config{})
assert.NoError(t, err)
assert.Contains(t, awsgo.StringValueSlice(amis), testImageId)

// with name filter
amis, err = acm.getAll(config.Config{
amis, err = acm.getAll(context.Background(), config.Config{
AMI: config.ResourceType{
ExcludeRule: config.FilterRule{
NamesRegExp: []config.Expression{{
Expand All @@ -108,7 +109,7 @@ func TestAMIGetAll(t *testing.T) {
assert.NotContains(t, awsgo.StringValueSlice(amis), testImageId)

// with time filter
amis, err = acm.getAll(config.Config{
amis, err = acm.getAll(context.Background(), config.Config{
AMI: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: awsgo.Time(now.Add(-12 * time.Hour))}}})
Expand Down
5 changes: 3 additions & 2 deletions aws/resources/ami_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
Expand Down Expand Up @@ -35,8 +36,8 @@ func (ami *AMIs) MaxBatchSize() int {
return 49
}

func (ami *AMIs) GetAndSetIdentifiers(configObj config.Config) ([]string, error) {
identifiers, err := ami.getAll(configObj)
func (ami *AMIs) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) {
identifiers, err := ami.getAll(c, configObj)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion aws/resources/apigateway.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"sync"

commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
Expand All @@ -15,7 +16,7 @@ import (
"github.com/hashicorp/go-multierror"
)

func (gateway *ApiGateway) getAll(configObj config.Config) ([]*string, error) {
func (gateway *ApiGateway) getAll(c context.Context, configObj config.Config) ([]*string, error) {
result, err := gateway.Client.GetRestApis(&apigateway.GetRestApisInput{})
if err != nil {
return []*string{}, errors.WithStackTrace(err)
Expand Down
9 changes: 5 additions & 4 deletions aws/resources/apigateway_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -47,7 +48,7 @@ func TestAPIGatewayGetAllAndNukeAll(t *testing.T) {
},
}

apis, err := apiGateway.getAll(config.Config{})
apis, err := apiGateway.getAll(context.Background(), config.Config{})
require.NoError(t, err)
require.Contains(t, awsgo.StringValueSlice(apis), testApiID)

Expand All @@ -73,7 +74,7 @@ func TestAPIGatewayGetAllTimeFilter(t *testing.T) {
}

// test API is not excluded from the filter
IDs, err := apiGateway.getAll(config.Config{
IDs, err := apiGateway.getAll(context.Background(), config.Config{
APIGateway: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now.Add(1)),
Expand All @@ -84,7 +85,7 @@ func TestAPIGatewayGetAllTimeFilter(t *testing.T) {
assert.Contains(t, aws.StringValueSlice(IDs), testApiID)

// test API being excluded from the filter
apiGwIdsOlder, err := apiGateway.getAll(config.Config{
apiGwIdsOlder, err := apiGateway.getAll(context.Background(), config.Config{
APIGateway: config.ResourceType{
ExcludeRule: config.FilterRule{
TimeAfter: aws.Time(now.Add(-1)),
Expand Down Expand Up @@ -113,7 +114,7 @@ func TestNukeAPIGatewayMoreThanOne(t *testing.T) {
},
}

apis, err := apiGateway.getAll(config.Config{})
apis, err := apiGateway.getAll(context.Background(), config.Config{})
require.NoError(t, err)
require.Contains(t, awsgo.StringValueSlice(apis), testApiID1)
require.Contains(t, awsgo.StringValueSlice(apis), testApiID2)
Expand Down
5 changes: 3 additions & 2 deletions aws/resources/apigateway_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resources

import (
"context"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigateway"
Expand Down Expand Up @@ -31,8 +32,8 @@ func (gateway *ApiGateway) MaxBatchSize() int {
return 10
}

func (gateway *ApiGateway) GetAndSetIdentifiers(configObj config.Config) ([]string, error) {
identifiers, err := gateway.getAll(configObj)
func (gateway *ApiGateway) GetAndSetIdentifiers(c context.Context, configObj config.Config) ([]string, error) {
identifiers, err := gateway.getAll(c, configObj)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit fbb1842

Please sign in to comment.