Skip to content

Commit

Permalink
Add custom vpc support in AWS cloud prepare
Browse files Browse the repository at this point in the history
Signed-off-by: Aswin Suryanarayanan <[email protected]>
  • Loading branch information
aswinsuryan committed Sep 24, 2024
1 parent fa86cf1 commit 1da2efd
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 68 deletions.
117 changes: 93 additions & 24 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package aws
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"regexp"
"strings"

Expand All @@ -40,36 +41,84 @@ const (
messageValidatedPrerequisites = "Validated pre-requisites"
)

type CloudOption func(*awsCloud)

const (
ControlPlaneSecurityGroupKey = "controlPlaneSecurityGroup"
WorkerSecurityGroupKey = "workerSecurityGroup"
PublicSubnetListKey = "PublicSubnetName"
VPCNameKey = "VPCName"
)

func WithControlPlaneSecurityGroup(name string) CloudOption {
return func(cloud *awsCloud) {
cloud.CloudConfig[ControlPlaneSecurityGroupKey] = name
}
}

func WithWorkerSecurityGroup(name string) CloudOption {
return func(cloud *awsCloud) {
cloud.CloudConfig[WorkerSecurityGroupKey] = name
}
}

func WithPublicSubnetList(name string) CloudOption {
return func(cloud *awsCloud) {
cloud.CloudConfig[PublicSubnetListKey] = name
}
}

func WithVPCName(name string) CloudOption {
return func(cloud *awsCloud) {
cloud.CloudConfig[VPCNameKey] = name
}
}

type awsCloud struct {
client awsClient.Interface
infraID string
region string
nodeSGSuffix string
controlPlaneSGSuffix string
CloudConfig map[string]interface{}
}

// NewCloud creates a new api.Cloud instance which can prepare AWS for Submariner to be deployed on it.
func NewCloud(client awsClient.Interface, infraID, region string) api.Cloud {
return &awsCloud{
client: client,
infraID: infraID,
region: region,
func NewCloud(client awsClient.Interface, infraID, region string, opts ...CloudOption) api.Cloud {
cloud := &awsCloud{
client: client,
infraID: infraID,
region: region,
CloudConfig: make(map[string]interface{}),
}

for _, opt := range opts {
opt(cloud)
}

return cloud
}

// NewCloudFromConfig creates a new api.Cloud instance based on an AWS configuration
// which can prepare AWS for Submariner to be deployed on it.
func NewCloudFromConfig(cfg *aws.Config, infraID, region string) api.Cloud {
return &awsCloud{
client: ec2.NewFromConfig(*cfg),
infraID: infraID,
region: region,
func NewCloudFromConfig(cfg *aws.Config, infraID, region string, opts ...CloudOption) api.Cloud {
cloud := &awsCloud{
client: ec2.NewFromConfig(*cfg),
infraID: infraID,
region: region,
CloudConfig: make(map[string]interface{}),
}

for _, opt := range opts {
opt(cloud)
}

return cloud
}

// NewCloudFromSettings creates a new api.Cloud instance using the given credentials file and profile
// which can prepare AWS for Submariner to be deployed on it.
func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api.Cloud, error) {
func NewCloudFromSettings(credentialsFile, profile, infraID, region string, opts ...CloudOption) (api.Cloud, error) {
options := []func(*config.LoadOptions) error{config.WithRegion(region), config.WithSharedConfigProfile(profile)}
if credentialsFile != DefaultCredentialsFile() {
options = append(options, config.WithSharedCredentialsFiles([]string{credentialsFile}))
Expand All @@ -80,7 +129,7 @@ func NewCloudFromSettings(credentialsFile, profile, infraID, region string) (api
return nil, errors.Wrap(err, "error loading default config")
}

return NewCloudFromConfig(&cfg, infraID, region), nil
return NewCloudFromConfig(&cfg, infraID, region, opts...), nil
}

// DefaultCredentialsFile returns the default credentials file name.
Expand All @@ -98,13 +147,29 @@ func (ac *awsCloud) setSuffixes(vpcID string) error {
return nil
}

publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}
var publicSubnets []types.Subnet

if subnets, exists := ac.CloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := ac.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}
publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}

if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
}
}

pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region))
Expand Down Expand Up @@ -137,9 +202,11 @@ func (ac *awsCloud) OpenPorts(ports []api.PortSpec, status reporter.Interface) e
return status.Error(err, "unable to retrieve the VPC ID")
}

err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := ac.CloudConfig[VPCNameKey]; !found {
err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Success(messageRetrievedVPCID, vpcID)
Expand Down Expand Up @@ -180,9 +247,11 @@ func (ac *awsCloud) ClosePorts(status reporter.Interface) error {
return status.Error(err, "unable to retrieve the VPC ID")
}

err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := ac.CloudConfig[VPCNameKey]; !found {
err = ac.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Success(messageRetrievedVPCID, vpcID)
Expand Down
38 changes: 29 additions & 9 deletions pkg/aws/ocpgwdeployer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,34 @@ func (d *ocpGatewayDeployer) Deploy(input api.GatewayDeployInput, status reporte

status.Success(messageRetrievedVPCID, vpcID)

err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := d.aws.CloudConfig[VPCNameKey]; !found {
err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Start(messageValidatePrerequisites)

publicSubnets, err := d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return status.Error(err, "unable to find public subnets")
var publicSubnets []types.Subnet

if subnets, exists := d.aws.CloudConfig[PublicSubnetListKey]; exists {
if subnetIDs, ok := subnets.([]string); ok && len(subnetIDs) > 0 {
for _, id := range subnetIDs {
subnet, err := d.aws.getSubnetByID(id)
if err != nil {
return errors.Wrapf(err, "unable to find subnet with ID %s", id)
}
publicSubnets = append(publicSubnets, *subnet)
}
} else {
return errors.New("Subnet IDs must be a valid non-empty slice of strings")
}
} else {
publicSubnets, err = d.aws.findPublicSubnets(vpcID, d.aws.filterByName("{infraID}*-public-{region}*"))
if err != nil {
return status.Error(err, "unable to find public subnets")
}
}

err = d.validateDeployPrerequisites(vpcID, input, publicSubnets)
Expand Down Expand Up @@ -313,9 +331,11 @@ func (d *ocpGatewayDeployer) Cleanup(status reporter.Interface) error {

status.Success(messageRetrievedVPCID, vpcID)

err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
if _, found := d.aws.CloudConfig[VPCNameKey]; !found {
err = d.aws.setSuffixes(vpcID)
if err != nil {
return status.Error(err, "unable to retrieve the security group names")
}
}

status.Start(messageValidatePrerequisites)
Expand Down
66 changes: 55 additions & 11 deletions pkg/aws/securitygroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,48 @@ func (ac *awsCloud) createClusterSGRule(srcGroup, destGroup *string, port uint16
}

func (ac *awsCloud) allowPortInCluster(vpcID string, port uint16, protocol string) error {
workerGroupID, err := ac.getSecurityGroupID(vpcID, withInfraIDPrefix(ac.nodeSGSuffix))
if err != nil {
return err
var workerGroupID, controlPlaneGroupID *string
var err error

if id, exists := ac.CloudConfig[WorkerSecurityGroupKey]; exists {
if workerGroupIDStr, ok := id.(string); ok && workerGroupIDStr != "" {
workerGroupID = &workerGroupIDStr
} else {
return errors.New("Worker Security Group ID must be a valid non-empty string")
}
} else {
workerGroupName := withInfraIDPrefix(ac.nodeSGSuffix)
workerGroupID, err = ac.getSecurityGroupID(vpcID, workerGroupName)
if err != nil {
return err
}
}

masterGroupID, err := ac.getSecurityGroupID(vpcID, withInfraIDPrefix(ac.controlPlaneSGSuffix))
if err != nil {
return err
if id, exists := ac.CloudConfig[ControlPlaneSecurityGroupKey]; exists {
if controlPlaneGroupIDStr, ok := id.(string); ok && controlPlaneGroupIDStr != "" {
controlPlaneGroupID = &controlPlaneGroupIDStr
} else {
return errors.New("Control Plane Security Group ID must be a valid non-empty string")
}
} else {
controlPlaneGroupName := withInfraIDPrefix(ac.controlPlaneSGSuffix)
controlPlaneGroupID, err = ac.getSecurityGroupID(vpcID, controlPlaneGroupName)
if err != nil {
return err
}
}

err = ac.createClusterSGRule(workerGroupID, workerGroupID, port, protocol, fmt.Sprintf("%s between the workers", internalTraffic))
if err != nil {
return err
}

err = ac.createClusterSGRule(workerGroupID, masterGroupID, port, protocol, fmt.Sprintf("%s from worker to master nodes", internalTraffic))
err = ac.createClusterSGRule(workerGroupID, controlPlaneGroupID, port, protocol, fmt.Sprintf("%s from worker to control plane nodes", internalTraffic))
if err != nil {
return err
}

return ac.createClusterSGRule(masterGroupID, workerGroupID, port, protocol, fmt.Sprintf("%s from master to worker nodes", internalTraffic))
return ac.createClusterSGRule(controlPlaneGroupID, workerGroupID, port, protocol, fmt.Sprintf("%s from control plane to worker nodes", internalTraffic))
}

func (ac *awsCloud) createPublicSGRule(groupID *string, port uint16, protocol, description string) error {
Expand Down Expand Up @@ -219,12 +240,35 @@ func (ac *awsCloud) deleteGatewaySG(vpcID string) error {
}

func (ac *awsCloud) revokePortsInCluster(vpcID string) error {
workerGroup, err := ac.getSecurityGroup(vpcID, withInfraIDPrefix(ac.nodeSGSuffix))
var workerGroupName, controlPlaneGroupName string
var err error

if name, exists := ac.CloudConfig[WorkerSecurityGroupKey]; exists {
if workerGroupNameStr, ok := name.(string); ok && workerGroupNameStr != "" {
workerGroupName = workerGroupNameStr
} else {
return errors.New("Worker Security Group name needs to be a valid non-empty string")
}
} else {
workerGroupName = withInfraIDPrefix(ac.nodeSGSuffix)
}

workerGroup, err := ac.getSecurityGroup(vpcID, workerGroupName)
if err != nil {
return err
}

masterGroup, err := ac.getSecurityGroup(vpcID, withInfraIDPrefix(ac.controlPlaneSGSuffix))
if name, exists := ac.CloudConfig[ControlPlaneSecurityGroupKey]; exists {
if controlPlaneGroupNameStr, ok := name.(string); ok && controlPlaneGroupNameStr != "" {
controlPlaneGroupName = controlPlaneGroupNameStr
} else {
return errors.New("Control Plane Security Group name needs to be a valid non-empty string")
}
} else {
controlPlaneGroupName = withInfraIDPrefix(ac.controlPlaneSGSuffix)
}

controlPlaneGroup, err := ac.getSecurityGroup(vpcID, controlPlaneGroupName)
if err != nil {
return err
}
Expand All @@ -234,7 +278,7 @@ func (ac *awsCloud) revokePortsInCluster(vpcID string) error {
return err
}

return ac.revokePortsFromGroup(&masterGroup)
return ac.revokePortsFromGroup(&controlPlaneGroup)
}

func (ac *awsCloud) revokePortsFromGroup(group *types.SecurityGroup) error {
Expand Down
15 changes: 15 additions & 0 deletions pkg/aws/subnets.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,18 @@ func (ac *awsCloud) untagPublicSubnet(subnetID *string) error {

return errors.Wrap(err, "error deleting AWS tag")
}

func (ac *awsCloud) getSubnetByID(subnetID string) (*types.Subnet, error) {
output, err := ac.client.DescribeSubnets(context.TODO(), &ec2.DescribeSubnetsInput{
SubnetIds: []string{subnetID},
})
if err != nil {
return nil, errors.Wrapf(err, "unable to describe subnet %s", subnetID)
}

if len(output.Subnets) == 0 {
return nil, errors.New("subnet not found")
}

return &output.Subnets[0], nil
}
Loading

0 comments on commit 1da2efd

Please sign in to comment.