Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom vpc support in AWS cloud prepare #1007

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 94 additions & 24 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/pkg/errors"
"github.com/submariner-io/admiral/pkg/reporter"
"github.com/submariner-io/cloud-prepare/pkg/api"
Expand All @@ -40,36 +41,84 @@ const (
messageValidatedPrerequisites = "Validated pre-requisites"
)

type CloudOption func(*awsCloud)

const (
ControlPlaneSecurityGroupIDKey = "controlPlaneSecurityGroupID"
WorkerSecurityGroupIDKey = "workerSecurityGroupID"
PublicSubnetListKey = "PublicSubnetList"
VPCIDKey = "VPCID"
)

func WithControlPlaneSecurityGroup(id string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[ControlPlaneSecurityGroupIDKey] = id
}
}

func WithWorkerSecurityGroup(id string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[WorkerSecurityGroupIDKey] = id
}
}

func WithPublicSubnetList(id []string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[PublicSubnetListKey] = id
}
}

func WithVPCName(name string) CloudOption {
return func(cloud *awsCloud) {
cloud.cloudConfig[VPCIDKey] = name
}
}

type awsCloud struct {
client awsClient.Interface
infraID string
region string
nodeSGSuffix string
controlPlaneSGSuffix string
cloudConfig map[string]interface{}
}

// NewCloud creates a new api.Cloud instance which can prepare AWS for Submariner to be deployed on it.
func NewCloud(client awsClient.Interface, infraID, region string) api.Cloud {
return &awsCloud{
client: client,
infraID: infraID,
region: region,
func NewCloud(client awsClient.Interface, infraID, region string, opts ...CloudOption) api.Cloud {
cloud := &awsCloud{
client: client,
infraID: infraID,
region: region,
cloudConfig: make(map[string]interface{}),
}

for _, opt := range opts {
opt(cloud)
}

return cloud
}

// NewCloudFromConfig creates a new api.Cloud instance based on an AWS configuration
// which can prepare AWS for Submariner to be deployed on it.
func NewCloudFromConfig(cfg *aws.Config, infraID, region string) api.Cloud {
return &awsCloud{
client: ec2.NewFromConfig(*cfg),
infraID: infraID,
region: region,
func NewCloudFromConfig(cfg *aws.Config, infraID, region string, opts ...CloudOption) api.Cloud {
cloud := &awsCloud{
client: ec2.NewFromConfig(*cfg),
infraID: infraID,
region: region,
cloudConfig: make(map[string]interface{}),
}

for _, opt := range opts {
opt(cloud)
}

return cloud
}

// NewCloudFromSettings creates a new api.Cloud instance using the given credentials file and profile
// which can prepare AWS for Submariner to be deployed on it.
func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api.Cloud, error) {
func NewCloudFromSettings(credentialsFile, profile, infraID, region string, opts ...CloudOption) (api.Cloud, error) {
options := []func(*config.LoadOptions) error{config.WithRegion(region), config.WithSharedConfigProfile(profile)}
if credentialsFile != DefaultCredentialsFile() {
options = append(options, config.WithSharedCredentialsFiles([]string{credentialsFile}))
Expand All @@ -80,7 +129,7 @@ func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api
return nil, errors.Wrap(err, "error loading default config")
}

return NewCloudFromConfig(&cfg, infraID, region), nil
return NewCloudFromConfig(&cfg, infraID, region, opts...), nil
}

// DefaultCredentialsFile returns the default credentials file name.
Expand All @@ -98,13 +147,30 @@ func (ac *awsCloud) setSuffixes(vpcID string) error {
return nil
}

publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}
var publicSubnets []types.Subnet

if subnets, exists := ac.cloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := ac.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}

publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}

if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
}
}

pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region))
Expand Down Expand Up @@ -137,9 +203,11 @@ func (ac *awsCloud) OpenPorts(ports []api.PortSpec, status reporter.Interface) e
return status.Error(err, "unable to retrieve the VPC ID")
}

