Skip to content

Commit

Permalink
Wait for AWS NAT Gateways to be ready
Browse files Browse the repository at this point in the history
At least in theory, NAT-Gateway creation is asynchronous on AWS.  We
now wait to be sure it is successfully created.
  • Loading branch information
justinsb committed Oct 30, 2023
1 parent 86f808f commit c2f15e6
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 89 deletions.
74 changes: 45 additions & 29 deletions cloudmock/aws/mockec2/natgateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func (m *MockEC2) CreateNatGatewayWithId(request *ec2.CreateNatGatewayInput, id
Tags: tags,
}

// Immediately mark it ready
ngw.State = aws.String("available")

if request.AllocationId != nil {
var eip *ec2.Address
for _, address := range m.Addresses {
Expand Down Expand Up @@ -71,10 +74,7 @@ func (m *MockEC2) CreateNatGatewayWithId(request *ec2.CreateNatGatewayInput, id
}

func (m *MockEC2) CreateNatGateway(request *ec2.CreateNatGatewayInput) (*ec2.CreateNatGatewayOutput, error) {
klog.Infof("CreateNatGateway: %v", request)

id := m.allocateId("nat")
return m.CreateNatGatewayWithId(request, id)
panic("Not implemented")
}

func (m *MockEC2) WaitUntilNatGatewayAvailable(request *ec2.DescribeNatGatewaysInput) error {
Expand All @@ -93,7 +93,7 @@ func (m *MockEC2) WaitUntilNatGatewayAvailable(request *ec2.DescribeNatGatewaysI
}

// We just immediately mark it ready
ngw.State = aws.String("Available")
ngw.State = aws.String("available")

return nil
}
Expand All @@ -102,19 +102,50 @@ func (m *MockEC2) WaitUntilNatGatewayAvailableWithContext(aws.Context, *ec2.Desc
panic("Not implemented")
}

func (m *MockEC2) CreateNatGatewayWithContext(aws.Context, *ec2.CreateNatGatewayInput, ...request.Option) (*ec2.CreateNatGatewayOutput, error) {
panic("Not implemented")
func (m *MockEC2) CreateNatGatewayWithContext(ctx aws.Context, request *ec2.CreateNatGatewayInput, options ...request.Option) (*ec2.CreateNatGatewayOutput, error) {
klog.Infof("CreateNatGateway: %v", request)

id := m.allocateId("nat")
return m.CreateNatGatewayWithId(request, id)
}

func (m *MockEC2) CreateNatGatewayRequest(*ec2.CreateNatGatewayInput) (*request.Request, *ec2.CreateNatGatewayOutput) {
panic("Not implemented")
}

func (m *MockEC2) DescribeNatGateways(request *ec2.DescribeNatGatewaysInput) (*ec2.DescribeNatGatewaysOutput, error) {
func (m *MockEC2) DescribeNatGateways(*ec2.DescribeNatGatewaysInput) (*ec2.DescribeNatGatewaysOutput, error) {
panic("Not implemented")
}

func (m *MockEC2) DescribeNatGatewaysWithContext(ctx aws.Context, request *ec2.DescribeNatGatewaysInput, options ...request.Option) (*ec2.DescribeNatGatewaysOutput, error) {
var pages []*ec2.DescribeNatGatewaysOutput
callback := func(page *ec2.DescribeNatGatewaysOutput, lastPage bool) bool {
pages = append(pages, page)
return true
}

if err := m.DescribeNatGatewaysPagesWithContext(ctx, request, callback, options...); err != nil {
return nil, err
}
if len(pages) == 0 {
return nil, fmt.Errorf("DescribeNatGatewaysPagesWithContext did not return any pages")
}
return pages[0], nil
}

func (m *MockEC2) DescribeNatGatewaysRequest(*ec2.DescribeNatGatewaysInput) (*request.Request, *ec2.DescribeNatGatewaysOutput) {
panic("Not implemented")
}

func (m *MockEC2) DescribeNatGatewaysPages(*ec2.DescribeNatGatewaysInput, func(*ec2.DescribeNatGatewaysOutput, bool) bool) error {
panic("Not implemented")
}

func (m *MockEC2) DescribeNatGatewaysPagesWithContext(ctx aws.Context, request *ec2.DescribeNatGatewaysInput, callback func(*ec2.DescribeNatGatewaysOutput, bool) bool, options ...request.Option) error {
m.mutex.Lock()
defer m.mutex.Unlock()

klog.Infof("DescribeNatGateways: %v", request)
klog.Infof("DescribeNatGatewaysPagesWithContext: %v", request)

var ngws []*ec2.NatGateway

Expand All @@ -137,7 +168,7 @@ func (m *MockEC2) DescribeNatGateways(request *ec2.DescribeNatGatewaysInput) (*e
if strings.HasPrefix(*filter.Name, "tag:") {
match = m.hasTag(ec2.ResourceTypeNatgateway, *ngw.NatGatewayId, filter)
} else {
return nil, fmt.Errorf("unknown filter name: %q", *filter.Name)
return fmt.Errorf("unknown filter name: %q", *filter.Name)
}
}

Expand All @@ -160,26 +191,15 @@ func (m *MockEC2) DescribeNatGateways(request *ec2.DescribeNatGatewaysInput) (*e
NatGateways: ngws,
}

return response, nil
}

func (m *MockEC2) DescribeNatGatewaysWithContext(aws.Context, *ec2.DescribeNatGatewaysInput, ...request.Option) (*ec2.DescribeNatGatewaysOutput, error) {
panic("Not implemented")
}

func (m *MockEC2) DescribeNatGatewaysRequest(*ec2.DescribeNatGatewaysInput) (*request.Request, *ec2.DescribeNatGatewaysOutput) {
panic("Not implemented")
}

func (m *MockEC2) DescribeNatGatewaysPages(*ec2.DescribeNatGatewaysInput, func(*ec2.DescribeNatGatewaysOutput, bool) bool) error {
panic("Not implemented")
callback(response, false)
return nil
}

func (m *MockEC2) DescribeNatGatewaysPagesWithContext(aws.Context, *ec2.DescribeNatGatewaysInput, func(*ec2.DescribeNatGatewaysOutput, bool) bool, ...request.Option) error {
func (m *MockEC2) DeleteNatGateway(request *ec2.DeleteNatGatewayInput) (*ec2.DeleteNatGatewayOutput, error) {
panic("Not implemented")
}

func (m *MockEC2) DeleteNatGateway(request *ec2.DeleteNatGatewayInput) (*ec2.DeleteNatGatewayOutput, error) {
func (m *MockEC2) DeleteNatGatewayWithContext(ctx aws.Context, request *ec2.DeleteNatGatewayInput, options ...request.Option) (*ec2.DeleteNatGatewayOutput, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -195,10 +215,6 @@ func (m *MockEC2) DeleteNatGateway(request *ec2.DeleteNatGatewayInput) (*ec2.Del
return &ec2.DeleteNatGatewayOutput{}, nil
}

func (m *MockEC2) DeleteNatGatewayWithContext(aws.Context, *ec2.DeleteNatGatewayInput, ...request.Option) (*ec2.DeleteNatGatewayOutput, error) {
panic("Not implemented")
}

func (m *MockEC2) DeleteNatGatewayRequest(*ec2.DeleteNatGatewayInput) (*request.Request, *ec2.DeleteNatGatewayOutput) {
panic("Not implemented")
}
37 changes: 22 additions & 15 deletions pkg/resources/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package aws

import (
"context"
"errors"
"fmt"
"strings"
Expand Down Expand Up @@ -735,6 +736,8 @@ func DeleteSubnet(cloud fi.Cloud, tracker *resources.Resource) error {
}

func ListSubnets(cloud fi.Cloud, vpcID, clusterName string) ([]*resources.Resource, error) {
ctx := context.TODO()

c := cloud.(awsup.AWSCloud)
subnets, err := DescribeSubnets(cloud)
if err != nil {
Expand Down Expand Up @@ -833,20 +836,20 @@ func ListSubnets(cloud fi.Cloud, vpcID, clusterName string) ([]*resources.Resour

klog.V(2).Infof("Querying Nat Gateways")
request := &ec2.DescribeNatGatewaysInput{}
response, err := c.EC2().DescribeNatGateways(request)
if err != nil {
return nil, fmt.Errorf("error describing NatGateways: %v", err)
}
if err := c.EC2().DescribeNatGatewaysPagesWithContext(ctx, request, func(page *ec2.DescribeNatGatewaysOutput, lastPage bool) bool {
for _, ngw := range page.NatGateways {
id := aws.StringValue(ngw.NatGatewayId)
if !natGatewayIds.Has(id) {
continue
}

for _, ngw := range response.NatGateways {
id := aws.StringValue(ngw.NatGatewayId)
if !natGatewayIds.Has(id) {
continue
forceShared := sharedNgwIds.Has(id) || !ownedNatGatewayIds.Has(id)
r := buildNatGatewayResource(ngw, forceShared, clusterName)
resourceTrackers = append(resourceTrackers, r)
}

forceShared := sharedNgwIds.Has(id) || !ownedNatGatewayIds.Has(id)
r := buildNatGatewayResource(ngw, forceShared, clusterName)
resourceTrackers = append(resourceTrackers, r)
return true
}); err != nil {
return nil, fmt.Errorf("listing NatGateways: %w", err)
}
}

Expand Down Expand Up @@ -1296,6 +1299,8 @@ func FindAutoScalingLaunchTemplates(cloud fi.Cloud, clusterName string) ([]*reso
}

func FindNatGateways(cloud fi.Cloud, routeTables map[string]*resources.Resource, clusterName string) ([]*resources.Resource, error) {
ctx := context.TODO()

if len(routeTables) == 0 {
return nil, nil
}
Expand Down Expand Up @@ -1347,7 +1352,7 @@ func FindNatGateways(cloud fi.Cloud, routeTables map[string]*resources.Resource,
request := &ec2.DescribeNatGatewaysInput{
NatGatewayIds: []*string{aws.String(natGatewayId)},
}
response, err := c.EC2().DescribeNatGateways(request)
response, err := c.EC2().DescribeNatGatewaysWithContext(ctx, request)
if err != nil {
if awsup.AWSErrorCode(err) == "NatGatewayNotFound" {
klog.V(2).Infof("Got NatGatewayNotFound describing NatGateway %s; will treat as already-deleted", natGatewayId)
Expand Down Expand Up @@ -1768,6 +1773,8 @@ func DeleteElasticIP(cloud fi.Cloud, t *resources.Resource) error {
}

func DeleteNatGateway(cloud fi.Cloud, t *resources.Resource) error {
ctx := context.TODO()

c := cloud.(awsup.AWSCloud)

id := t.ID
Expand All @@ -1776,12 +1783,12 @@ func DeleteNatGateway(cloud fi.Cloud, t *resources.Resource) error {
request := &ec2.DeleteNatGatewayInput{
NatGatewayId: &id,
}
_, err := c.EC2().DeleteNatGateway(request)
_, err := c.EC2().DeleteNatGatewayWithContext(ctx, request)
if err != nil {
if IsDependencyViolation(err) {
return err
}
return fmt.Errorf("error deleting ngw %q: %v", t.Name, err)
return fmt.Errorf("error deleting nat gateway %q: %v", t.Name, err)
}
return nil
}
Expand Down
10 changes: 6 additions & 4 deletions upup/pkg/fi/cloudup/awstasks/elastic_ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package awstasks

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -59,18 +60,19 @@ func (e *ElasticIP) CompareWithID() *string {
}

// Find returns the actual ElasticIP state, or nil if not found
func (e *ElasticIP) Find(context *fi.CloudupContext) (*ElasticIP, error) {
return e.find(context.T.Cloud.(awsup.AWSCloud))
func (e *ElasticIP) Find(c *fi.CloudupContext) (*ElasticIP, error) {
ctx := c.Context()
return e.find(ctx, c.T.Cloud.(awsup.AWSCloud))
}

// find will attempt to look up the elastic IP from AWS
func (e *ElasticIP) find(cloud awsup.AWSCloud) (*ElasticIP, error) {
func (e *ElasticIP) find(ctx context.Context, cloud awsup.AWSCloud) (*ElasticIP, error) {
publicIP := e.PublicIP
allocationID := e.ID

// Find via RouteTable -> NatGateway -> ElasticIP
if allocationID == nil && publicIP == nil && e.AssociatedNatGatewayRouteTable != nil {
ngw, err := findNatGatewayFromRouteTable(cloud, e.AssociatedNatGatewayRouteTable)
ngw, err := findNatGatewayFromRouteTable(ctx, cloud, e.AssociatedNatGatewayRouteTable)
if err != nil {
return nil, fmt.Errorf("error finding AssociatedNatGatewayRouteTable: %v", err)
}
Expand Down
Loading

0 comments on commit c2f15e6

Please sign in to comment.