diff --git a/cloudmock/aws/mockec2/natgateway.go b/cloudmock/aws/mockec2/natgateway.go index cbcea93f0f8af..f8de6cf5d5b2e 100644 --- a/cloudmock/aws/mockec2/natgateway.go +++ b/cloudmock/aws/mockec2/natgateway.go @@ -38,6 +38,9 @@ func (m *MockEC2) CreateNatGatewayWithId(request *ec2.CreateNatGatewayInput, id Tags: tags, } + // Immediately mark it ready + ngw.State = aws.String("available") + if request.AllocationId != nil { var eip *ec2.Address for _, address := range m.Addresses { @@ -71,10 +74,7 @@ func (m *MockEC2) CreateNatGatewayWithId(request *ec2.CreateNatGatewayInput, id } func (m *MockEC2) CreateNatGateway(request *ec2.CreateNatGatewayInput) (*ec2.CreateNatGatewayOutput, error) { - klog.Infof("CreateNatGateway: %v", request) - - id := m.allocateId("nat") - return m.CreateNatGatewayWithId(request, id) + panic("Not implemented") } func (m *MockEC2) WaitUntilNatGatewayAvailable(request *ec2.DescribeNatGatewaysInput) error { @@ -93,7 +93,7 @@ func (m *MockEC2) WaitUntilNatGatewayAvailable(request *ec2.DescribeNatGatewaysI } // We just immediately mark it ready - ngw.State = aws.String("Available") + ngw.State = aws.String("available") return nil } @@ -102,19 +102,50 @@ func (m *MockEC2) WaitUntilNatGatewayAvailableWithContext(aws.Context, *ec2.Desc panic("Not implemented") } -func (m *MockEC2) CreateNatGatewayWithContext(aws.Context, *ec2.CreateNatGatewayInput, ...request.Option) (*ec2.CreateNatGatewayOutput, error) { - panic("Not implemented") +func (m *MockEC2) CreateNatGatewayWithContext(ctx aws.Context, request *ec2.CreateNatGatewayInput, options ...request.Option) (*ec2.CreateNatGatewayOutput, error) { + klog.Infof("CreateNatGateway: %v", request) + + id := m.allocateId("nat") + return m.CreateNatGatewayWithId(request, id) } func (m *MockEC2) CreateNatGatewayRequest(*ec2.CreateNatGatewayInput) (*request.Request, *ec2.CreateNatGatewayOutput) { panic("Not implemented") } -func (m *MockEC2) DescribeNatGateways(request *ec2.DescribeNatGatewaysInput) (*ec2.DescribeNatGatewaysOutput, error) { +func (m *MockEC2) DescribeNatGateways(*ec2.DescribeNatGatewaysInput) (*ec2.DescribeNatGatewaysOutput, error) { + panic("Not implemented") +} + +func (m *MockEC2) DescribeNatGatewaysWithContext(ctx aws.Context, request *ec2.DescribeNatGatewaysInput, options ...request.Option) (*ec2.DescribeNatGatewaysOutput, error) { + var pages []*ec2.DescribeNatGatewaysOutput + callback := func(page *ec2.DescribeNatGatewaysOutput, lastPage bool) bool { + pages = append(pages, page) + return true + } + + if err := m.DescribeNatGatewaysPagesWithContext(ctx, request, callback, options...); err != nil { + return nil, err + } + if len(pages) == 0 { + return nil, fmt.Errorf("DescribeNatGatewaysPagesWithContext did not return any pages") + } + return pages[0], nil +} + +func (m *MockEC2) DescribeNatGatewaysRequest(*ec2.DescribeNatGatewaysInput) (*request.Request, *ec2.DescribeNatGatewaysOutput) { + panic("Not implemented") +} + +func (m *MockEC2) DescribeNatGatewaysPages(*ec2.DescribeNatGatewaysInput, func(*ec2.DescribeNatGatewaysOutput, bool) bool) error { + panic("Not implemented") +} + +func (m *MockEC2) DescribeNatGatewaysPagesWithContext(ctx aws.Context, request *ec2.DescribeNatGatewaysInput, callback func(*ec2.DescribeNatGatewaysOutput, bool) bool, options ...request.Option) error { m.mutex.Lock() defer m.mutex.Unlock() - klog.Infof("DescribeNatGateways: %v", request) + klog.Infof("DescribeNatGatewaysPagesWithContext: %v", request) var ngws []*ec2.NatGateway @@ -137,7 +168,7 @@ func (m *MockEC2) DescribeNatGateways(request *ec2.DescribeNatGatewaysInput) (*e if strings.HasPrefix(*filter.Name, "tag:") { match = m.hasTag(ec2.ResourceTypeNatgateway, *ngw.NatGatewayId, filter) } else { - return nil, fmt.Errorf("unknown filter name: %q", *filter.Name) + return fmt.Errorf("unknown filter name: %q", *filter.Name) } } @@ -160,26 +191,15 @@ func (m *MockEC2) DescribeNatGateways(request *ec2.DescribeNatGatewaysInput) (*e NatGateways: ngws, } - return response, nil -} - -func (m *MockEC2) DescribeNatGatewaysWithContext(aws.Context, *ec2.DescribeNatGatewaysInput, ...request.Option) (*ec2.DescribeNatGatewaysOutput, error) { - panic("Not implemented") -} - -func (m *MockEC2) DescribeNatGatewaysRequest(*ec2.DescribeNatGatewaysInput) (*request.Request, *ec2.DescribeNatGatewaysOutput) { - panic("Not implemented") -} - -func (m *MockEC2) DescribeNatGatewaysPages(*ec2.DescribeNatGatewaysInput, func(*ec2.DescribeNatGatewaysOutput, bool) bool) error { - panic("Not implemented") + callback(response, false) + return nil } -func (m *MockEC2) DescribeNatGatewaysPagesWithContext(aws.Context, *ec2.DescribeNatGatewaysInput, func(*ec2.DescribeNatGatewaysOutput, bool) bool, ...request.Option) error { +func (m *MockEC2) DeleteNatGateway(request *ec2.DeleteNatGatewayInput) (*ec2.DeleteNatGatewayOutput, error) { panic("Not implemented") } -func (m *MockEC2) DeleteNatGateway(request *ec2.DeleteNatGatewayInput) (*ec2.DeleteNatGatewayOutput, error) { +func (m *MockEC2) DeleteNatGatewayWithContext(ctx aws.Context, request *ec2.DeleteNatGatewayInput, options ...request.Option) (*ec2.DeleteNatGatewayOutput, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -195,10 +215,6 @@ func (m *MockEC2) DeleteNatGateway(request *ec2.DeleteNatGatewayInput) (*ec2.Del return &ec2.DeleteNatGatewayOutput{}, nil } -func (m *MockEC2) DeleteNatGatewayWithContext(aws.Context, *ec2.DeleteNatGatewayInput, ...request.Option) (*ec2.DeleteNatGatewayOutput, error) { - panic("Not implemented") -} - func (m *MockEC2) DeleteNatGatewayRequest(*ec2.DeleteNatGatewayInput) (*request.Request, *ec2.DeleteNatGatewayOutput) { panic("Not implemented") } diff --git a/pkg/resources/aws/aws.go b/pkg/resources/aws/aws.go index 6c153f2e2401b..d30855a889018 100644 --- a/pkg/resources/aws/aws.go +++ b/pkg/resources/aws/aws.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "errors" "fmt" "strings" @@ -735,6 +736,8 @@ func DeleteSubnet(cloud fi.Cloud, tracker *resources.Resource) error { } func ListSubnets(cloud fi.Cloud, vpcID, clusterName string) ([]*resources.Resource, error) { + ctx := context.TODO() + c := cloud.(awsup.AWSCloud) subnets, err := DescribeSubnets(cloud) if err != nil { @@ -833,20 +836,20 @@ func ListSubnets(cloud fi.Cloud, vpcID, clusterName string) ([]*resources.Resour klog.V(2).Infof("Querying Nat Gateways") request := &ec2.DescribeNatGatewaysInput{} - response, err := c.EC2().DescribeNatGateways(request) - if err != nil { - return nil, fmt.Errorf("error describing NatGateways: %v", err) - } + if err := c.EC2().DescribeNatGatewaysPagesWithContext(ctx, request, func(page *ec2.DescribeNatGatewaysOutput, lastPage bool) bool { + for _, ngw := range page.NatGateways { + id := aws.StringValue(ngw.NatGatewayId) + if !natGatewayIds.Has(id) { + continue + } - for _, ngw := range response.NatGateways { - id := aws.StringValue(ngw.NatGatewayId) - if !natGatewayIds.Has(id) { - continue + forceShared := sharedNgwIds.Has(id) || !ownedNatGatewayIds.Has(id) + r := buildNatGatewayResource(ngw, forceShared, clusterName) + resourceTrackers = append(resourceTrackers, r) } - - forceShared := sharedNgwIds.Has(id) || !ownedNatGatewayIds.Has(id) - r := buildNatGatewayResource(ngw, forceShared, clusterName) - resourceTrackers = append(resourceTrackers, r) + return true + }); err != nil { + return nil, fmt.Errorf("listing NatGateways: %w", err) } } @@ -1296,6 +1299,8 @@ func FindAutoScalingLaunchTemplates(cloud fi.Cloud, clusterName string) ([]*reso } func FindNatGateways(cloud fi.Cloud, routeTables map[string]*resources.Resource, clusterName string) ([]*resources.Resource, error) { + ctx := context.TODO() + if len(routeTables) == 0 { return nil, nil } @@ -1347,7 +1352,7 @@ func FindNatGateways(cloud fi.Cloud, routeTables map[string]*resources.Resource, request := &ec2.DescribeNatGatewaysInput{ NatGatewayIds: []*string{aws.String(natGatewayId)}, } - response, err := c.EC2().DescribeNatGateways(request) + response, err := c.EC2().DescribeNatGatewaysWithContext(ctx, request) if err != nil { if awsup.AWSErrorCode(err) == "NatGatewayNotFound" { klog.V(2).Infof("Got NatGatewayNotFound describing NatGateway %s; will treat as already-deleted", natGatewayId) @@ -1768,6 +1773,8 @@ func DeleteElasticIP(cloud fi.Cloud, t *resources.Resource) error { } func DeleteNatGateway(cloud fi.Cloud, t *resources.Resource) error { + ctx := context.TODO() + c := cloud.(awsup.AWSCloud) id := t.ID @@ -1776,12 +1783,12 @@ func DeleteNatGateway(cloud fi.Cloud, t *resources.Resource) error { request := &ec2.DeleteNatGatewayInput{ NatGatewayId: &id, } - _, err := c.EC2().DeleteNatGateway(request) + _, err := c.EC2().DeleteNatGatewayWithContext(ctx, request) if err != nil { if IsDependencyViolation(err) { return err } - return fmt.Errorf("error deleting ngw %q: %v", t.Name, err) + return fmt.Errorf("error deleting nat gateway %q: %v", t.Name, err) } return nil } diff --git a/upup/pkg/fi/cloudup/awstasks/elastic_ip.go b/upup/pkg/fi/cloudup/awstasks/elastic_ip.go index cca8e57989bb9..e93b930a21dfe 100644 --- a/upup/pkg/fi/cloudup/awstasks/elastic_ip.go +++ b/upup/pkg/fi/cloudup/awstasks/elastic_ip.go @@ -17,6 +17,7 @@ limitations under the License. package awstasks import ( + "context" "fmt" "github.com/aws/aws-sdk-go/aws" @@ -59,18 +60,19 @@ func (e *ElasticIP) CompareWithID() *string { } // Find returns the actual ElasticIP state, or nil if not found -func (e *ElasticIP) Find(context *fi.CloudupContext) (*ElasticIP, error) { - return e.find(context.T.Cloud.(awsup.AWSCloud)) +func (e *ElasticIP) Find(c *fi.CloudupContext) (*ElasticIP, error) { + ctx := c.Context() + return e.find(ctx, c.T.Cloud.(awsup.AWSCloud)) } // find will attempt to look up the elastic IP from AWS -func (e *ElasticIP) find(cloud awsup.AWSCloud) (*ElasticIP, error) { +func (e *ElasticIP) find(ctx context.Context, cloud awsup.AWSCloud) (*ElasticIP, error) { publicIP := e.PublicIP allocationID := e.ID // Find via RouteTable -> NatGateway -> ElasticIP if allocationID == nil && publicIP == nil && e.AssociatedNatGatewayRouteTable != nil { - ngw, err := findNatGatewayFromRouteTable(cloud, e.AssociatedNatGatewayRouteTable) + ngw, err := findNatGatewayFromRouteTable(ctx, cloud, e.AssociatedNatGatewayRouteTable) if err != nil { return nil, fmt.Errorf("error finding AssociatedNatGatewayRouteTable: %v", err) } diff --git a/upup/pkg/fi/cloudup/awstasks/natgateway.go b/upup/pkg/fi/cloudup/awstasks/natgateway.go index 69fc68187f976..4dffe0b9f534d 100644 --- a/upup/pkg/fi/cloudup/awstasks/natgateway.go +++ b/upup/pkg/fi/cloudup/awstasks/natgateway.go @@ -17,10 +17,12 @@ limitations under the License. package awstasks import ( + "context" "fmt" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/klog/v2" raws "k8s.io/kops/pkg/resources/aws" @@ -59,32 +61,21 @@ func (e *NatGateway) CompareWithID() *string { } func (e *NatGateway) Find(c *fi.CloudupContext) (*NatGateway, error) { + ctx := c.Context() + cloud := c.T.Cloud.(awsup.AWSCloud) var ngw *ec2.NatGateway actual := &NatGateway{} if fi.ValueOf(e.ID) != "" { - // We have an existing NGW, lets look up the EIP - var ngwIds []*string - ngwIds = append(ngwIds, e.ID) - - request := &ec2.DescribeNatGatewaysInput{ - NatGatewayIds: ngwIds, - } - - response, err := cloud.EC2().DescribeNatGateways(request) + found, err := findNatGatewayByID(ctx, cloud, fi.ValueOf(e.ID)) if err != nil { - return nil, fmt.Errorf("error listing Nat Gateways %v", err) - } - - if len(response.NatGateways) != 1 { - return nil, fmt.Errorf("found %d Nat Gateways with ID %q, expected 1", len(response.NatGateways), fi.ValueOf(e.ID)) + return nil, err } - ngw = response.NatGateways[0] - - if len(ngw.NatGatewayAddresses) != 1 { + if len(found.NatGatewayAddresses) != 1 { return nil, fmt.Errorf("found %d EIP Addresses for 1 NATGateway, expected 1", len(ngw.NatGatewayAddresses)) } + ngw = found } else { // This is the normal/default path var err error @@ -127,13 +118,15 @@ func (e *NatGateway) Find(c *fi.CloudupContext) (*NatGateway, error) { } func (e *NatGateway) findNatGateway(c *fi.CloudupContext) (*ec2.NatGateway, error) { + ctx := c.Context() + cloud := c.T.Cloud.(awsup.AWSCloud) id := e.ID // Find via route on private route table if id == nil && e.AssociatedRouteTable != nil { - ngw, err := findNatGatewayFromRouteTable(cloud, e.AssociatedRouteTable) + ngw, err := findNatGatewayFromRouteTable(ctx, cloud, e.AssociatedRouteTable) if err != nil { return nil, err } @@ -175,31 +168,29 @@ func (e *NatGateway) findNatGateway(c *fi.CloudupContext) (*ec2.NatGateway, erro } if id != nil { - return findNatGatewayById(cloud, id) + return findNatGatewayByID(ctx, cloud, *id) } return nil, nil } -func findNatGatewayById(cloud awsup.AWSCloud, id *string) (*ec2.NatGateway, error) { - request := &ec2.DescribeNatGatewaysInput{} - request.NatGatewayIds = []*string{id} - response, err := cloud.EC2().DescribeNatGateways(request) - if err != nil { - return nil, fmt.Errorf("error listing NatGateway %q: %v", aws.StringValue(id), err) +func findNatGatewayByID(ctx context.Context, cloud awsup.AWSCloud, id string) (*ec2.NatGateway, error) { + request := &ec2.DescribeNatGatewaysInput{ + NatGatewayIds: []*string{&id}, } - if response == nil || len(response.NatGateways) == 0 { - klog.V(2).Infof("Unable to find NatGateway %q", aws.StringValue(id)) - return nil, nil + response, err := cloud.EC2().DescribeNatGatewaysWithContext(ctx, request) + if err != nil { + return nil, fmt.Errorf("finding NatGateway with id %q: %w", id, err) } + if len(response.NatGateways) != 1 { - return nil, fmt.Errorf("found multiple NatGateways with id %q", aws.StringValue(id)) + return nil, fmt.Errorf("found multiple NatGateways with ID %q (got %d, expected 1)", id, len(response.NatGateways)) } return response.NatGateways[0], nil } -func findNatGatewayFromRouteTable(cloud awsup.AWSCloud, routeTable *RouteTable) (*ec2.NatGateway, error) { +func findNatGatewayFromRouteTable(ctx context.Context, cloud awsup.AWSCloud, routeTable *RouteTable) (*ec2.NatGateway, error) { // Find via route on private route table if routeTable.ID != nil { klog.V(2).Infof("trying to match NatGateway via RouteTable %s", *routeTable.ID) @@ -209,15 +200,15 @@ func findNatGatewayFromRouteTable(cloud awsup.AWSCloud, routeTable *RouteTable) } if rt != nil { - var natGatewayIDs []*string - natGatewayIDsSeen := map[string]bool{} + natGatewayIDsSeen := sets.New[string]() for _, route := range rt.Routes { - if route.NatGatewayId != nil && !natGatewayIDsSeen[*route.NatGatewayId] { - natGatewayIDs = append(natGatewayIDs, route.NatGatewayId) - natGatewayIDsSeen[*route.NatGatewayId] = true + if route.NatGatewayId != nil { + natGatewayIDsSeen.Insert(*route.NatGatewayId) } } + natGatewayIDs := natGatewayIDsSeen.UnsortedList() + if len(natGatewayIDs) == 0 { klog.V(2).Infof("no NatGateway found in route table %s", *rt.RouteTableId) } else if len(natGatewayIDs) > 1 { @@ -227,12 +218,12 @@ func findNatGatewayFromRouteTable(cloud awsup.AWSCloud, routeTable *RouteTable) } filteredNatGateways := []*ec2.NatGateway{} for _, natGatewayID := range natGatewayIDs { - gw, err := findNatGatewayById(cloud, natGatewayID) + gw, err := findNatGatewayByID(ctx, cloud, natGatewayID) if err != nil { return nil, err } - if raws.HasOwnedTag(ec2.ResourceTypeNatgateway+":"+fi.ValueOf(natGatewayID), gw.Tags, clusterName) { + if raws.HasOwnedTag(ec2.ResourceTypeNatgateway+":"+natGatewayID, gw.Tags, clusterName) { filteredNatGateways = append(filteredNatGateways, gw) } } @@ -244,7 +235,7 @@ func findNatGatewayFromRouteTable(cloud awsup.AWSCloud, routeTable *RouteTable) return filteredNatGateways[0], nil } } else { - return findNatGatewayById(cloud, natGatewayIDs[0]) + return findNatGatewayByID(ctx, cloud, natGatewayIDs[0]) } } } @@ -296,11 +287,11 @@ func (e *NatGateway) Run(c *fi.CloudupContext) error { } func (_ *NatGateway) RenderAWS(t *awsup.AWSAPITarget, a, e, changes *NatGateway) error { - // New NGW + ctx := context.TODO() var id *string if a == nil { - + // New NGW if fi.ValueOf(e.Shared) { return fmt.Errorf("NAT gateway %q not found", fi.ValueOf(e.ID)) } @@ -312,7 +303,7 @@ func (_ *NatGateway) RenderAWS(t *awsup.AWSAPITarget, a, e, changes *NatGateway) } request.AllocationId = e.ElasticIP.ID request.SubnetId = e.Subnet.ID - response, err := t.Cloud.EC2().CreateNatGateway(request) + response, err := t.Cloud.EC2().CreateNatGatewayWithContext(ctx, request) if err != nil { return fmt.Errorf("Error creating Nat Gateway: %v", err) } @@ -322,6 +313,34 @@ func (_ *NatGateway) RenderAWS(t *awsup.AWSAPITarget, a, e, changes *NatGateway) id = a.ID } + // Ensure the nat gateway is ready + { + found, err := findNatGatewayByID(ctx, t.Cloud, *e.ID) + if err != nil { + return fmt.Errorf("reading created nat gateway: %w", err) + } + + switch state := aws.StringValue(found.State); state { + case "pending": + // The NAT gateway is being created and is not ready to process traffic. + return fi.NewTryAgainLaterError("waiting for the NAT gateway to be ready") + + case "failed": + // The NAT gateway could not be created. + message := fmt.Sprintf("the NAT gateway failed in AWS; AWS failureCode=%q; AWS failureMessage=%q", aws.StringValue(found.FailureCode), aws.StringValue(found.FailureMessage)) + return fi.NewTryAgainLaterError(message) + + case "available": + // The NAT gateway is able to process traffic. + break + + case "deleting", "deleted": + return fmt.Errorf("the NAT gateway is being deleted (state=%q)", state) + default: + return fmt.Errorf("the NAT gateway is in an unknown state %q", state) + } + } + err := t.AddAWSTags(*e.ID, e.Tags) if err != nil { return fmt.Errorf("unable to tag NatGateway")