diff --git a/README.md b/README.md index 5ede9e49..569f7202 100644 --- a/README.md +++ b/README.md @@ -194,17 +194,19 @@ Dry run mode is only available within: You can import cloud-nuke into other projects and use it as a library for programmatically inspecting and counting resources. ```golang + package main import ( "fmt" "time" + "github.com/aws/aws-sdk-go/aws" nuke_aws "github.com/gruntwork-io/cloud-nuke/aws" + "github.com/gruntwork-io/cloud-nuke/externalcreds" ) func main() { - // You can scan multiple regions at once, or just pass a single region for speed targetRegions := []string{"us-east-1", "us-west-1", "us-west-2"} excludeRegions := []string{} @@ -214,6 +216,17 @@ func main() { // excludeAfter is parsed identically to the --older-than flag excludeAfter := time.Now() + // Any custom settings you want + myCustomConfig := &aws.Config{} + + myCustomConfig.WithMaxRetries(3) + myCustomConfig.WithLogLevel(aws.LogDebugWithRequestErrors) + // Optionally, set custom credentials + // myCustomConfig.WithCredentials() + + // Be sure to set your config prior to calling any library methods such as NewQuery + externalcreds.Set(myCustomConfig) + // NewQuery is a convenience method for configuring parameters you want to pass to your resource search query, err := nuke_aws.NewQuery( targetRegions, @@ -222,7 +235,6 @@ func main() { excludeResourceTypes, excludeAfter, ) - if err != nil { fmt.Println(err) } @@ -246,7 +258,7 @@ func main() { // countOfEc2InUsWest1: 2 fmt.Printf("usWest1Resources.ResourceTypePresent(\"ec2\"):%b\n", usWest1Resources.ResourceTypePresent("ec2")) - //usWest1Resources.ResourceTypePresent("ec2"): true + // usWest1Resources.ResourceTypePresent("ec2"): true // Get all the resource identifiers for a given resource type // In this example, we're only looking for ec2 instances @@ -254,7 +266,6 @@ func main() { fmt.Printf("resourceIds: %s", resourceIds) // resourceIds: [i-0c5d16c3ef28dda24 i-09d9739e1f4d27814] - } ``` diff --git a/aws/aws.go b/aws/aws.go index 0e67b725..fe4aee18 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/externalcreds" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/go-commons/collections" "github.com/gruntwork-io/go-commons/errors" @@ -54,16 +55,7 @@ const ( ) func newSession(region string) *session.Session { - return session.Must( - session.NewSessionWithOptions( - session.Options{ - SharedConfigState: session.SharedConfigEnable, - Config: awsgo.Config{ - Region: awsgo.String(region), - }, - }, - ), - ) + return externalcreds.Get(region) } // Try a describe regions command with the most likely enabled regions @@ -227,13 +219,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp logging.Logger.Infof("Checking region [%d/%d]: %s", count, totalRegions, region) - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }, - ) - if err != nil { - return nil, errors.WithStackTrace(err) - } + cloudNukeSession := newSession(region) resourcesInRegion := AwsRegionResource{} @@ -243,7 +229,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // ACMPCA arns acmpca := ACMPCA{} if IsNukeable(acmpca.ResourceName(), resourceTypes) { - arns, err := getAllACMPCA(session, region, excludeAfter) + arns, err := getAllACMPCA(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -257,7 +243,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // ASG Names asGroups := ASGroups{} if IsNukeable(asGroups.ResourceName(), resourceTypes) { - groupNames, err := getAllAutoScalingGroups(session, region, excludeAfter, configObj) + groupNames, err := getAllAutoScalingGroups(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -271,7 +257,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // Launch Configuration Names configs := LaunchConfigs{} if IsNukeable(configs.ResourceName(), resourceTypes) { - configNames, err := getAllLaunchConfigurations(session, region, excludeAfter, configObj) + configNames, err := getAllLaunchConfigurations(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -285,7 +271,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // LoadBalancer Names loadBalancers := LoadBalancers{} if IsNukeable(loadBalancers.ResourceName(), resourceTypes) { - elbNames, err := getAllElbInstances(session, region, excludeAfter) + elbNames, err := getAllElbInstances(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -299,7 +285,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // LoadBalancerV2 Arns loadBalancersV2 := LoadBalancersV2{} if IsNukeable(loadBalancersV2.ResourceName(), resourceTypes) { - elbv2Arns, err := getAllElbv2Instances(session, region, excludeAfter, configObj) + elbv2Arns, err := getAllElbv2Instances(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -313,7 +299,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // SQS Queues sqsQueue := SqsQueue{} if IsNukeable(sqsQueue.ResourceName(), resourceTypes) { - queueUrls, err := getAllSqsQueue(session, region, excludeAfter) + queueUrls, err := getAllSqsQueue(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -326,12 +312,12 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // TransitGatewayVpcAttachment transitGatewayVpcAttachments := TransitGatewaysVpcAttachment{} - transitGatewayIsAvailable, err := tgIsAvailableInRegion(session, region) + transitGatewayIsAvailable, err := tgIsAvailableInRegion(cloudNukeSession, region) if err != nil { return nil, errors.WithStackTrace(err) } if IsNukeable(transitGatewayVpcAttachments.ResourceName(), resourceTypes) && transitGatewayIsAvailable { - transitGatewayVpcAttachmentIds, err := getAllTransitGatewayVpcAttachments(session, region, excludeAfter) + transitGatewayVpcAttachmentIds, err := getAllTransitGatewayVpcAttachments(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -345,7 +331,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // TransitGatewayRouteTable transitGatewayRouteTables := TransitGatewaysRouteTables{} if IsNukeable(transitGatewayRouteTables.ResourceName(), resourceTypes) && transitGatewayIsAvailable { - transitGatewayRouteTableIds, err := getAllTransitGatewayRouteTables(session, region, excludeAfter) + transitGatewayRouteTableIds, err := getAllTransitGatewayRouteTables(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -359,7 +345,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // TransitGateway transitGateways := TransitGateways{} if IsNukeable(transitGateways.ResourceName(), resourceTypes) && transitGatewayIsAvailable { - transitGatewayIds, err := getAllTransitGatewayInstances(session, region, excludeAfter) + transitGatewayIds, err := getAllTransitGatewayInstances(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -373,7 +359,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // NATGateway natGateways := NatGateways{} if IsNukeable(natGateways.ResourceName(), resourceTypes) { - ngwIDs, err := getAllNatGateways(session, excludeAfter, configObj) + ngwIDs, err := getAllNatGateways(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -387,7 +373,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // OpenSearch Domains domains := OpenSearchDomains{} if IsNukeable(domains.ResourceName(), resourceTypes) { - domainNames, err := getOpenSearchDomainsToNuke(session, excludeAfter, configObj) + domainNames, err := getOpenSearchDomainsToNuke(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -401,7 +387,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // EC2 Instances ec2Instances := EC2Instances{} if IsNukeable(ec2Instances.ResourceName(), resourceTypes) { - instanceIds, err := getAllEc2Instances(session, region, excludeAfter, configObj) + instanceIds, err := getAllEc2Instances(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -415,7 +401,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // EBS Volumes ebsVolumes := EBSVolumes{} if IsNukeable(ebsVolumes.ResourceName(), resourceTypes) { - volumeIds, err := getAllEbsVolumes(session, region, excludeAfter, configObj) + volumeIds, err := getAllEbsVolumes(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -429,7 +415,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // EIP Addresses eipAddresses := EIPAddresses{} if IsNukeable(eipAddresses.ResourceName(), resourceTypes) { - allocationIds, err := getAllEIPAddresses(session, region, excludeAfter, configObj) + allocationIds, err := getAllEIPAddresses(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -443,7 +429,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // AMIs amis := AMIs{} if IsNukeable(amis.ResourceName(), resourceTypes) { - imageIds, err := getAllAMIs(session, region, excludeAfter) + imageIds, err := getAllAMIs(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -457,7 +443,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // Snapshots snapshots := Snapshots{} if IsNukeable(snapshots.ResourceName(), resourceTypes) { - snapshotIds, err := getAllSnapshots(session, region, excludeAfter) + snapshotIds, err := getAllSnapshots(cloudNukeSession, region, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -471,12 +457,12 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // ECS resources ecsServices := ECSServices{} if IsNukeable(ecsServices.ResourceName(), resourceTypes) { - clusterArns, err := getAllEcsClusters(session) + clusterArns, err := getAllEcsClusters(cloudNukeSession) if err != nil { return nil, errors.WithStackTrace(err) } if len(clusterArns) > 0 { - serviceArns, serviceClusterMap, err := getAllEcsServices(session, clusterArns, excludeAfter, configObj) + serviceArns, serviceClusterMap, err := getAllEcsServices(cloudNukeSession, clusterArns, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -488,7 +474,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp ecsClusters := ECSClusters{} if IsNukeable(ecsClusters.ResourceName(), resourceTypes) { - ecsClusterArns, err := getAllEcsClustersOlderThan(session, excludeAfter, configObj) + ecsClusterArns, err := getAllEcsClustersOlderThan(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -502,7 +488,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // EKS resources eksClusters := EKSClusters{} if IsNukeable(eksClusters.ResourceName(), resourceTypes) { - eksClusterNames, err := getAllEksClusters(session, excludeAfter, configObj) + eksClusterNames, err := getAllEksClusters(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -516,7 +502,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // RDS DB Instances dbInstances := DBInstances{} if IsNukeable(dbInstances.ResourceName(), resourceTypes) { - instanceNames, err := getAllRdsInstances(session, excludeAfter) + instanceNames, err := getAllRdsInstances(cloudNukeSession, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -533,7 +519,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // has different abstractions for each. dbClusters := DBClusters{} if IsNukeable(dbClusters.ResourceName(), resourceTypes) { - clustersNames, err := getAllRdsClusters(session, excludeAfter) + clustersNames, err := getAllRdsClusters(cloudNukeSession, excludeAfter) if err != nil { return nil, errors.WithStackTrace(err) } @@ -548,7 +534,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // Lambda Functions lambdaFunctions := LambdaFunctions{} if IsNukeable(lambdaFunctions.ResourceName(), resourceTypes) { - lambdaFunctionNames, err := getAllLambdaFunctions(session, excludeAfter, configObj, lambdaFunctions.MaxBatchSize()) + lambdaFunctionNames, err := getAllLambdaFunctions(cloudNukeSession, excludeAfter, configObj, lambdaFunctions.MaxBatchSize()) if err != nil { return nil, errors.WithStackTrace(err) } @@ -563,7 +549,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // Secrets Manager Secrets secretsManagerSecrets := SecretsManagerSecrets{} if IsNukeable(secretsManagerSecrets.ResourceName(), resourceTypes) { - secrets, err := getAllSecretsManagerSecrets(session, excludeAfter, configObj) + secrets, err := getAllSecretsManagerSecrets(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -578,7 +564,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // AccessAnalyzer accessAnalyzer := AccessAnalyzer{} if IsNukeable(accessAnalyzer.ResourceName(), resourceTypes) { - analyzerNames, err := getAllAccessAnalyzers(session, excludeAfter, configObj) + analyzerNames, err := getAllAccessAnalyzers(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -592,7 +578,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // CloudWatchDashboard cloudwatchDashboards := CloudWatchDashboards{} if IsNukeable(cloudwatchDashboards.ResourceName(), resourceTypes) { - cwdbNames, err := getAllCloudWatchDashboards(session, excludeAfter, configObj) + cwdbNames, err := getAllCloudWatchDashboards(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -606,7 +592,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // CloudWatchLogGroup cloudwatchLogGroups := CloudWatchLogGroups{} if IsNukeable(cloudwatchLogGroups.ResourceName(), resourceTypes) { - lgNames, err := getAllCloudWatchLogGroups(session, excludeAfter, configObj) + lgNames, err := getAllCloudWatchLogGroups(cloudNukeSession, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -636,7 +622,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp bucketNamesPerRegion, ok := resourcesCache["S3"] if !ok { - bucketNamesPerRegion, err = getAllS3Buckets(session, excludeAfter, targetRegions, "", s3Buckets.MaxConcurrentGetSize(), configObj) + bucketNamesPerRegion, err = getAllS3Buckets(cloudNukeSession, excludeAfter, targetRegions, "", s3Buckets.MaxConcurrentGetSize(), configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -659,7 +645,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp DynamoDB := DynamoDB{} if IsNukeable(DynamoDB.ResourceName(), resourceTypes) { - tablenames, err := getAllDynamoTables(session, excludeAfter, configObj, DynamoDB) + tablenames, err := getAllDynamoTables(cloudNukeSession, excludeAfter, configObj, DynamoDB) if err != nil { return nil, errors.WithStackTrace(err) } @@ -674,7 +660,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // EC2 VPCS ec2Vpcs := EC2VPCs{} if IsNukeable(ec2Vpcs.ResourceName(), resourceTypes) { - vpcids, vpcs, err := getAllVpcs(session, region, excludeAfter, configObj) + vpcids, vpcs, err := getAllVpcs(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -690,7 +676,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // Elasticaches elasticaches := Elasticaches{} if IsNukeable(elasticaches.ResourceName(), resourceTypes) { - clusterIds, err := getAllElasticacheClusters(session, region, excludeAfter, configObj) + clusterIds, err := getAllElasticacheClusters(cloudNukeSession, region, excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -705,7 +691,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // KMS Customer managed keys customerKeys := KmsCustomerKeys{} if IsNukeable(customerKeys.ResourceName(), resourceTypes) { - keys, err := getAllKmsUserKeys(session, customerKeys.MaxBatchSize(), excludeAfter, configObj) + keys, err := getAllKmsUserKeys(cloudNukeSession, customerKeys.MaxBatchSize(), excludeAfter, configObj) if err != nil { return nil, errors.WithStackTrace(err) } @@ -720,7 +706,7 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // GuardDuty detectors guardDutyDetectors := GuardDuty{} if IsNukeable(guardDutyDetectors.ResourceName(), resourceTypes) { - detectors, err := getAllGuardDutyDetectors(session, excludeAfter, configObj, guardDutyDetectors.MaxBatchSize()) + detectors, err := getAllGuardDutyDetectors(cloudNukeSession, excludeAfter, configObj, guardDutyDetectors.MaxBatchSize()) if err != nil { return nil, errors.WithStackTrace(err) } diff --git a/aws/inspect.go b/aws/inspect.go index f47352ea..eb8748a1 100644 --- a/aws/inspect.go +++ b/aws/inspect.go @@ -31,7 +31,6 @@ func ExtractResourcesForPrinting(account *AwsAccountResources) []string { } func ensureValidResourceTypes(resourceTypes []string) ([]string, error) { - invalidresourceTypes := []string{} for _, resourceType := range resourceTypes { if resourceType == "all" { @@ -79,7 +78,6 @@ func HandleResourceTypeSelections( } func InspectResources(q *Query) (*AwsAccountResources, error) { - // Log which resource types will be inspected logging.Logger.Info("The following resource types will be inspected:") if len(q.ResourceTypes) > 0 { diff --git a/aws/s3.go b/aws/s3.go index 1f173d4a..b6eecc88 100644 --- a/aws/s3.go +++ b/aws/s3.go @@ -94,8 +94,8 @@ type S3Bucket struct { // getAllS3Buckets returns a map of per region AWS S3 buckets which were created before excludeAfter func getAllS3Buckets(awsSession *session.Session, excludeAfter time.Time, - targetRegions []string, bucketNameSubStr string, batchSize int, configObj config.Config) (map[string][]*string, error) { - + targetRegions []string, bucketNameSubStr string, batchSize int, configObj config.Config, +) (map[string][]*string, error) { if batchSize <= 0 { return nil, fmt.Errorf("Invalid batchsize - %d - should be > 0", batchSize) } @@ -112,7 +112,7 @@ func getAllS3Buckets(awsSession *session.Session, excludeAfter time.Time, return nil, errors.WithStackTrace(err) } - var bucketNamesPerRegion = make(map[string][]*string) + bucketNamesPerRegion := make(map[string][]*string) totalBuckets := len(output.Buckets) if totalBuckets == 0 { return bucketNamesPerRegion, nil @@ -127,7 +127,6 @@ func getAllS3Buckets(awsSession *session.Session, excludeAfter time.Time, logging.Logger.Infof("Getting - %d-%d buckets of batch %d/%d", batchStart+1, batchEnd, batchCount, totalBatches) targetBuckets := output.Buckets[batchStart:batchEnd] currBucketNamesPerRegion, err := getBucketNamesPerRegion(svc, targetBuckets, excludeAfter, regionClients, bucketNameSubStr, configObj) - if err != nil { return bucketNamesPerRegion, err } @@ -145,15 +144,12 @@ func getAllS3Buckets(awsSession *session.Session, excludeAfter time.Time, // getRegions creates s3 clients for target regions func getRegionClients(regions []string) (map[string]*s3.S3, error) { - var regionClients = make(map[string]*s3.S3) + regionClients := make(map[string]*s3.S3) for _, region := range regions { logging.Logger.Debugf("S3 - creating session - region %s", region) - awsSession, err := session.NewSession(&aws.Config{ - Region: aws.String(region)}, - ) - if err != nil { - return regionClients, err - } + + awsSession := newSession(region) + regionClients[region] = s3.New(awsSession) } return regionClients, nil @@ -161,10 +157,10 @@ func getRegionClients(regions []string) (map[string]*s3.S3, error) { // getBucketNamesPerRegions gets valid bucket names concurrently from list of target buckets func getBucketNamesPerRegion(svc *s3.S3, targetBuckets []*s3.Bucket, excludeAfter time.Time, regionClients map[string]*s3.S3, - bucketNameSubStr string, configObj config.Config) (map[string][]*string, error) { - - var bucketNamesPerRegion = make(map[string][]*string) - var bucketCh = make(chan *S3Bucket, len(targetBuckets)) + bucketNameSubStr string, configObj config.Config, +) (map[string][]*string, error) { + bucketNamesPerRegion := make(map[string][]*string) + bucketCh := make(chan *S3Bucket, len(targetBuckets)) var wg sync.WaitGroup for _, bucket := range targetBuckets { diff --git a/aws/types.go b/aws/types.go index a84383af..e2542dba 100644 --- a/aws/types.go +++ b/aws/types.go @@ -83,7 +83,6 @@ type Query struct { // NewQuery configures and returns a Query struct that can be passed into the InspectResources method func NewQuery(regions, excludeRegions, resourceTypes, excludeResourceTypes []string, excludeAfter time.Time) (*Query, error) { - q := &Query{ Regions: regions, ExcludeRegions: excludeRegions, @@ -104,7 +103,6 @@ func NewQuery(regions, excludeRegions, resourceTypes, excludeResourceTypes []str // Validate ensures the configured values for a Query are valid, returning an error if there are // any invalid params, or nil if the Query is valid func (q *Query) Validate() error { - resourceTypes, err := HandleResourceTypeSelections(q.ResourceTypes, q.ExcludeResourceTypes) if err != nil { return err @@ -118,7 +116,6 @@ func (q *Query) Validate() error { } targetRegions, err := GetTargetRegions(enabledRegions, q.Regions, q.ExcludeRegions) - if err != nil { return CouldNotSelectRegionError{Underlying: err} } diff --git a/commands/cli.go b/commands/cli.go index 03098d7f..0b136d3b 100644 --- a/commands/cli.go +++ b/commands/cli.go @@ -160,7 +160,6 @@ func awsNuke(c *cli.Context) error { if configFilePath != "" { configObjPtr, err := config.GetConfig(configFilePath) - if err != nil { return fmt.Errorf("Error reading config - %s - %s", configFilePath, err) } @@ -212,7 +211,6 @@ func awsNuke(c *cli.Context) error { logging.Logger.Infof("Retrieving active AWS resources in [%s]", strings.Join(targetRegions[:], ", ")) account, err := aws.GetAllResources(targetRegions, *excludeAfter, resourceTypes, configObj) - if err != nil { return errors.WithStackTrace(err) } @@ -372,7 +370,6 @@ func confirmationPrompt(prompt string, maxPrompts int) (bool, error) { prompts := 0 for prompts < maxPrompts { input, err := shell.PromptUserForInput(prompt, &shellOptions) - if err != nil { return false, errors.WithStackTrace(err) } @@ -417,13 +414,11 @@ func awsInspect(c *cli.Context) error { c.StringSlice("exclude-resource-type"), *excludeAfter, ) - if err != nil { return aws.QueryCreationError{Underlying: err} } accountResources, err := aws.InspectResources(query) - if err != nil { return errors.WithStackTrace(aws.ResourceInspectionError{Underlying: err}) } diff --git a/externalcreds/creds.go b/externalcreds/creds.go new file mode 100644 index 00000000..73b2b176 --- /dev/null +++ b/externalcreds/creds.go @@ -0,0 +1,30 @@ +package externalcreds + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" +) + +var externalConfig *aws.Config + +func Set(opts *aws.Config) { + externalConfig = opts +} + +func Get(region string) *session.Session { + config := aws.Config{ + Region: aws.String(region), + } + // If external config was passed in, use its credentials + if externalConfig != nil { + config.Credentials = externalConfig.Credentials + } + return session.Must( + session.NewSessionWithOptions( + session.Options{ + SharedConfigState: session.SharedConfigEnable, + Config: config, + }, + ), + ) +}