diff --git a/api/v1beta2/conditions_consts.go b/api/v1beta2/conditions_consts.go index bfbb96c77a..604ef8e1d5 100644 --- a/api/v1beta2/conditions_consts.go +++ b/api/v1beta2/conditions_consts.go @@ -69,6 +69,14 @@ const ( EgressOnlyInternetGatewayFailedReason = "EgressOnlyInternetGatewayFailed" ) +const ( + // CarrierGatewayReadyCondition reports on the successful reconciliation of carrier gateways. + // Only applicable to managed clusters. + CarrierGatewayReadyCondition clusterv1.ConditionType = "CarrierGatewayReady" + // CarrierGatewayFailedReason used when errors occur during carrier gateway reconciliation. + CarrierGatewayFailedReason = "CarrierGatewayFailed" +) + const ( // NatGatewaysReadyCondition reports successful reconciliation of NAT gateways. // Only applicable to managed clusters. diff --git a/pkg/cloud/awserrors/errors.go b/pkg/cloud/awserrors/errors.go index 5312e4fe42..d51b41595c 100644 --- a/pkg/cloud/awserrors/errors.go +++ b/pkg/cloud/awserrors/errors.go @@ -33,6 +33,7 @@ const ( GatewayNotFound = "InvalidGatewayID.NotFound" GroupNotFound = "InvalidGroup.NotFound" InternetGatewayNotFound = "InvalidInternetGatewayID.NotFound" + InvalidCarrierGatewayNotFound = "InvalidCarrierGatewayID.NotFound" EgressOnlyInternetGatewayNotFound = "InvalidEgressOnlyInternetGatewayID.NotFound" InUseIPAddress = "InvalidIPAddress.InUse" InvalidAccessKeyID = "InvalidAccessKeyId" diff --git a/pkg/cloud/services/network/carriergateways.go b/pkg/cloud/services/network/carriergateways.go new file mode 100644 index 0000000000..6237df9052 --- /dev/null +++ b/pkg/cloud/services/network/carriergateways.go @@ -0,0 +1,145 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package network + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/pkg/errors" + + infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/awserrors" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/converters" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/filter" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/wait" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/tags" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/record" + "sigs.k8s.io/cluster-api/util/conditions" +) + +func (s *Service) reconcileCarrierGateway() error { + if s.scope.VPC().IsUnmanaged(s.scope.Name()) { + s.scope.Trace("Skipping carrier gateway reconcile in unmanaged mode") + return nil + } + + if !s.scope.Subnets().HasPublicSubnetWavelength() { + s.scope.Trace("Skipping carrier gateway reconcile in VPC without subnets in zone type wavelength-zone") + return nil + } + + s.scope.Debug("Reconciling carrier gateway") + + cagw, err := s.describeVpcCarrierGateway() + if awserrors.IsNotFound(err) { + if s.scope.VPC().IsUnmanaged(s.scope.Name()) { + return errors.Errorf("failed to validate network: no carrier gateway found in VPC %q", s.scope.VPC().ID) + } + + cg, err := s.createCarrierGateway() + if err != nil { + return err + } + cagw = cg + } else if err != nil { + return err + } + + s.scope.VPC().CarrierGatewayID = cagw.CarrierGatewayId + + // Make sure tags are up-to-date. + if err := wait.WaitForWithRetryable(wait.NewBackoff(), func() (bool, error) { + buildParams := s.getGatewayTagParams(*cagw.CarrierGatewayId) + tagsBuilder := tags.New(&buildParams, tags.WithEC2(s.EC2Client)) + if err := tagsBuilder.Ensure(converters.TagsToMap(cagw.Tags)); err != nil { + return false, err + } + return true, nil + }, awserrors.InvalidCarrierGatewayNotFound); err != nil { + record.Warnf(s.scope.InfraCluster(), "FailedTagCarrierGateway", "Failed to tag managed Carrier Gateway %q: %v", cagw.CarrierGatewayId, err) + return errors.Wrapf(err, "failed to tag carrier gateway %q", *cagw.CarrierGatewayId) + } + conditions.MarkTrue(s.scope.InfraCluster(), infrav1.CarrierGatewayReadyCondition) + return nil +} + +func (s *Service) deleteCarrierGateway() error { + if s.scope.VPC().IsUnmanaged(s.scope.Name()) { + s.scope.Trace("Skipping carrier gateway deletion in unmanaged mode") + return nil + } + + cagw, err := s.describeVpcCarrierGateway() + if awserrors.IsNotFound(err) { + return nil + } else if err != nil { + return err + } + + deleteReq := &ec2.DeleteCarrierGatewayInput{ + CarrierGatewayId: cagw.CarrierGatewayId, + } + + if _, err = s.EC2Client.DeleteCarrierGatewayWithContext(context.TODO(), deleteReq); err != nil { + record.Warnf(s.scope.InfraCluster(), "FailedDeleteCarrierGateway", "Failed to delete Carrier Gateway %q previously attached to VPC %q: %v", *cagw.CarrierGatewayId, s.scope.VPC().ID, err) + return errors.Wrapf(err, "failed to delete carrier gateway %q", *cagw.CarrierGatewayId) + } + + record.Eventf(s.scope.InfraCluster(), "SuccessfulDeleteCarrierGateway", "Deleted Carrier Gateway %q previously attached to VPC %q", *cagw.CarrierGatewayId, s.scope.VPC().ID) + s.scope.Info("Deleted Carrier Gateway in VPC", "carrier-gateway-id", *cagw.CarrierGatewayId, "vpc-id", s.scope.VPC().ID) + + return nil +} + +func (s *Service) createCarrierGateway() (*ec2.CarrierGateway, error) { + ig, err := s.EC2Client.CreateCarrierGatewayWithContext(context.TODO(), &ec2.CreateCarrierGatewayInput{ + VpcId: aws.String(s.scope.VPC().ID), + TagSpecifications: []*ec2.TagSpecification{ + tags.BuildParamsToTagSpecification(ec2.ResourceTypeCarrierGateway, s.getGatewayTagParams(services.TemporaryResourceID)), + }, + }) + if err != nil { + record.Warnf(s.scope.InfraCluster(), "FailedCreateCarrierGateway", "Failed to create new managed Internet Gateway: %v", err) + return nil, errors.Wrap(err, "failed to create carrier gateway") + } + record.Eventf(s.scope.InfraCluster(), "SuccessfulCreateCarrierGateway", "Created new managed Internet Gateway %q", *ig.CarrierGateway.CarrierGatewayId) + s.scope.Info("Created Internet gateway for VPC", "internet-gateway-id", *ig.CarrierGateway.CarrierGatewayId, "vpc-id", s.scope.VPC().ID) + + return ig.CarrierGateway, nil +} + +func (s *Service) describeVpcCarrierGateway() (*ec2.CarrierGateway, error) { + out, err := s.EC2Client.DescribeCarrierGatewaysWithContext(context.TODO(), &ec2.DescribeCarrierGatewaysInput{ + Filters: []*ec2.Filter{ + filter.EC2.VPC(s.scope.VPC().ID), + }, + }) + if err != nil { + record.Eventf(s.scope.InfraCluster(), "FailedDescribeCarrierGateway", "Failed to describe carrier gateways in vpc %q: %v", s.scope.VPC().ID, err) + return nil, errors.Wrapf(err, "failed to describe carrier gateways in vpc %q", s.scope.VPC().ID) + } + + if len(out.CarrierGateways) == 0 { + return nil, awserrors.NewNotFound(fmt.Sprintf("no carrier gateways found in vpc %q", s.scope.VPC().ID)) + } + + return out.CarrierGateways[0], nil +} diff --git a/pkg/cloud/services/network/carriergateways_test.go b/pkg/cloud/services/network/carriergateways_test.go new file mode 100644 index 0000000000..6608375c72 --- /dev/null +++ b/pkg/cloud/services/network/carriergateways_test.go @@ -0,0 +1,257 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package network + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/golang/mock/gomock" + . "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" + "sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks" + clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" +) + +func TestReconcileCarrierGateway(t *testing.T) { + testCases := []struct { + name string + input *infrav1.NetworkSpec + expect func(m *mocks.MockEC2APIMockRecorder) + }{ + { + name: "has cagw", + input: &infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: "vpc-cagw", + Tags: infrav1.Tags{ + infrav1.ClusterTagKey("test-cluster"): "owned", + }, + }, + }, + expect: func(m *mocks.MockEC2APIMockRecorder) { + m.DescribeCarrierGatewaysWithContext(context.TODO(), gomock.Eq(&ec2.DescribeCarrierGatewaysInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("vpc-id"), + Values: aws.StringSlice([]string{"vpc-cagw"}), + }, + }, + })). + Return(&ec2.DescribeCarrierGatewaysOutput{ + CarrierGateways: []*ec2.CarrierGateway{ + { + CarrierGatewayId: ptr.To("cagw-01"), + }, + }, + }, nil).AnyTimes() + + m.CreateTagsWithContext(context.TODO(), gomock.AssignableToTypeOf(&ec2.CreateTagsInput{})). + Return(nil, nil).AnyTimes() + }, + }, + { + name: "no cagw attached, creates one", + input: &infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: "vpc-cagw", + Tags: infrav1.Tags{ + infrav1.ClusterTagKey("test-cluster"): "owned", + }, + }, + }, + expect: func(m *mocks.MockEC2APIMockRecorder) { + m.DescribeCarrierGatewaysWithContext(context.TODO(), gomock.AssignableToTypeOf(&ec2.DescribeCarrierGatewaysInput{})). + Return(&ec2.DescribeCarrierGatewaysOutput{}, nil).AnyTimes() + + m.CreateCarrierGatewayWithContext(context.TODO(), gomock.AssignableToTypeOf(&ec2.CreateCarrierGatewayInput{})). + Return(&ec2.CreateCarrierGatewayOutput{ + CarrierGateway: &ec2.CarrierGateway{ + CarrierGatewayId: aws.String("cagw-1"), + VpcId: aws.String("vpc-cagw"), + Tags: []*ec2.Tag{ + { + Key: aws.String(infrav1.ClusterTagKey("test-cluster")), + Value: aws.String("owned"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/role"), + Value: aws.String("common"), + }, + { + Key: aws.String("Name"), + Value: aws.String("test-cluster-cagw"), + }, + }, + }, + }, nil).AnyTimes() + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + ec2Mock := mocks.NewMockEC2API(mockCtrl) + + scheme := runtime.NewScheme() + _ = infrav1.AddToScheme(scheme) + client := fake.NewClientBuilder().WithScheme(scheme).Build() + scope, err := scope.NewClusterScope(scope.ClusterScopeParams{ + Client: client, + Cluster: &clusterv1.Cluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test-cluster"}, + }, + AWSCluster: &infrav1.AWSCluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + Spec: infrav1.AWSClusterSpec{ + NetworkSpec: *tc.input, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create test context: %v", err) + } + + tc.expect(ec2Mock.EXPECT()) + + s := NewService(scope) + s.EC2Client = ec2Mock + + if err := s.reconcileCarrierGateway(); err != nil { + t.Fatalf("got an unexpected error: %v", err) + } + mockCtrl.Finish() + }) + } +} + +func TestDeleteCarrierGateway(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + testCases := []struct { + name string + input *infrav1.NetworkSpec + expect func(m *mocks.MockEC2APIMockRecorder) + wantErr bool + }{ + { + name: "Should ignore deletion if vpc is unmanaged", + input: &infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: "vpc-cagw", + }, + }, + expect: func(m *mocks.MockEC2APIMockRecorder) {}, + }, + { + name: "Should ignore deletion if carrier gateway is not found", + input: &infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: "vpc-cagw", + Tags: infrav1.Tags{ + infrav1.ClusterTagKey("test-cluster"): "owned", + }, + }, + }, + expect: func(m *mocks.MockEC2APIMockRecorder) { + m.DescribeCarrierGatewaysWithContext(context.TODO(), gomock.Eq(&ec2.DescribeCarrierGatewaysInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("vpc-id"), + Values: aws.StringSlice([]string{"vpc-cagw"}), + }, + }, + })).Return(&ec2.DescribeCarrierGatewaysOutput{}, nil) + }, + }, + { + name: "Should successfully delete the carrier gateway", + input: &infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: "vpc-cagw", + Tags: infrav1.Tags{ + infrav1.ClusterTagKey("test-cluster"): "owned", + }, + }, + }, + expect: func(m *mocks.MockEC2APIMockRecorder) { + m.DescribeCarrierGatewaysWithContext(context.TODO(), gomock.AssignableToTypeOf(&ec2.DescribeCarrierGatewaysInput{})). + Return(&ec2.DescribeCarrierGatewaysOutput{ + CarrierGateways: []*ec2.CarrierGateway{ + { + CarrierGatewayId: aws.String("cagw-0"), + VpcId: aws.String("vpc-gateways"), + }, + }, + }, nil) + + m.DeleteCarrierGatewayWithContext(context.TODO(), &ec2.DeleteCarrierGatewayInput{ + CarrierGatewayId: aws.String("cagw-0"), + }).Return(&ec2.DeleteCarrierGatewayOutput{}, nil) + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + ec2Mock := mocks.NewMockEC2API(mockCtrl) + + scheme := runtime.NewScheme() + err := infrav1.AddToScheme(scheme) + g.Expect(err).NotTo(HaveOccurred()) + client := fake.NewClientBuilder().WithScheme(scheme).Build() + + scope, err := scope.NewClusterScope(scope.ClusterScopeParams{ + Client: client, + Cluster: &clusterv1.Cluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test-cluster"}, + }, + AWSCluster: &infrav1.AWSCluster{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + Spec: infrav1.AWSClusterSpec{ + NetworkSpec: *tc.input, + }, + }, + }) + g.Expect(err).NotTo(HaveOccurred()) + + tc.expect(ec2Mock.EXPECT()) + + s := NewService(scope) + s.EC2Client = ec2Mock + + err = s.deleteCarrierGateway() + if tc.wantErr { + g.Expect(err).To(HaveOccurred()) + return + } + g.Expect(err).NotTo(HaveOccurred()) + }) + } +} diff --git a/pkg/cloud/services/network/natgateways_test.go b/pkg/cloud/services/network/natgateways_test.go index 8036424131..29dc45ec13 100644 --- a/pkg/cloud/services/network/natgateways_test.go +++ b/pkg/cloud/services/network/natgateways_test.go @@ -873,6 +873,17 @@ func TestGetdNatGatewayForEdgeSubnet(t *testing.T) { }, expect: "natgw-az-1b-second", }, + { + name: "wavelength zones without Nat GW support, public subnet and Nat Gateway for the parent zone, return parent's zone nat gateway", + input: infrav1.SubnetSpec{ + ID: "subnet-7", + CidrBlock: "10.0.10.0/24", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + ZoneType: ptr.To(infrav1.ZoneTypeWavelengthZone), + ParentZoneName: aws.String("us-east-1x"), + }, + expect: "natgw-az-1b-last", + }, // errors { name: "error if the subnet is public", diff --git a/pkg/cloud/services/network/network.go b/pkg/cloud/services/network/network.go index b2363b5aac..e97024fad7 100644 --- a/pkg/cloud/services/network/network.go +++ b/pkg/cloud/services/network/network.go @@ -55,6 +55,12 @@ func (s *Service) ReconcileNetwork() (err error) { return err } + // Carrier Gateway. + if err := s.reconcileCarrierGateway(); err != nil { + conditions.MarkFalse(s.scope.InfraCluster(), infrav1.CarrierGatewayReadyCondition, infrav1.CarrierGatewayFailedReason, infrautilconditions.ErrorConditionAfterInit(s.scope.ClusterObj()), err.Error()) + return err + } + // Egress Only Internet Gateways. if err := s.reconcileEgressOnlyInternetGateways(); err != nil { conditions.MarkFalse(s.scope.InfraCluster(), infrav1.EgressOnlyInternetGatewayReadyCondition, infrav1.EgressOnlyInternetGatewayFailedReason, infrautilconditions.ErrorConditionAfterInit(s.scope.ClusterObj()), err.Error()) @@ -158,6 +164,15 @@ func (s *Service) DeleteNetwork() (err error) { } conditions.MarkFalse(s.scope.InfraCluster(), infrav1.InternetGatewayReadyCondition, clusterv1.DeletedReason, clusterv1.ConditionSeverityInfo, "") + // Carrier Gateway. + if s.scope.VPC().CarrierGatewayID != nil { + if err := s.deleteCarrierGateway(); err != nil { + conditions.MarkFalse(s.scope.InfraCluster(), infrav1.CarrierGatewayReadyCondition, "DeletingFailed", clusterv1.ConditionSeverityWarning, err.Error()) + return err + } + conditions.MarkFalse(s.scope.InfraCluster(), infrav1.CarrierGatewayReadyCondition, clusterv1.DeletedReason, clusterv1.ConditionSeverityInfo, "") + } + // Egress Only Internet Gateways. conditions.MarkFalse(s.scope.InfraCluster(), infrav1.EgressOnlyInternetGatewayReadyCondition, clusterv1.DeletingReason, clusterv1.ConditionSeverityInfo, "") if err := s.scope.PatchObject(); err != nil { diff --git a/pkg/cloud/services/network/routetables.go b/pkg/cloud/services/network/routetables.go index 0c096315b9..66694b2dd3 100644 --- a/pkg/cloud/services/network/routetables.go +++ b/pkg/cloud/services/network/routetables.go @@ -340,6 +340,13 @@ func (s *Service) getGatewayPublicIPv6Route() *ec2.CreateRouteInput { } } +func (s *Service) getCarrierGatewayPublicIPv4Route() *ec2.CreateRouteInput { + return &ec2.CreateRouteInput{ + DestinationCidrBlock: aws.String(services.AnyIPv4CidrBlock), + CarrierGatewayId: aws.String(*s.scope.VPC().CarrierGatewayID), + } +} + func (s *Service) getRouteTableTagParams(id string, public bool, zone string) infrav1.BuildParams { var name strings.Builder @@ -373,6 +380,14 @@ func (s *Service) getRoutesToPublicSubnet(sn *infrav1.SubnetSpec) ([]*ec2.Create return nil, errors.Errorf("can't determine routes for unsupported ipv6 subnet in zone type %q", sn.ZoneType) } + if sn.IsEdgeWavelength() { + if s.scope.VPC().CarrierGatewayID == nil { + return routes, errors.Errorf("failed to create carrier routing table: carrier gateway for VPC %q is not present", s.scope.VPC().ID) + } + routes = append(routes, s.getCarrierGatewayPublicIPv4Route()) + return routes, nil + } + if s.scope.VPC().InternetGatewayID == nil { return routes, errors.Errorf("failed to create routing tables: internet gateway for VPC %q is not present", s.scope.VPC().ID) } diff --git a/pkg/cloud/services/network/routetables_test.go b/pkg/cloud/services/network/routetables_test.go index 05b13222c3..6b6003a2d7 100644 --- a/pkg/cloud/services/network/routetables_test.go +++ b/pkg/cloud/services/network/routetables_test.go @@ -931,6 +931,34 @@ func TestService_getRoutesForSubnet(t *testing.T) { ZoneType: ptr.To(infrav1.ZoneType("local-zone")), ParentZoneName: ptr.To("us-east-1a"), }, + { + ResourceID: "subnet-wl-invalid2z-private", + AvailabilityZone: "us-east-2-wl1-inv-wlz-1", + IsPublic: false, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: ptr.To("us-east-2z"), + }, + { + ResourceID: "subnet-wl-invalid2z-public", + AvailabilityZone: "us-east-2-wl1-inv-wlz-1", + IsPublic: true, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: ptr.To("us-east-2z"), + }, + { + ResourceID: "subnet-wl-1a-private", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + IsPublic: false, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: ptr.To("us-east-1a"), + }, + { + ResourceID: "subnet-wl-1a-public", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + IsPublic: true, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: ptr.To("us-east-1a"), + }, } vpcName := "vpc-test-for-routes" @@ -938,6 +966,7 @@ func TestService_getRoutesForSubnet(t *testing.T) { VPC: infrav1.VPCSpec{ ID: vpcName, InternetGatewayID: aws.String("vpc-igw"), + CarrierGatewayID: aws.String("vpc-cagw"), IPv6: &infrav1.IPv6{ CidrBlock: "2001:db8:1234:1::/64", EgressOnlyInternetGatewayID: aws.String("vpc-eigw"), @@ -1020,6 +1049,21 @@ func TestService_getRoutesForSubnet(t *testing.T) { }, }, }, + { + name: "public ipv4 subnet, wavelength zone, must have ipv4 default route to carrier gateway", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-wl-1a-public", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + IsPublic: true, + }, + want: []*ec2.CreateRouteInput{ + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + CarrierGatewayId: aws.String("vpc-cagw"), + }, + }, + }, // public subnet ipv4, GW not found. { name: "public ipv4 subnet, availability zone, must return error when no internet gateway available", @@ -1051,6 +1095,22 @@ func TestService_getRoutesForSubnet(t *testing.T) { }, wantErrMessage: `failed to create routing tables: internet gateway for VPC "vpc-test-for-routes" is not present`, }, + { + name: "public ipv4 subnet, wavelength zone, must return error when no Carrier Gateway found", + specOverrideNet: func() *infrav1.NetworkSpec { + net := defaultNetwork.DeepCopy() + net.VPC.CarrierGatewayID = nil + return net + }(), + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-wl-1a-public", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + IsPublic: true, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: aws.String("us-east-1a"), + }, + wantErrMessage: `failed to create carrier routing table: carrier gateway for VPC "vpc-test-for-routes" is not present`, + }, // public subnet ipv6, unsupported { name: "public ipv6 subnet, local zone, must return error for unsupported ip version", @@ -1064,6 +1124,19 @@ func TestService_getRoutesForSubnet(t *testing.T) { }, wantErrMessage: `can't determine routes for unsupported ipv6 subnet in zone type "local-zone"`, }, + { + name: "public ipv6 subnet, wavelength zone, must return error for unsupported ip version", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-wl-1a-public", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + IsPublic: true, + IsIPv6: true, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: aws.String("us-east-1a"), + }, + wantErr: true, + wantErrMessage: `can't determine routes for unsupported ipv6 subnet in zone type "wavelength-zone"`, + }, // private subnets { name: "private ipv4 subnet, availability zone, must have ipv4 default route to nat gateway", @@ -1095,6 +1168,22 @@ func TestService_getRoutesForSubnet(t *testing.T) { }, }, }, + { + name: "private ipv4 subnet, wavelength zone, must have ipv4 default route to nat gateway", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-wl-1a-private", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: aws.String("us-east-1a"), + IsPublic: false, + }, + want: []*ec2.CreateRouteInput{ + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + NatGatewayId: aws.String("nat-gw-fromZone-us-east-1a"), + }, + }, + }, // egress-only subnet ipv6 { name: "egress-only ipv6 subnet, availability zone, must have ipv6 default route to egress-only gateway", @@ -1143,6 +1232,18 @@ func TestService_getRoutesForSubnet(t *testing.T) { }, wantErrMessage: `can't determine routes for unsupported ipv6 subnet in zone type "local-zone"`, }, + { + name: "private ipv6 subnet, wavelength zone, must return unsupported", + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-wl-1a-private", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: aws.String("us-east-1a"), + IsIPv6: true, + IsPublic: false, + }, + wantErrMessage: `can't determine routes for unsupported ipv6 subnet in zone type "wavelength-zone"`, + }, // private subnet, gateway not found { name: "private ipv4 subnet, availability zone, must return error when invalid gateway", @@ -1183,6 +1284,24 @@ func TestService_getRoutesForSubnet(t *testing.T) { }, wantErrMessage: `can't determine routes for unsupported ipv6 subnet in zone type "local-zone"`, }, + { + name: "private ipv4 subnet, wavelength zone, must return error when invalid gateway", + specOverrideNet: func() *infrav1.NetworkSpec { + net := new(infrav1.NetworkSpec) + *net = defaultNetwork + net.VPC.CarrierGatewayID = nil + return net + }(), + inputSubnet: &infrav1.SubnetSpec{ + ResourceID: "subnet-wl-1a-private", + AvailabilityZone: "us-east-1-wl1-nyc-wlz-1", + IsIPv6: true, + IsPublic: false, + ZoneType: ptr.To(infrav1.ZoneType("wavelength-zone")), + ParentZoneName: aws.String("us-east-1a"), + }, + wantErrMessage: `can't determine routes for unsupported ipv6 subnet in zone type "wavelength-zone"`, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) {