-
-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement support for API Gateway v1 and v2 (#352)
- Loading branch information
1 parent
3f5a923
commit e9b3ae9
Showing
23 changed files
with
860 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package aws | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/apigateway" | ||
"github.com/gruntwork-io/cloud-nuke/config" | ||
"github.com/gruntwork-io/cloud-nuke/logging" | ||
"github.com/gruntwork-io/go-commons/errors" | ||
"github.com/hashicorp/go-multierror" | ||
) | ||
|
||
func getAllAPIGateways(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) { | ||
svc := apigateway.New(session) | ||
|
||
result, err := svc.GetRestApis(&apigateway.GetRestApisInput{}) | ||
if err != nil { | ||
return []*string{}, errors.WithStackTrace(err) | ||
} | ||
Ids := []*string{} | ||
for _, apigateway := range result.Items { | ||
if shouldIncludeAPIGateway(apigateway, excludeAfter, configObj) { | ||
Ids = append(Ids, apigateway.Id) | ||
} | ||
} | ||
|
||
return Ids, nil | ||
} | ||
|
||
func shouldIncludeAPIGateway(apigw *apigateway.RestApi, excludeAfter time.Time, configObj config.Config) bool { | ||
if apigw == nil { | ||
return false | ||
} | ||
|
||
if apigw.CreatedDate != nil { | ||
if excludeAfter.Before(aws.TimeValue(apigw.CreatedDate)) { | ||
return false | ||
} | ||
} | ||
|
||
return config.ShouldInclude( | ||
aws.StringValue(apigw.Name), | ||
configObj.APIGateway.IncludeRule.NamesRegExp, | ||
configObj.APIGateway.ExcludeRule.NamesRegExp, | ||
) | ||
} | ||
|
||
func nukeAllAPIGateways(session *session.Session, identifiers []*string) error { | ||
region := aws.StringValue(session.Config.Region) | ||
|
||
svc := apigateway.New(session) | ||
|
||
if len(identifiers) == 0 { | ||
logging.Logger.Infof("No API Gateways (v1) to nuke in region %s", region) | ||
} | ||
|
||
if len(identifiers) > 100 { | ||
logging.Logger.Errorf("Nuking too many API Gateways (v1) at once (100): halting to avoid hitting AWS API rate limiting") | ||
return TooManyApiGatewayErr{} | ||
} | ||
|
||
// There is no bulk delete Api Gateway API, so we delete the batch of gateways concurrently using goroutines | ||
logging.Logger.Infof("Deleting Api Gateways (v1) in region %s", region) | ||
wg := new(sync.WaitGroup) | ||
wg.Add(len(identifiers)) | ||
errChans := make([]chan error, len(identifiers)) | ||
for i, apigwID := range identifiers { | ||
errChans[i] = make(chan error, 1) | ||
go deleteApiGatewayAsync(wg, errChans[i], svc, apigwID, region) | ||
} | ||
wg.Wait() | ||
|
||
var allErrs *multierror.Error | ||
for _, errChan := range errChans { | ||
if err := <-errChan; err != nil { | ||
allErrs = multierror.Append(allErrs, err) | ||
logging.Logger.Errorf("[Failed] %s", err) | ||
} | ||
} | ||
finalErr := allErrs.ErrorOrNil() | ||
if finalErr != nil { | ||
return errors.WithStackTrace(finalErr) | ||
} | ||
return nil | ||
} | ||
|
||
func deleteApiGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *apigateway.APIGateway, apigwID *string, region string) { | ||
defer wg.Done() | ||
|
||
input := &apigateway.DeleteRestApiInput{RestApiId: apigwID} | ||
_, err := svc.DeleteRestApi(input) | ||
errChan <- err | ||
|
||
if err == nil { | ||
logging.Logger.Infof("[OK] API Gateway (v1) %s deleted in %s", aws.StringValue(apigwID), region) | ||
} else { | ||
logging.Logger.Errorf("[Failed] Error deleting API Gateway (v1) %s in %s", aws.StringValue(apigwID), region) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
package aws | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
awsgo "github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/apigateway" | ||
"github.com/gruntwork-io/cloud-nuke/config" | ||
"github.com/gruntwork-io/cloud-nuke/util" | ||
"github.com/gruntwork-io/go-commons/errors" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type TestAPIGateway struct { | ||
ID *string | ||
Name *string | ||
} | ||
|
||
func createTestAPIGateway(t *testing.T, session *session.Session, name string) (*TestAPIGateway, error) { | ||
svc := apigateway.New(session) | ||
|
||
testGw := &TestAPIGateway{ | ||
Name: aws.String(name), | ||
} | ||
|
||
param := &apigateway.CreateRestApiInput{ | ||
Name: aws.String(name), | ||
} | ||
|
||
output, err := svc.CreateRestApi(param) | ||
if err != nil { | ||
assert.Failf(t, "Could not create test API Gateway: %s", errors.WithStackTrace(err).Error()) | ||
} | ||
|
||
testGw.ID = output.Id | ||
|
||
return testGw, nil | ||
} | ||
|
||
func TestListAPIGateways(t *testing.T) { | ||
t.Parallel() | ||
|
||
region, err := getRandomRegion() | ||
require.NoError(t, err) | ||
session, err := session.NewSession(&awsgo.Config{ | ||
Region: awsgo.String(region), | ||
}, | ||
) | ||
if err != nil { | ||
assert.Fail(t, errors.WithStackTrace(err).Error()) | ||
} | ||
|
||
apigwName := "aws-nuke-test-" + util.UniqueID() | ||
testGw, createTestGwErr := createTestAPIGateway(t, session, apigwName) | ||
require.NoError(t, createTestGwErr) | ||
// clean up after this test | ||
defer nukeAllAPIGateways(session, []*string{testGw.ID}) | ||
|
||
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) | ||
if err != nil { | ||
assert.Fail(t, "Unable to fetch list of API Gateways (v1)") | ||
} | ||
|
||
assert.Contains(t, awsgo.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) | ||
} | ||
|
||
func TestTimeFilterExclusionNewlyCreatedAPIGateway(t *testing.T) { | ||
t.Parallel() | ||
|
||
region, err := getRandomRegion() | ||
require.NoError(t, err) | ||
|
||
session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) | ||
require.NoError(t, err) | ||
|
||
apigwName := "aws-nuke-test-" + util.UniqueID() | ||
|
||
testGw, createTestGwErr := createTestAPIGateway(t, session, apigwName) | ||
require.NoError(t, createTestGwErr) | ||
defer nukeAllAPIGateways(session, []*string{testGw.ID}) | ||
|
||
// Assert API Gateway is picked up without filters | ||
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) | ||
require.NoError(t, err) | ||
assert.Contains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) | ||
|
||
// Assert API Gateway doesn't appear when we look at API Gateways older than 1 Hour | ||
olderThan := time.Now().Add(-1 * time.Hour) | ||
apiGwIdsOlder, err := getAllAPIGateways(session, olderThan, config.Config{}) | ||
require.NoError(t, err) | ||
assert.NotContains(t, aws.StringValueSlice(apiGwIdsOlder), aws.StringValue(testGw.ID)) | ||
} | ||
|
||
func TestNukeAPIGatewayOne(t *testing.T) { | ||
t.Parallel() | ||
|
||
region, err := getRandomRegion() | ||
require.NoError(t, err) | ||
|
||
session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) | ||
require.NoError(t, err) | ||
|
||
apigwName := "aws-nuke-test-" + util.UniqueID() | ||
// We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke. | ||
testGw, createTestErr := createTestAPIGateway(t, session, apigwName) | ||
require.NoError(t, createTestErr) | ||
|
||
nukeErr := nukeAllAPIGateways(session, []*string{testGw.ID}) | ||
require.NoError(t, nukeErr) | ||
|
||
// Make sure the API Gateway was deleted | ||
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) | ||
require.NoError(t, err) | ||
|
||
assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) | ||
} | ||
|
||
func TestNukeAPIGatewayMoreThanOne(t *testing.T) { | ||
t.Parallel() | ||
|
||
region, err := getRandomRegion() | ||
require.NoError(t, err) | ||
|
||
session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) | ||
require.NoError(t, err) | ||
|
||
apigwName := "aws-nuke-test-" + util.UniqueID() | ||
apigwName2 := "aws-nuke-test-" + util.UniqueID() | ||
// We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke. | ||
testGw, createTestErr := createTestAPIGateway(t, session, apigwName) | ||
require.NoError(t, createTestErr) | ||
testGw2, createTestErr2 := createTestAPIGateway(t, session, apigwName2) | ||
require.NoError(t, createTestErr2) | ||
|
||
nukeErr := nukeAllAPIGateways(session, []*string{testGw.ID, testGw2.ID}) | ||
require.NoError(t, nukeErr) | ||
|
||
// Make sure the API Gateway was deleted | ||
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) | ||
require.NoError(t, err) | ||
|
||
assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) | ||
assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw2.ID)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package aws | ||
|
||
import ( | ||
awsgo "github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/gruntwork-io/go-commons/errors" | ||
) | ||
|
||
type ApiGateway struct { | ||
Ids []string | ||
} | ||
|
||
func (apigateway ApiGateway) ResourceName() string { | ||
return "apigateway" | ||
} | ||
|
||
func (apigateway ApiGateway) ResourceIdentifiers() []string { | ||
return apigateway.Ids | ||
} | ||
|
||
func (apigateway ApiGateway) MaxBatchSize() int { | ||
return 10 | ||
} | ||
|
||
func (apigateway ApiGateway) Nuke(session *session.Session, identifiers []string) error { | ||
if err := nukeAllAPIGateways(session, awsgo.StringSlice(identifiers)); err != nil { | ||
return errors.WithStackTrace(err) | ||
} | ||
return nil | ||
} | ||
|
||
type TooManyApiGatewayErr struct{} | ||
|
||
func (err TooManyApiGatewayErr) Error() string { | ||
return "Too many Api Gateways requested at once." | ||
} |
Oops, something went wrong.