err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := ac.cloudConfig[VPCIDKey]; !found {
err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Success(messageRetrievedVPCID, vpcID)
Expand Down Expand Up @@ -180,9 +248,11 @@ func (ac *awsCloud) ClosePorts(status reporter.Interface) error {
return status.Error(err, "unable to retrieve the VPC ID")
}

err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := ac.cloudConfig[VPCIDKey]; !found {
err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Success(messageRetrievedVPCID, vpcID)
Expand Down
2 changes: 0 additions & 2 deletions pkg/aws/aws_cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ func testOpenPorts() {

JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)

retError = t.cloud.OpenPorts([]api.PortSpec{
Expand Down Expand Up @@ -118,7 +117,6 @@ func testClosePorts() {
JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

retError = t.cloud.ClosePorts(reporter.Stdout())
Expand Down
21 changes: 3 additions & 18 deletions pkg/aws/aws_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ const (
masterSGName = infraID + "-master-sg"
workerSGName = infraID + "-worker-sg"
gatewaySGName = infraID + "-submariner-gw-sg"
providerAWSTagPrefix = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/"
clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID
clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID
clusterFilterTagNameSigs = providerAWSTagPrefix + infraID
)

var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic)
Expand Down Expand Up @@ -110,24 +111,8 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) {
}, {
Name: ptr.To(clusterFilterTagName),
Values: []string{"owned"},
}}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe()
}

func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) {
var vpcs []types.Vpc
if vpcID != "" {
vpcs = []types.Vpc{
{
VpcId: ptr.To(vpcID),
},
}
}

f.awsClient.EXPECT().DescribeVpcs(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{
Name: ptr.To("tag:Name"),
Values: []string{infraID + "-vpc"},
}, {
Name: ptr.To(clusterFilterTagNameSigs),
Name: ptr.To(providerAWSTagPrefix + infraID),
Values: []string{"owned"},
}}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe()
}
Expand Down
74 changes: 59 additions & 15 deletions pkg/aws/ocpgwdeployer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,35 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte

status.Success(messageRetrievedVPCID, vpcID)

err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := d.aws.cloudConfig[VPCIDKey]; !found {
err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Start(messageValidatePrerequisites)

publicSubnets, err := d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return status.Error(err, "unable to find public subnets")
var publicSubnets []types.Subnet

if subnets, exists := d.aws.cloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := d.aws.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}

publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err = d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return status.Error(err, "unable to find public subnets")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code block looks the same as the block starting on L152 above except the latter checks for len(publicSubnets). I think we can refactor to a function and let the caller check len(publicSubnets).

}

err = d.validateDeployPrerequisites(vpcID, input, publicSubnets)
Expand All @@ -97,9 +116,15 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte

status.Success("Created Submariner gateway security group %s", gatewaySG)

return d.processSubnets(vpcID, gatewaySG, publicSubnets, input, status)
}

func (d *ocpGatewayDeployer) processSubnets(vpcID, gatewaySG string, publicSubnets []types.Subnet,
input api.GatewayDeployInput, status reporter.Interface,
) error {
subnets, err := d.aws.getSubnetsSupportingInstanceType(publicSubnets, d.instanceType)
if err != nil {
return status.Error(err, "unable to create security group")
return status.Error(err, "unable to get subnets supporting instance type")
}

taggedSubnets, _ := filterSubnets(subnets, func(subnet *types.Subnet) (bool, error) {
Expand Down Expand Up @@ -313,9 +338,11 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error {

status.Success(messageRetrievedVPCID, vpcID)

err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := d.aws.cloudConfig[VPCIDKey]; !found {
err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Start(messageValidatePrerequisites)
Expand All @@ -327,13 +354,30 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error {

status.Success(messageValidatedPrerequisites)

subnets, err := d.aws.getTaggedPublicSubnets(vpcID)
if err != nil {
return err
var publicSubnets []types.Subnet

if subnets, exists := d.aws.cloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := d.aws.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}

publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err = d.aws.getTaggedPublicSubnets(vpcID)
if err != nil {
return err
}
}

for i := range subnets {
subnet := &subnets[i]
for i := range publicSubnets {
subnet := &publicSubnets[i]
subnetName := extractName(subnet.Tags)

status.Start("Removing gateway node for public subnet %s", subnetName)
Expand Down
1 change: 0 additions & 1 deletion pkg/aws/ocpgwdeployer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver {
t.expectDescribeInstances(instanceImageID)
t.expectDescribeSecurityGroups(workerSGName, workerGroupID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

var err error
Expand Down
Loading
Loading