From 9364c71b81cbdc6179ea29271bd91ba6e9ac2f7d Mon Sep 17 00:00:00 2001 From: Aswin Suryanarayanan Date: Thu, 20 Jun 2024 17:53:58 -0400 Subject: [PATCH] Fix Unit tests in AWS cloud-prepare Signed-off-by: Aswin Suryanarayanan --- pkg/aws/aws.go | 6 ++- pkg/aws/aws_cloud_test.go | 6 +++ pkg/aws/aws_suite_test.go | 69 ++++++++++++++++++++++++++--------- pkg/aws/ocpgwdeployer_test.go | 3 ++ 4 files changed, 66 insertions(+), 18 deletions(-) diff --git a/pkg/aws/aws.go b/pkg/aws/aws.go index 68ed99d0..8bc5b2cb 100644 --- a/pkg/aws/aws.go +++ b/pkg/aws/aws.go @@ -99,10 +99,14 @@ func (ac *awsCloud) setSuffixes(vpcID string) error { } publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*")) - if err != nil || len(publicSubnets) == 0 { + if err != nil { return errors.Wrapf(err, "unable to find the public subnet") } + if len(publicSubnets) == 0 { + return errors.New("no public subnet found") + } + pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region)) re := regexp.MustCompile(pattern) diff --git a/pkg/aws/aws_cloud_test.go b/pkg/aws/aws_cloud_test.go index b05369ff..e5c82a97 100644 --- a/pkg/aws/aws_cloud_test.go +++ b/pkg/aws/aws_cloud_test.go @@ -40,6 +40,8 @@ func testOpenPorts() { JustBeforeEach(func() { t.expectDescribeVpcs(t.vpcID) + t.expectDescribeVpcsSigs(t.vpcID) + t.expectDescribePublicSubnets(t.subnets...) retError = t.cloud.OpenPorts([]api.PortSpec{ { @@ -87,6 +89,7 @@ func testOpenPorts() { When("authorize security group ingress validation fails", func() { BeforeEach(func() { t.expectDescribeVpcs(vpcID) + t.expectDescribePublicSubnets(t.subnets...) t.expectValidateAuthorizeSecurityGroupIngress(errors.New("mock error")) }) @@ -114,6 +117,9 @@ func testClosePorts() { JustBeforeEach(func() { t.expectDescribeVpcs(t.vpcID) + t.expectDescribePublicSubnets(t.subnets...) + t.expectDescribeVpcsSigs(t.vpcID) + t.expectDescribePublicSubnetsSigs(t.subnets...) retError = t.cloud.ClosePorts(reporter.Stdout()) }) diff --git a/pkg/aws/aws_suite_test.go b/pkg/aws/aws_suite_test.go index 806257e9..08767c25 100644 --- a/pkg/aws/aws_suite_test.go +++ b/pkg/aws/aws_suite_test.go @@ -35,22 +35,23 @@ import ( ) const ( - infraID = "test-infra" - region = "test-region" - vpcID = "test-vpc" - workerGroupID = "worker-group" - masterGroupID = "master-group" - gatewayGroupID = "gateway-group" - internalTraffic = "Internal Submariner traffic" - availabilityZone1 = "availability-zone-1" - availabilityZone2 = "availability-zone-2" - subnetID1 = "subnet-1" - subnetID2 = "subnet-2" - instanceImageID = "test-image" - masterSGName = infraID + "-master-sg" - workerSGName = infraID + "-worker-sg" - gatewaySGName = infraID + "-submariner-gw-sg" - clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID + infraID = "test-infra" + region = "test-region" + vpcID = "test-vpc" + workerGroupID = "worker-group" + masterGroupID = "master-group" + gatewayGroupID = "gateway-group" + internalTraffic = "Internal Submariner traffic" + availabilityZone1 = "availability-zone-1" + availabilityZone2 = "availability-zone-2" + subnetID1 = "subnet-1" + subnetID2 = "subnet-2" + instanceImageID = "test-image" + masterSGName = infraID + "-master-sg" + workerSGName = infraID + "-worker-sg" + gatewaySGName = infraID + "-submariner-gw-sg" + clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID + clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID ) var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic) @@ -64,6 +65,7 @@ type fakeAWSClientBase struct { awsClient *fake.MockInterface mockCtrl *gomock.Controller vpcID string + subnets []types.Subnet describeSubnetsErr error authorizeSecurityGroupIngressErr error createTagsErr error @@ -74,6 +76,7 @@ func (f *fakeAWSClientBase) beforeEach() { f.mockCtrl = gomock.NewController(GinkgoT()) f.awsClient = fake.NewMockInterface(f.mockCtrl) f.vpcID = vpcID + f.subnets = []types.Subnet{newSubnet(availabilityZone1, subnetID1), newSubnet(availabilityZone2, subnetID2)} f.describeSubnetsErr = nil f.authorizeSecurityGroupIngressErr = nil f.createTagsErr = nil @@ -113,6 +116,25 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) { })).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).AnyTimes() } +func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) { + var vpcs []types.Vpc + if vpcID != "" { + vpcs = []types.Vpc{ + { + VpcId: awssdk.String(vpcID), + }, + } + } + + f.awsClient.EXPECT().DescribeVpcs(gomock.Any(), eqFilters(types.Filter{ + Name: awssdk.String("tag:Name"), + Values: []string{infraID + "-vpc"}, + }, types.Filter{ + Name: awssdk.String(clusterFilterTagNameSigs), + Values: []string{"owned"}, + })).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).AnyTimes() +} + func (f *fakeAWSClientBase) expectValidateAuthorizeSecurityGroupIngress(authErr error) *gomock.Call { return f.awsClient.EXPECT().AuthorizeSecurityGroupIngress(gomock.Any(), mock.Eq(&ec2.AuthorizeSecurityGroupIngressInput{ DryRun: awssdk.Bool(true), @@ -143,7 +165,7 @@ func (f *fakeAWSClientBase) expectValidateRevokeSecurityGroupIngress(retErr erro func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subnet) { f.awsClient.EXPECT().DescribeSubnets(gomock.Any(), eqFilters(types.Filter{ Name: awssdk.String("tag:Name"), - Values: []string{infraID + "-public-" + region + "*"}, + Values: []string{infraID + "*-public-" + region + "*"}, }, types.Filter{ Name: awssdk.String("vpc-id"), Values: []string{f.vpcID}, @@ -153,6 +175,19 @@ func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subn })).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).AnyTimes() } +func (f *fakeAWSClientBase) expectDescribePublicSubnetsSigs(retSubnets ...types.Subnet) { + f.awsClient.EXPECT().DescribeSubnets(gomock.Any(), eqFilters(types.Filter{ + Name: awssdk.String("tag:Name"), + Values: []string{infraID + "*-public-" + region + "*"}, + }, types.Filter{ + Name: awssdk.String("vpc-id"), + Values: []string{f.vpcID}, + }, types.Filter{ + Name: awssdk.String(clusterFilterTagNameSigs), + Values: []string{"owned"}, + })).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).AnyTimes() +} + func (f *fakeAWSClientBase) expectDescribeGatewaySubnets(retSubnets ...types.Subnet) { f.awsClient.EXPECT().DescribeSubnets(gomock.Any(), eqFilters(types.Filter{ Name: awssdk.String("tag:submariner.io/gateway"), diff --git a/pkg/aws/ocpgwdeployer_test.go b/pkg/aws/ocpgwdeployer_test.go index 833f7338..edac284e 100644 --- a/pkg/aws/ocpgwdeployer_test.go +++ b/pkg/aws/ocpgwdeployer_test.go @@ -283,6 +283,9 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver { t.expectDescribeSecurityGroups(gatewaySGName, t.gatewayGroupID) t.expectDescribeInstances(instanceImageID) t.expectDescribeSecurityGroups(workerSGName, workerGroupID) + t.expectDescribePublicSubnets(t.subnets...) + t.expectDescribeVpcsSigs(t.vpcID) + t.expectDescribePublicSubnetsSigs(t.subnets...) var err error