Skip to content

Commit

Permalink
Fix Unit tests in AWS cloud-prepare
Browse files Browse the repository at this point in the history
Signed-off-by: Aswin Suryanarayanan <[email protected]>
  • Loading branch information
aswinsuryan committed Jun 24, 2024
1 parent 07d55e6 commit 9364c71
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 18 deletions.
6 changes: 5 additions & 1 deletion pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,14 @@ func (ac *awsCloud) setSuffixes(vpcID string) error {
}

publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil || len(publicSubnets) == 0 {
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}

if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
}

pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region))
re := regexp.MustCompile(pattern)

Expand Down
6 changes: 6 additions & 0 deletions pkg/aws/aws_cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func testOpenPorts() {

JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)

retError = t.cloud.OpenPorts([]api.PortSpec{
{
Expand Down Expand Up @@ -87,6 +89,7 @@ func testOpenPorts() {
When("authorize security group ingress validation fails", func() {
BeforeEach(func() {
t.expectDescribeVpcs(vpcID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectValidateAuthorizeSecurityGroupIngress(errors.New("mock error"))
})

Expand Down Expand Up @@ -114,6 +117,9 @@ func testClosePorts() {

JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

retError = t.cloud.ClosePorts(reporter.Stdout())
})
Expand Down
69 changes: 52 additions & 17 deletions pkg/aws/aws_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ import (
)

const (
infraID = "test-infra"
region = "test-region"
vpcID = "test-vpc"
workerGroupID = "worker-group"
masterGroupID = "master-group"
gatewayGroupID = "gateway-group"
internalTraffic = "Internal Submariner traffic"
availabilityZone1 = "availability-zone-1"
availabilityZone2 = "availability-zone-2"
subnetID1 = "subnet-1"
subnetID2 = "subnet-2"
instanceImageID = "test-image"
masterSGName = infraID + "-master-sg"
workerSGName = infraID + "-worker-sg"
gatewaySGName = infraID + "-submariner-gw-sg"
clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID
infraID = "test-infra"
region = "test-region"
vpcID = "test-vpc"
workerGroupID = "worker-group"
masterGroupID = "master-group"
gatewayGroupID = "gateway-group"
internalTraffic = "Internal Submariner traffic"
availabilityZone1 = "availability-zone-1"
availabilityZone2 = "availability-zone-2"
subnetID1 = "subnet-1"
subnetID2 = "subnet-2"
instanceImageID = "test-image"
masterSGName = infraID + "-master-sg"
workerSGName = infraID + "-worker-sg"
gatewaySGName = infraID + "-submariner-gw-sg"
clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID
clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID
)

var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic)
Expand All @@ -64,6 +65,7 @@ type fakeAWSClientBase struct {
awsClient *fake.MockInterface
mockCtrl *gomock.Controller
vpcID string
subnets []types.Subnet
describeSubnetsErr error
authorizeSecurityGroupIngressErr error
createTagsErr error
Expand All @@ -74,6 +76,7 @@ func (f *fakeAWSClientBase) beforeEach() {
f.mockCtrl = gomock.NewController(GinkgoT())
f.awsClient = fake.NewMockInterface(f.mockCtrl)
f.vpcID = vpcID
f.subnets = []types.Subnet{newSubnet(availabilityZone1, subnetID1), newSubnet(availabilityZone2, subnetID2)}
f.describeSubnetsErr = nil
f.authorizeSecurityGroupIngressErr = nil
f.createTagsErr = nil
Expand Down Expand Up @@ -113,6 +116,25 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) {
})).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).AnyTimes()
}

func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) {
var vpcs []types.Vpc
if vpcID != "" {
vpcs = []types.Vpc{
{
VpcId: awssdk.String(vpcID),
},
}
}

f.awsClient.EXPECT().DescribeVpcs(gomock.Any(), eqFilters(types.Filter{
Name: awssdk.String("tag:Name"),
Values: []string{infraID + "-vpc"},
}, types.Filter{
Name: awssdk.String(clusterFilterTagNameSigs),
Values: []string{"owned"},
})).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).AnyTimes()
}

func (f *fakeAWSClientBase) expectValidateAuthorizeSecurityGroupIngress(authErr error) *gomock.Call {
return f.awsClient.EXPECT().AuthorizeSecurityGroupIngress(gomock.Any(), mock.Eq(&ec2.AuthorizeSecurityGroupIngressInput{
DryRun: awssdk.Bool(true),
Expand Down Expand Up @@ -143,7 +165,7 @@ func (f *fakeAWSClientBase) expectValidateRevokeSecurityGroupIngress(retErr erro
func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subnet) {
f.awsClient.EXPECT().DescribeSubnets(gomock.Any(), eqFilters(types.Filter{
Name: awssdk.String("tag:Name"),
Values: []string{infraID + "-public-" + region + "*"},
Values: []string{infraID + "*-public-" + region + "*"},
}, types.Filter{
Name: awssdk.String("vpc-id"),
Values: []string{f.vpcID},
Expand All @@ -153,6 +175,19 @@ func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subn
})).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).AnyTimes()
}

func (f *fakeAWSClientBase) expectDescribePublicSubnetsSigs(retSubnets ...types.Subnet) {
f.awsClient.EXPECT().DescribeSubnets(gomock.Any(), eqFilters(types.Filter{
Name: awssdk.String("tag:Name"),
Values: []string{infraID + "*-public-" + region + "*"},
}, types.Filter{
Name: awssdk.String("vpc-id"),
Values: []string{f.vpcID},
}, types.Filter{
Name: awssdk.String(clusterFilterTagNameSigs),
Values: []string{"owned"},
})).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).AnyTimes()
}

func (f *fakeAWSClientBase) expectDescribeGatewaySubnets(retSubnets ...types.Subnet) {
f.awsClient.EXPECT().DescribeSubnets(gomock.Any(), eqFilters(types.Filter{
Name: awssdk.String("tag:submariner.io/gateway"),
Expand Down
3 changes: 3 additions & 0 deletions pkg/aws/ocpgwdeployer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver {
t.expectDescribeSecurityGroups(gatewaySGName, t.gatewayGroupID)
t.expectDescribeInstances(instanceImageID)
t.expectDescribeSecurityGroups(workerSGName, workerGroupID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

var err error

Expand Down

0 comments on commit 9364c71

Please sign in to comment.