diff --git a/setup.py b/setup.py index 0d0d9390d..d53ac4e42 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "halo", "pandas", "questionary", + "sshtunnel", "typer", # shared dependencies "cachetools", diff --git a/skylark/benchmark/utils.py b/skylark/benchmark/utils.py index 871d9a14a..d9facfdd4 100644 --- a/skylark/benchmark/utils.py +++ b/skylark/benchmark/utils.py @@ -19,9 +19,9 @@ def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] = results = do_parallel( lambda region: provider.get_matching_instances(region=region, **instance_filter), region_list, - progress_bar=True, + spinner=True, n=n, - desc="Refreshing instance list", + desc="Querying clouds for active instances", ) return {r: ilist for r, ilist in results if ilist} @@ -67,7 +67,7 @@ def provision( jobs.append(partial(aws.create_iam, attach_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess")) if aws_regions_to_provision: for r in set(aws_regions_to_provision): - jobs.append(partial(aws.add_ip_to_security_group, r)) + jobs.append(partial(aws.make_vpc, r)) jobs.append(partial(aws.ensure_keyfile_exists, r)) if azure_regions_to_provision: jobs.append(azure.create_ssh_key) @@ -87,6 +87,7 @@ def provision( "state": [ServerState.PENDING, ServerState.RUNNING], } do_parallel(aws.add_ip_to_security_group, aws_regions_to_provision, progress_bar=True, desc="add IP to aws security groups") + do_parallel(aws.authorize_client, [(r, "0.0.0.0/0") for r in aws_regions_to_provision], progress_bar=True, desc="authorize client") aws_instances = refresh_instance_list(aws, aws_regions_to_provision, aws_instance_filter) missing_aws_regions = set(aws_regions_to_provision) - set(aws_instances.keys()) if missing_aws_regions: diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index d2d4963b4..53d0e5ba1 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -1,42 +1,36 @@ import concurrent.futures -from functools import partial import json +import logging import os import re -import logging import resource import signal import subprocess +from functools import partial from pathlib import Path from shutil import copyfile -from threading import Thread -from typing import Dict, List from sys import platform -from typing import Dict, List - +from threading import Thread +from typing import Dict, List, Optional import boto3 import typer -from skylark import GB, MB -from skylark import exceptions -from skylark import gcp_config_path +from skylark import GB, MB, exceptions, gcp_config_path from skylark.compute.aws.aws_auth import AWSAuthentication -from skylark.compute.azure.azure_auth import AzureAuthentication -from skylark.compute.gcp.gcp_auth import GCPAuthentication -from skylark.config import SkylarkConfig -from skylark.utils import logger from skylark.compute.aws.aws_cloud_provider import AWSCloudProvider +from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.azure.azure_cloud_provider import AzureCloudProvider +from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.compute.gcp.gcp_cloud_provider import GCPCloudProvider +from skylark.config import SkylarkConfig +from skylark.obj_store.azure_interface import AzureInterface +from skylark.obj_store.gcs_interface import GCSInterface from skylark.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject from skylark.obj_store.s3_interface import S3Interface -from skylark.obj_store.gcs_interface import GCSInterface -from skylark.obj_store.azure_interface import AzureInterface -from skylark.obj_store.object_store_interface import ObjectStoreInterface from skylark.replicate.replication_plan import ReplicationJob, ReplicationTopology from skylark.replicate.replicator_client import ReplicatorClient +from skylark.utils import logger from skylark.utils.utils import do_parallel -from typing import Optional from tqdm import tqdm @@ -297,7 +291,7 @@ def replicate_helper( try: rc.provision_gateways(reuse_gateways) for node, gw in rc.bound_nodes.items(): - typer.secho(f" Realtime logs for {node.region}:{node.instance} at {gw.gateway_log_viewer_url}") + logger.fs.info(f"Realtime logs for {node.region}:{node.instance} at {gw.gateway_log_viewer_url}") job = rc.run_replication_plan(job) if random: total_bytes = n_chunks * random_chunk_size_mb * MB @@ -406,6 +400,13 @@ def run(): else: typer.secho("No instances to deprovision, exiting...", fg="yellow", bold=True) + # remove skylark vpc + if AWSAuthentication().enabled(): + aws = AWSCloudProvider() + vpcs = do_parallel(partial(aws.get_vpcs), aws.region_list(), desc="Querying VPCs", spinner=True) + args = [(x[0], vpc.id) for x in vpcs for vpc in x[1]] + do_parallel(lambda args: aws.delete_vpc(*args), args, desc="Deleting VPCs", spinner=True, spinner_persist=True) + def load_aws_config(config: SkylarkConfig) -> SkylarkConfig: # get AWS credentials from boto3 diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index b26d6c426..9fa347a27 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -84,75 +84,85 @@ def get_security_group(self, region: str, vpc_name="skylark", sg_name="skylark") assert len(sgs) == 1 return sgs[0] - def get_vpc(self, region: str, vpc_name="skylark"): + def get_vpcs(self, region: str, vpc_name="skylark"): ec2 = self.auth.get_boto3_resource("ec2", region) vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) if len(vpcs) == 0: - return None + return [] else: - return vpcs[0] + return vpcs def make_vpc(self, region: str, vpc_name="skylark"): ec2 = self.auth.get_boto3_resource("ec2", region) ec2client = ec2.meta.client vpcs = list(ec2.vpcs.filter(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]).all()) - # find matching valid VPC - matching_vpc = None - for vpc in vpcs: - subsets_azs = [subnet.availability_zone for subnet in vpc.subnets.all()] - if ( - vpc.cidr_block == "10.0.0.0/16" - and vpc.describe_attribute(Attribute="enableDnsSupport")["EnableDnsSupport"] - and vpc.describe_attribute(Attribute="enableDnsHostnames")["EnableDnsHostnames"] - and all(az in subsets_azs for az in self.auth.get_azs_in_region(region)) - ): - matching_vpc = vpc - # delete all other vpcs - for vpc in vpcs: - if vpc != matching_vpc: - try: - self.delete_vpc(region, vpc.id) - except botocore.exceptions.ClientError as e: - logger.warning(f"Failed to delete VPC {vpc.id} in {region}: {e}") - break - - # make vpc if none found - if matching_vpc is None: - # delete old skylark vpcs + with ILock(f"aws_make_vpc_{region}"): + # find matching valid VPC + matching_vpc = None for vpc in vpcs: - self.delete_vpc(region, vpc.id) - - # enable dns support, enable dns hostnames - matching_vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16", InstanceTenancy="default") - matching_vpc.modify_attribute(EnableDnsSupport={"Value": True}) - matching_vpc.modify_attribute(EnableDnsHostnames={"Value": True}) - matching_vpc.create_tags(Tags=[{"Key": "Name", "Value": vpc_name}]) - matching_vpc.wait_until_available() - - # make subnet for each AZ in region - def make_subnet(az): - subnet_cidr_id = ord(az[-1]) - ord("a") - subnet = self.auth.get_boto3_resource("ec2", region).create_subnet( - CidrBlock=f"10.0.{subnet_cidr_id}.0/24", VpcId=matching_vpc.id, AvailabilityZone=az + subsets_azs = [subnet.availability_zone for subnet in vpc.subnets.all()] + if ( + vpc.cidr_block == "10.0.0.0/16" + and vpc.describe_attribute(Attribute="enableDnsSupport")["EnableDnsSupport"] + and vpc.describe_attribute(Attribute="enableDnsHostnames")["EnableDnsHostnames"] + and all(az in subsets_azs for az in self.auth.get_azs_in_region(region)) + ): + matching_vpc = vpc + # delete all other vpcs + for vpc in vpcs: + if vpc != matching_vpc: + try: + self.delete_vpc(region, vpc.id) + except botocore.exceptions.ClientError as e: + logger.warning(f"Failed to delete VPC {vpc.id} in {region}: {e}") + break + + # make vpc if none found + if matching_vpc is None: + # delete old skylark vpcs + for vpc in vpcs: + self.delete_vpc(region, vpc.id) + + # enable dns support, enable dns hostnames + matching_vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16", InstanceTenancy="default") + matching_vpc.modify_attribute(EnableDnsSupport={"Value": True}) + matching_vpc.modify_attribute(EnableDnsHostnames={"Value": True}) + matching_vpc.create_tags(Tags=[{"Key": "Name", "Value": vpc_name}]) + matching_vpc.wait_until_available() + + # make subnet for each AZ in region + def make_subnet(az): + subnet_cidr_id = ord(az[-1]) - ord("a") + subnet = self.auth.get_boto3_resource("ec2", region).create_subnet( + CidrBlock=f"10.0.{subnet_cidr_id}.0/24", VpcId=matching_vpc.id, AvailabilityZone=az + ) + subnet.meta.client.modify_subnet_attribute(SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": True}) + return subnet.id + + subnet_ids = do_parallel(make_subnet, self.auth.get_azs_in_region(region), return_args=False) + + # make internet gateway + igw = ec2.create_internet_gateway() + igw.attach_to_vpc(VpcId=matching_vpc.id) + public_route_table = list(matching_vpc.route_tables.all())[0] + + # add a default route, for Public Subnet, pointing to Internet Gateway + ec2client.create_route(RouteTableId=public_route_table.id, DestinationCidrBlock="0.0.0.0/0", GatewayId=igw.id) + for subnet_id in subnet_ids: + public_route_table.associate_with_subnet(SubnetId=subnet_id) + + # make security group named "skylark" + tagSpecifications = [ + {"Tags": [{"Key": "skylark", "Value": "true"}], "ResourceType": "security-group"}, + ] + sg = ec2.create_security_group( + GroupName="skylark", + Description="Default security group for Skylark VPC", + VpcId=matching_vpc.id, + TagSpecifications=tagSpecifications, ) - subnet.meta.client.modify_subnet_attribute(SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": True}) - return subnet.id - - subnet_ids = do_parallel(make_subnet, self.auth.get_azs_in_region(region), return_args=False) - - # make internet gateway - igw = ec2.create_internet_gateway() - igw.attach_to_vpc(VpcId=matching_vpc.id) - public_route_table = list(matching_vpc.route_tables.all())[0] - # add a default route, for Public Subnet, pointing to Internet Gateway - ec2client.create_route(RouteTableId=public_route_table.id, DestinationCidrBlock="0.0.0.0/0", GatewayId=igw.id) - for subnet_id in subnet_ids: - public_route_table.associate_with_subnet(SubnetId=subnet_id) - - # make security group named "default" - sg = ec2.create_security_group(GroupName="skylark", Description="Default security group for Skylark VPC", VpcId=matching_vpc.id) - return matching_vpc + return matching_vpc def delete_vpc(self, region: str, vpcid: str): """Delete VPC, from https://gist.github.com/vernhart/c6a0fc94c0aeaebe84e5cd6f3dede4ce""" @@ -212,18 +222,54 @@ def create_iam(self, iam_name: str = "skylark_gateway", attach_policy_arn: Optio if attach_policy_arn: iam.attach_role_policy(RoleName=iam_name, PolicyArn=attach_policy_arn) - def add_ip_to_security_group(self, aws_region: str): - """Add IP to security group. If security group ID is None, use group named skylark (create if not exists).""" + def authorize_client(self, aws_region: str, client_ip: str, port=22): + vpcs = self.get_vpcs(aws_region) + assert vpcs, f"No VPC found in {aws_region}" + vpc = vpcs[0] + sgs = [sg for sg in vpc.security_groups.all() if sg.group_name == "skylark"] + assert len(sgs) == 1, f"Found {len(sgs)} sgs named skylark, expected 1" + sg = sgs[0] + + # check if we already have a rule for this security group + for rule in sg.ip_permissions: + if "FromPort" in rule and rule["FromPort"] <= port and "ToPort" in rule and rule["ToPort"] >= port: + for ipr in rule["IpRanges"]: + if ipr["CidrIp"] == client_ip: + logger.fs.debug(f"[AWS] Found existing rule for {client_ip}:{port} in {sg.group_name}, not adding again") + return + logger.fs.debug(f"[AWS] Authorizing {client_ip}:{port} in {sg.group_name}") + sg.authorize_ingress(IpPermissions=[{"IpProtocol": "tcp", "FromPort": port, "ToPort": port, "IpRanges": [{"CidrIp": client_ip}]}]) + + def add_ip_to_security_group(self, aws_region: str, ip: Optional[str] = None): + """Add IP to security group. If security group ID is None, use group named skylark (create if not exists). If ip is None, authorize all IPs.""" with ILock(f"aws_add_ip_to_security_group_{aws_region}"): - self.make_vpc(aws_region) sg = self.get_security_group(aws_region) try: + logger.fs.debug(f"[AWS] Adding IP {ip} to security group {sg.group_name}") sg.authorize_ingress( - IpPermissions=[{"IpProtocol": "-1", "FromPort": -1, "ToPort": -1, "IpRanges": [{"CidrIp": "0.0.0.0/0"}]}] + IpPermissions=[ + {"IpProtocol": "-1", "FromPort": -1, "ToPort": -1, "IpRanges": [{"CidrIp": f"{ip}/32" if ip else "0.0.0.0/0"}]} + ] ) except botocore.exceptions.ClientError as e: + logger.fs.error(f"[AWS] Error adding IP {ip} to security group {sg.group_name}: {e}") if not str(e).endswith("already exists"): - raise e + logger.warn("[AWS] Error adding IP to security group, since it already exits") + + def remove_ip_from_security_group(self, aws_region: str, ip: str): + """Remove IP from security group. If security group ID is None, return.""" + with ILock(f"aws_remove_ip_to_security_group_{aws_region}"): + # Remove instance IP from security group + sg = self.get_security_group(aws_region) + try: + logger.fs.debug(f"[AWS] Removing IP {ip} from security group {sg.group_name}") + sg.revoke_ingress( + IpPermissions=[{"IpProtocol": "tcp", "FromPort": 12000, "ToPort": 65535, "IpRanges": [{"CidrIp": ip + "/32"}]}] + ) + except botocore.exceptions.ClientError as e: + logger.fs.error(f"[AWS] Error removing IP {ip} from security group {sg.group_name}: {e}") + if not str(e).endswith("NotFound"): + logger.warn("[AWS] Error removing IP from security group") def ensure_keyfile_exists(self, aws_region, prefix=key_root / "aws"): ec2 = self.auth.get_boto3_resource("ec2", aws_region) @@ -268,8 +314,9 @@ def provision_instance( iam = self.auth.get_boto3_client("iam", region) ec2 = self.auth.get_boto3_resource("ec2", region) ec2_client = self.auth.get_boto3_client("ec2", region) - vpc = self.get_vpc(region) - assert vpc is not None, "No VPC found" + vpcs = self.get_vpcs(region) + assert vpcs, "No VPC found" + vpc = vpcs[0] # get subnet list def instance_class_supported(az): diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index a6872c393..741bb73b8 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -2,11 +2,12 @@ import boto3 import paramiko -from skylark.compute.aws.aws_auth import AWSAuthentication +import sshtunnel from skylark import key_root - +from skylark.compute.aws.aws_auth import AWSAuthentication from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache +from skylark.utils import logger class AWSServer(Server): @@ -75,10 +76,19 @@ def get_ssh_client_impl(self): ) return client - def get_ssh_cmd(self): - return f"ssh -i {self.local_keyfile} ec2-user@{self.public_ip()}" - def get_sftp_client(self): t = paramiko.Transport((self.public_ip(), 22)) t.connect(username="ec2-user", pkey=paramiko.RSAKey.from_private_key_file(str(self.local_keyfile))) return paramiko.SFTPClient.from_transport(t) + + def open_ssh_tunnel_impl(self, remote_port) -> sshtunnel.SSHTunnelForwarder: + return sshtunnel.SSHTunnelForwarder( + (self.public_ip(), 22), + ssh_username="ec2-user", + ssh_pkey=str(self.local_keyfile), + local_bind_address=("127.0.0.1", 0), + remote_bind_address=("127.0.0.1", remote_port), + ) + + def get_ssh_cmd(self): + return f"ssh -i {self.local_keyfile} ec2-user@{self.public_ip()}" diff --git a/skylark/compute/azure/azure_server.py b/skylark/compute/azure/azure_server.py index 3dc053745..b8902551a 100644 --- a/skylark/compute/azure/azure_server.py +++ b/skylark/compute/azure/azure_server.py @@ -1,11 +1,13 @@ from pathlib import Path import paramiko +import sshtunnel from skylark import key_root from skylark.compute.azure.azure_auth import AzureAuthentication from skylark.compute.server import Server, ServerState from skylark.utils.cache import ignore_lru_cache from skylark.utils.utils import PathLike +from skylark.utils import logger import azure.core.exceptions @@ -177,11 +179,21 @@ def get_ssh_client_impl(self, uname="skylark", ssh_key_password="skylark"): ) return ssh_client - def get_ssh_cmd(self, uname="skylark", ssh_key_password="skylark"): - return f"ssh -i {self.ssh_private_key} {uname}@{self.public_ip()}" - def get_sftp_client(self, uname="skylark", ssh_key_password="skylark"): t = paramiko.Transport((self.public_ip(), 22)) pkey = paramiko.RSAKey.from_private_key_file(str(self.ssh_private_key), password=ssh_key_password) t.connect(username=uname, pkey=pkey) return paramiko.SFTPClient.from_transport(t) + + def open_ssh_tunnel_impl(self, remote_port, uname="skylark", ssh_key_password="skylark") -> sshtunnel.SSHTunnelForwarder: + return sshtunnel.SSHTunnelForwarder( + (self.public_ip(), 22), + ssh_username=uname, + ssh_pkey=str(self.ssh_private_key), + ssh_private_key_password=ssh_key_password, + local_bind_address=("127.0.0.1", 0), + remote_bind_address=("127.0.0.1", remote_port), + ) + + def get_ssh_cmd(self, uname="skylark", ssh_key_password="skylark"): + return f"ssh -i {self.ssh_private_key} {uname}@{self.public_ip()}" diff --git a/skylark/compute/cloud_providers.py b/skylark/compute/cloud_providers.py index ce2d74f1a..d082cd3db 100644 --- a/skylark/compute/cloud_providers.py +++ b/skylark/compute/cloud_providers.py @@ -6,6 +6,10 @@ class CloudProvider: + + logging_enabled = True # For Dozzle + log_viewer_port = 8888 + @property def name(self): raise NotImplementedError diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index 75e9df3eb..6808044d8 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -1,8 +1,9 @@ +from operator import not_ import os import time import uuid from pathlib import Path -from typing import List +from typing import List, Optional import googleapiclient import paramiko diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index 10f32c4a9..ad2a44366 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -2,10 +2,12 @@ from pathlib import Path import paramiko +import sshtunnel from skylark import key_root from skylark.compute.gcp.gcp_auth import GCPAuthentication from skylark.compute.server import Server, ServerState from skylark.utils.utils import PathLike +from skylark.utils import logger class GCPServer(Server): @@ -94,12 +96,22 @@ def get_ssh_client_impl(self, uname="skylark", ssh_key_password="skylark"): ) return ssh_client - def get_ssh_cmd(self, uname="skylark", ssh_key_password="skylark"): - # todo can we include the key password inline? - return f"ssh -i {self.ssh_private_key} -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no {uname}@{self.public_ip()}" - def get_sftp_client(self, uname="skylark", ssh_key_password="skylark"): t = paramiko.Transport((self.public_ip(), 22)) pkey = paramiko.RSAKey.from_private_key_file(str(self.ssh_private_key), password=ssh_key_password) t.connect(username=uname, pkey=pkey) return paramiko.SFTPClient.from_transport(t) + + def open_ssh_tunnel_impl(self, remote_port, uname="skylark", ssh_key_password="skylark") -> sshtunnel.SSHTunnelForwarder: + return sshtunnel.SSHTunnelForwarder( + (self.public_ip(), 22), + ssh_username=uname, + ssh_pkey=str(self.ssh_private_key), + ssh_private_key_password=ssh_key_password, + local_bind_address=("127.0.0.1", 0), + remote_bind_address=("127.0.0.1", remote_port), + ) + + def get_ssh_cmd(self, uname="skylark", ssh_key_password="skylark"): + # todo can we include the key password inline? + return f"ssh -i {self.ssh_private_key} -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no {uname}@{self.public_ip()}" diff --git a/skylark/compute/server.py b/skylark/compute/server.py index ec9c75264..45a8a6c6f 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -1,4 +1,5 @@ import json +import logging import socket from contextlib import closing from enum import Enum, auto @@ -12,6 +13,8 @@ from skylark.utils.net import retry_requests from skylark.utils.utils import PathLike, Timer, retry_backoff, wait_for +import sshtunnel + class ServerState(Enum): PENDING = auto() @@ -69,6 +72,7 @@ def __init__(self, region_tag, log_dir=None): self.region_tag = region_tag # format provider:region self.command_log = [] self.init_log_files(log_dir) + self.ssh_tunnels: Dict[int, sshtunnel.SSHTunnelForwarder] = {} def __repr__(self): return f"Server({self.uuid()})" @@ -93,6 +97,9 @@ def get_sftp_client(self): def get_ssh_client_impl(self): raise NotImplementedError() + def open_ssh_tunnel_impl(self, remote_port) -> sshtunnel.SSHTunnelForwarder: + raise NotImplementedError() + def get_ssh_cmd(self) -> str: raise NotImplementedError() @@ -103,6 +110,13 @@ def ssh_client(self): self._ssh_client = self.get_ssh_client_impl() return self._ssh_client + def tunnel_port(self, remote_port: int) -> int: + """Returns a local port that tunnels to the remote port.""" + if remote_port not in self.ssh_tunnels: + self.ssh_tunnels[remote_port] = self.open_ssh_tunnel_impl(remote_port) + self.ssh_tunnels[remote_port].start() + return self.ssh_tunnels[remote_port].local_bind_port + @property def provider(self) -> str: """Format provider""" @@ -152,6 +166,10 @@ def is_up(): wait_for(is_up, timeout=timeout, interval=interval, desc=f"Waiting for {self.uuid()} to be ready") def close_server(self): + if hasattr(self, "_ssh_client"): + self._ssh_client.close() + for tunnel in self.ssh_tunnels.values(): + tunnel.stop() self.flush_command_log() def flush_command_log(self): @@ -176,15 +194,13 @@ def run_command(self, command): def download_file(self, remote_path, local_path): """Download a file from the server""" - sftp_client = self.get_sftp_client() - sftp_client.get(remote_path, local_path) - sftp_client.close() + self.get_sftp_client().get(remote_path, local_path) + self.get_sftp_client().close() def upload_file(self, local_path, remote_path): """Upload a file to the server""" - sftp_client = self.get_sftp_client() - sftp_client.put(local_path, remote_path) - sftp_client.close() + self.get_sftp_client().put(local_path, remote_path) + self.get_sftp_client().close() def copy_public_key(self, pub_key_path: PathLike): """Append public key to authorized_keys file on server.""" @@ -255,29 +271,32 @@ def check_stderr(tup): logger.fs.debug(desc_prefix + f": Gateway started {start_out.strip()}") assert not start_err.strip(), f"Error starting gateway: {start_err.strip()}" - # load URLs gateway_container_hash = start_out.strip().split("\n")[-1][:12] - self.gateway_api_url = f"http://{self.public_ip()}:8080/api/v1" - self.gateway_log_viewer_url = f"http://{self.public_ip()}:8888/container/{gateway_container_hash}" + self.gateway_log_viewer_url = f"http://127.0.0.1:{self.tunnel_port(8888)}/container/{gateway_container_hash}" # wait for gateways to start (check status API) - def is_ready(): - api_url = f"http://{self.public_ip()}:8080/api/v1/status" + def is_api_ready(): try: + api_url = f"http://127.0.0.1:{self.tunnel_port(8080)}/api/v1/status" status_val = retry_requests().get(api_url) is_up = status_val.json().get("status") == "ok" return is_up - except Exception: + except Exception as e: + logger.error(f"{desc_prefix}: Failed to check gateway status: {e}") return False try: - wait_for(is_ready, timeout=10, interval=0.1, desc=f"Waiting for gateway {self.uuid()} to start", leave_pbar=False) + logging.disable(logging.CRITICAL) + wait_for(is_api_ready, timeout=5, interval=0.1, desc=f"Waiting for gateway {self.uuid()} to start", leave_pbar=False) except TimeoutError as e: - logger.error(f"Gateway {self.instance_name()} is not ready {e}") - logger.warning(desc_prefix + " gateway launch command: " + docker_launch_cmd) + logger.fs.error(f"Gateway {self.instance_name()} is not ready {e}") + logger.fs.warning(desc_prefix + " gateway launch command: " + docker_launch_cmd) logs, err = self.run_command(f"sudo docker logs skylark_gateway --tail=100") - logger.error(f"Docker logs: {logs}\nerr: {err}") + logger.fs.error(f"Docker logs: {logs}\nerr: {err}") out, err = self.run_command(docker_launch_cmd.replace(" -d ", " ")) - logger.error(f"Relaunching gateway in foreground\nout: {out}\nerr: {err}") + logger.fs.error(f"Relaunching gateway in foreground\nout: {out}\nerr: {err}") + logger.fs.exception(e) raise e + finally: + logging.disable(logging.NOTSET) diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index 5839905d3..609513fca 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -79,7 +79,12 @@ def provision_gateways( jobs = [] jobs.append(partial(self.aws.create_iam, attach_policy_arn="arn:aws:iam::aws:policy/AmazonS3FullAccess")) for r in set(aws_regions_to_provision): - jobs.append(partial(self.aws.add_ip_to_security_group, r.split(":")[1])) + + def init_aws_vpc(r): + self.aws.make_vpc(r) + self.aws.authorize_client(r, "0.0.0.0/0") + + jobs.append(partial(init_aws_vpc, r.split(":")[1])) jobs.append(partial(self.aws.ensure_keyfile_exists, r.split(":")[1])) if azure_regions_to_provision: jobs.append(self.azure.create_ssh_key) @@ -193,6 +198,14 @@ def provision_gateway_instance(region: str) -> Server: self.bound_nodes[node] = instance self.temp_nodes.remove(instance) + # Firewall rules + # todo add firewall rules for Azure and GCP + public_ips = [self.bound_nodes[n].public_ip() for n in self.topology.nodes] + aws_jobs = [ + partial(self.aws.add_ip_to_security_group, r.split(":")[1], ip) for r in set(aws_regions_to_provision) for ip in public_ips + ] + do_parallel(lambda fn: fn(), aws_jobs, spinner=True, desc="Applying firewall rules") + # setup instances def setup(args: Tuple[Server, Dict[str, int]]): server, outgoing_ports = args @@ -208,11 +221,20 @@ def setup(args: Tuple[Server, Dict[str, int]]): do_parallel(setup, args, n=-1, spinner=True, spinner_persist=True, desc="Install gateway package on instances") def deprovision_gateways(self): + # This is a good place to tear down Security Groups and the instance since this is invoked by CLI too. def deprovision_gateway_instance(server: Server): if server.instance_state() == ServerState.RUNNING: server.terminate_instance() logger.fs.warning(f"Deprovisioned {server.uuid()}") + # Clear IPs from security groups + # todo remove firewall rules for Azure and GCP + public_ips = [i.public_ip() for i in self.bound_nodes.values()] + [i.public_ip() for i in self.temp_nodes] + aws_regions = [node.region for node in self.topology.nodes if node.region.startswith("aws:")] + aws_jobs = [partial(self.aws.remove_ip_from_security_group, r.split(":")[1], ip) for r in set(aws_regions) for ip in public_ips] + do_parallel(lambda fn: fn(), aws_jobs) + + # Terminate instances instances = self.bound_nodes.values() logger.fs.warning(f"Deprovisioning {len(instances)} instances") do_parallel(deprovision_gateway_instance, instances, n=-1, spinner=True, spinner_persist=True, desc="Deprovisioning instances") @@ -384,8 +406,11 @@ def partition(items: List[Chunk], n_batches: int) -> List[List[Chunk]]: def send_chunk_requests(args: Tuple[Server, List[ChunkRequest]]): hop_instance, chunk_requests = args ip = gateway_ips[hop_instance] - logger.fs.debug(f"Sending {len(chunk_requests)} chunk requests to {ip}") - reply = retry_requests().post(f"http://{ip}:8080/api/v1/chunk_requests", json=[cr.as_dict() for cr in chunk_requests]) + tunnel_port = hop_instance.tunnel_port(8080) + logger.fs.debug(f"Sending {len(chunk_requests)} chunk requests to {ip} (via 127.0.0.1:{tunnel_port})") + reply = retry_requests().post( + f"http://127.0.0.1:{tunnel_port}/api/v1/chunk_requests", json=[cr.as_dict() for cr in chunk_requests] + ) if reply.status_code != 200: raise Exception(f"Failed to send chunk requests to gateway instance {hop_instance.instance_name()}: {reply.text}") @@ -398,7 +423,7 @@ def send_chunk_requests(args: Tuple[Server, List[ChunkRequest]]): def get_chunk_status_log_df(self) -> pd.DataFrame: def get_chunk_status(args): node, instance = args - reply = retry_requests().get(f"http://{instance.public_ip()}:8080/api/v1/chunk_status_log") + reply = retry_requests().get(f"http://127.0.0.1:{instance.tunnel_port(8080)}/api/v1/chunk_status_log") if reply.status_code != 200: raise Exception(f"Failed to get chunk status from gateway instance {instance.instance_name()}: {reply.text}") logs = [] @@ -561,7 +586,7 @@ def copy_log(instance): def fn(s: Server): try: - retry_requests().post(f"http://{s.public_ip()}:8080/api/v1/shutdown") + retry_requests().post(f"http://127.0.0.1:{s.tunnel_port(8080)}/api/v1/shutdown") except: return # ignore connection errors since server may be shutting down