Skip to content

Commit

Permalink
AWS firewall configuration via security groups (#239)
Browse files Browse the repository at this point in the history
Skyplane now supports concurrent transfers in a secure manner. Every instance's ip is manually added to the SG at the start of transfer, and removed from the SG at the end of a transfer.
  • Loading branch information
ShishirPatil authored May 4, 2022
1 parent 90a2fdf commit 31f44bb
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 120 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"halo",
"pandas",
"questionary",
"sshtunnel",
"typer",
# shared dependencies
"cachetools",
Expand Down
7 changes: 4 additions & 3 deletions skylark/benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] =
results = do_parallel(
lambda region: provider.get_matching_instances(region=region, **instance_filter),
region_list,
progress_bar=True,
spinner=True,
n=n,
desc="Refreshing instance list",
desc="Querying clouds for active instances",
)
return {r: ilist for r, ilist in results if ilist}

Expand Down Expand Up @@ -67,7 +67,7 @@ def provision(
jobs.append(partial(aws.create_iam, attach_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess"))
if aws_regions_to_provision:
for r in set(aws_regions_to_provision):
jobs.append(partial(aws.add_ip_to_security_group, r))
jobs.append(partial(aws.make_vpc, r))
jobs.append(partial(aws.ensure_keyfile_exists, r))
if azure_regions_to_provision:
jobs.append(azure.create_ssh_key)
Expand All @@ -87,6 +87,7 @@ def provision(
"state": [ServerState.PENDING, ServerState.RUNNING],
}
do_parallel(aws.add_ip_to_security_group, aws_regions_to_provision, progress_bar=True, desc="add IP to aws security groups")
do_parallel(aws.authorize_client, [(r, "0.0.0.0/0") for r in aws_regions_to_provision], progress_bar=True, desc="authorize client")
aws_instances = refresh_instance_list(aws, aws_regions_to_provision, aws_instance_filter)
missing_aws_regions = set(aws_regions_to_provision) - set(aws_instances.keys())
if missing_aws_regions:
Expand Down
37 changes: 19 additions & 18 deletions skylark/cli/cli_helper.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,36 @@
import concurrent.futures
from functools import partial
import json
import logging
import os
import re
import logging
import resource
import signal
import subprocess
from functools import partial
from pathlib import Path
from shutil import copyfile
from threading import Thread
from typing import Dict, List
from sys import platform
from typing import Dict, List

from threading import Thread
from typing import Dict, List, Optional

import boto3
import typer
from skylark import GB, MB
from skylark import exceptions
from skylark import gcp_config_path
from skylark import GB, MB, exceptions, gcp_config_path
from skylark.compute.aws.aws_auth import AWSAuthentication
from skylark.compute.azure.azure_auth import AzureAuthentication
from skylark.compute.gcp.gcp_auth import GCPAuthentication
from skylark.config import SkylarkConfig
from skylark.utils import logger
from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider
from skylark.compute.azure.azure_auth import AzureAuthentication
from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider
from skylark.compute.gcp.gcp_auth import GCPAuthentication
from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider
from skylark.config import SkylarkConfig
from skylark.obj_store.azure_interface import AzureInterface
from skylark.obj_store.gcs_interface import GCSInterface
from skylark.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skylark.obj_store.s3_interface import S3Interface
from skylark.obj_store.gcs_interface import GCSInterface
from skylark.obj_store.azure_interface import AzureInterface
from skylark.obj_store.object_store_interface import ObjectStoreInterface
from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology
from skylark.replicate.replicator_client import ReplicatorClient
from skylark.utils import logger
from skylark.utils.utils import do_parallel
from typing import Optional
from tqdm import tqdm


Expand Down Expand Up @@ -297,7 +291,7 @@ def replicate_helper(
try:
rc.provision_gateways(reuse_gateways)
for node, gw in rc.bound_nodes.items():
typer.secho(f" Realtime logs for {node.region}:{node.instance} at {gw.gateway_log_viewer_url}")
logger.fs.info(f"Realtime logs for {node.region}:{node.instance} at {gw.gateway_log_viewer_url}")
job = rc.run_replication_plan(job)
if random:
total_bytes = n_chunks * random_chunk_size_mb * MB
Expand Down Expand Up @@ -406,6 +400,13 @@ def run():
else:
typer.secho("No instances to deprovision, exiting...", fg="yellow", bold=True)

# remove skylark vpc
if AWSAuthentication().enabled():
aws = AWSCloudProvider()
vpcs = do_parallel(partial(aws.get_vpcs), aws.region_list(), desc="Querying VPCs", spinner=True)
args = [(x[0], vpc.id) for x in vpcs for vpc in x[1]]
do_parallel(lambda args: aws.delete_vpc(*args), args, desc="Deleting VPCs", spinner=True, spinner_persist=True)


def load_aws_config(config: SkylarkConfig) -> SkylarkConfig:
# get AWS credentials from boto3
Expand Down
175 changes: 111 additions & 64 deletions skylark/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,75 +84,85 @@ def get_security_group(self, region: str, vpc_name="skylark", sg_name="skylark")
assert len(sgs) == 1
return sgs[0]

def get_vpc(self, region: str, vpc_name="skylark"):
def get_vpcs(self, region: str, vpc_name="skylark"):
ec2 = self.auth.get_boto3_resource("ec2", region)
vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all())
if len(vpcs) == 0:
return None
return []
else:
return vpcs[0]
return vpcs

def make_vpc(self, region: str, vpc_name="skylark"):
ec2 = self.auth.get_boto3_resource("ec2", region)
ec2client = ec2.meta.client
vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all())

# find matching valid VPC
matching_vpc = None
for vpc in vpcs:
subsets_azs = [subnet.availability_zone for subnet in vpc.subnets.all()]
if (
vpc.cidr_block == "10.0.0.0/16"
and vpc.describe_attribute(Attribute="enableDnsSupport")["EnableDnsSupport"]
and vpc.describe_attribute(Attribute="enableDnsHostnames")["EnableDnsHostnames"]
and all(az in subsets_azs for az in self.auth.get_azs_in_region(region))
):
matching_vpc = vpc
# delete all other vpcs
for vpc in vpcs:
if vpc != matching_vpc:
try:
self.delete_vpc(region, vpc.id)
except botocore.exceptions.ClientError as e:
logger.warning(f"Failed to delete VPC {vpc.id} in {region}: {e}")
break

# make vpc if none found
if matching_vpc is None:
# delete old skylark vpcs
with ILock(f"aws_make_vpc_{region}"):
# find matching valid VPC
matching_vpc = None
for vpc in vpcs:
self.delete_vpc(region, vpc.id)

# enable dns support, enable dns hostnames
matching_vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16", InstanceTenancy="default")
matching_vpc.modify_attribute(EnableDnsSupport={"Value": True})
matching_vpc.modify_attribute(EnableDnsHostnames={"Value": True})
matching_vpc.create_tags(Tags=[{"Key": "Name", "Value": vpc_name}])
matching_vpc.wait_until_available()

# make subnet for each AZ in region
def make_subnet(az):
subnet_cidr_id = ord(az[-1]) - ord("a")
subnet = self.auth.get_boto3_resource("ec2", region).create_subnet(
CidrBlock=f"10.0.{subnet_cidr_id}.0/24", VpcId=matching_vpc.id, AvailabilityZone=az
subsets_azs = [subnet.availability_zone for subnet in vpc.subnets.all()]
if (
vpc.cidr_block == "10.0.0.0/16"
and vpc.describe_attribute(Attribute="enableDnsSupport")["EnableDnsSupport"]
and vpc.describe_attribute(Attribute="enableDnsHostnames")["EnableDnsHostnames"]
and all(az in subsets_azs for az in self.auth.get_azs_in_region(region))
):
matching_vpc = vpc
# delete all other vpcs
for vpc in vpcs:
if vpc != matching_vpc:
try:
self.delete_vpc(region, vpc.id)
except botocore.exceptions.ClientError as e:
logger.warning(f"Failed to delete VPC {vpc.id} in {region}: {e}")
break

# make vpc if none found
if matching_vpc is None:
# delete old skylark vpcs
for vpc in vpcs:
self.delete_vpc(region, vpc.id)

# enable dns support, enable dns hostnames
matching_vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16", InstanceTenancy="default")
matching_vpc.modify_attribute(EnableDnsSupport={"Value": True})
matching_vpc.modify_attribute(EnableDnsHostnames={"Value": True})
matching_vpc.create_tags(Tags=[{"Key": "Name", "Value": vpc_name}])
matching_vpc.wait_until_available()

# make subnet for each AZ in region
def make_subnet(az):
subnet_cidr_id = ord(az[-1]) - ord("a")
subnet = self.auth.get_boto3_resource("ec2", region).create_subnet(
CidrBlock=f"10.0.{subnet_cidr_id}.0/24", VpcId=matching_vpc.id, AvailabilityZone=az
)
subnet.meta.client.modify_subnet_attribute(SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": True})
return subnet.id

subnet_ids = do_parallel(make_subnet, self.auth.get_azs_in_region(region), return_args=False)

# make internet gateway
igw = ec2.create_internet_gateway()
igw.attach_to_vpc(VpcId=matching_vpc.id)
public_route_table = list(matching_vpc.route_tables.all())[0]

# add a default route, for Public Subnet, pointing to Internet Gateway
ec2client.create_route(RouteTableId=public_route_table.id, DestinationCidrBlock="0.0.0.0/0", GatewayId=igw.id)
for subnet_id in subnet_ids:
public_route_table.associate_with_subnet(SubnetId=subnet_id)

# make security group named "skylark"
tagSpecifications = [
{"Tags": [{"Key": "skylark", "Value": "true"}], "ResourceType": "security-group"},
]
sg = ec2.create_security_group(
GroupName="skylark",
Description="Default security group for Skylark VPC",
VpcId=matching_vpc.id,
TagSpecifications=tagSpecifications,
)
subnet.meta.client.modify_subnet_attribute(SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": True})
return subnet.id

subnet_ids = do_parallel(make_subnet, self.auth.get_azs_in_region(region), return_args=False)

# make internet gateway
igw = ec2.create_internet_gateway()
igw.attach_to_vpc(VpcId=matching_vpc.id)
public_route_table = list(matching_vpc.route_tables.all())[0]
# add a default route, for Public Subnet, pointing to Internet Gateway
ec2client.create_route(RouteTableId=public_route_table.id, DestinationCidrBlock="0.0.0.0/0", GatewayId=igw.id)
for subnet_id in subnet_ids:
public_route_table.associate_with_subnet(SubnetId=subnet_id)

# make security group named "default"
sg = ec2.create_security_group(GroupName="skylark", Description="Default security group for Skylark VPC", VpcId=matching_vpc.id)
return matching_vpc
return matching_vpc

def delete_vpc(self, region: str, vpcid: str):
"""Delete VPC, from https://gist.github.com/vernhart/c6a0fc94c0aeaebe84e5cd6f3dede4ce"""
Expand Down Expand Up @@ -212,18 +222,54 @@ def create_iam(self, iam_name: str = "skylark_gateway", attach_policy_arn: Optio
if attach_policy_arn:
iam.attach_role_policy(RoleName=iam_name, PolicyArn=attach_policy_arn)

def add_ip_to_security_group(self, aws_region: str):
"""Add IP to security group. If security group ID is None, use group named skylark (create if not exists)."""
def authorize_client(self, aws_region: str, client_ip: str, port=22):
vpcs = self.get_vpcs(aws_region)
assert vpcs, f"No VPC found in {aws_region}"
vpc = vpcs[0]
sgs = [sg for sg in vpc.security_groups.all() if sg.group_name == "skylark"]
assert len(sgs) == 1, f"Found {len(sgs)} sgs named skylark, expected 1"
sg = sgs[0]

# check if we already have a rule for this security group
for rule in sg.ip_permissions:
if "FromPort" in rule and rule["FromPort"] <= port and "ToPort" in rule and rule["ToPort"] >= port:
for ipr in rule["IpRanges"]:
if ipr["CidrIp"] == client_ip:
logger.fs.debug(f"[AWS] Found existing rule for {client_ip}:{port} in {sg.group_name}, not adding again")
return
logger.fs.debug(f"[AWS] Authorizing {client_ip}:{port} in {sg.group_name}")
sg.authorize_ingress(IpPermissions=[{"IpProtocol": "tcp", "FromPort": port, "ToPort": port, "IpRanges": [{"CidrIp": client_ip}]}])

def add_ip_to_security_group(self, aws_region: str, ip: Optional[str] = None):
"""Add IP to security group. If security group ID is None, use group named skylark (create if not exists). If ip is None, authorize all IPs."""
with ILock(f"aws_add_ip_to_security_group_{aws_region}"):
self.make_vpc(aws_region)
sg = self.get_security_group(aws_region)
try:
logger.fs.debug(f"[AWS] Adding IP {ip} to security group {sg.group_name}")
sg.authorize_ingress(
IpPermissions=[{"IpProtocol": "-1", "FromPort": -1, "ToPort": -1, "IpRanges": [{"CidrIp": "0.0.0.0/0"}]}]
IpPermissions=[
{"IpProtocol": "-1", "FromPort": -1, "ToPort": -1, "IpRanges": [{"CidrIp": f"{ip}/32" if ip else "0.0.0.0/0"}]}
]
)
except botocore.exceptions.ClientError as e:
logger.fs.error(f"[AWS] Error adding IP {ip} to security group {sg.group_name}: {e}")
if not str(e).endswith("already exists"):
raise e
logger.warn("[AWS] Error adding IP to security group, since it already exits")

def remove_ip_from_security_group(self, aws_region: str, ip: str):
"""Remove IP from security group. If security group ID is None, return."""
with ILock(f"aws_remove_ip_to_security_group_{aws_region}"):
# Remove instance IP from security group
sg = self.get_security_group(aws_region)
try:
logger.fs.debug(f"[AWS] Removing IP {ip} from security group {sg.group_name}")
sg.revoke_ingress(
IpPermissions=[{"IpProtocol": "tcp", "FromPort": 12000, "ToPort": 65535, "IpRanges": [{"CidrIp": ip + "/32"}]}]
)
except botocore.exceptions.ClientError as e:
logger.fs.error(f"[AWS] Error removing IP {ip} from security group {sg.group_name}: {e}")
if not str(e).endswith("NotFound"):
logger.warn("[AWS] Error removing IP from security group")

def ensure_keyfile_exists(self, aws_region, prefix=key_root / "aws"):
ec2 = self.auth.get_boto3_resource("ec2", aws_region)
Expand Down Expand Up @@ -268,8 +314,9 @@ def provision_instance(
iam = self.auth.get_boto3_client("iam", region)
ec2 = self.auth.get_boto3_resource("ec2", region)
ec2_client = self.auth.get_boto3_client("ec2", region)
vpc = self.get_vpc(region)
assert vpc is not None, "No VPC found"
vpcs = self.get_vpcs(region)
assert vpcs, "No VPC found"
vpc = vpcs[0]

# get subnet list
def instance_class_supported(az):
Expand Down
20 changes: 15 additions & 5 deletions skylark/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import boto3
import paramiko
from skylark.compute.aws.aws_auth import AWSAuthentication
import sshtunnel
from skylark import key_root

from skylark.compute.aws.aws_auth import AWSAuthentication
from skylark.compute.server import Server, ServerState
from skylark.utils.cache import ignore_lru_cache
from skylark.utils import logger


class AWSServer(Server):
Expand Down Expand Up @@ -75,10 +76,19 @@ def get_ssh_client_impl(self):
)
return client

def get_ssh_cmd(self):
return f"ssh -i {self.local_keyfile} ec2-user@{self.public_ip()}"

def get_sftp_client(self):
t = paramiko.Transport((self.public_ip(), 22))
t.connect(username="ec2-user", pkey=paramiko.RSAKey.from_private_key_file(str(self.local_keyfile)))
return paramiko.SFTPClient.from_transport(t)

def open_ssh_tunnel_impl(self, remote_port) -> sshtunnel.SSHTunnelForwarder:
return sshtunnel.SSHTunnelForwarder(
(self.public_ip(), 22),
ssh_username="ec2-user",
ssh_pkey=str(self.local_keyfile),
local_bind_address=("127.0.0.1", 0),
remote_bind_address=("127.0.0.1", remote_port),
)

def get_ssh_cmd(self):
return f"ssh -i {self.local_keyfile} ec2-user@{self.public_ip()}"
Loading

0 comments on commit 31f44bb

Please sign in to comment.