Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit cbae100
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 23:09:06 2022 -0700

    Update merge

commit c60c535
Merge: 3731242 0321bff
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 23:07:21 2022 -0700

    Merge branch 'main' into skycamp/prefix_aws_keys

commit 3731242
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 23:03:08 2022 -0700

    Better rename setup

commit 10a9699
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 23:00:28 2022 -0700

    Tags is a function

commit a357d24
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 22:58:19 2022 -0700

    Update key prefix

commit 3e6b2f2
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 22:48:45 2022 -0700

    Tag VMs with a client ID when provisioning

commit d47ad76
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 21:29:07 2022 -0700

    Add option to deprovision tagged instances

commit ec096b9
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 22:46:21 2022 -0700

    Clean up leaked resources upon VM provisioning errors for AWS, GCP and Azure (#617)

commit 44eb38f
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 22:44:14 2022 -0700

    Add flag before deprovisioning networks (#615)

commit 961e64d
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 21:32:08 2022 -0700

    Tag VMs with a client ID when provisioning

commit d39af7e
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 21:29:07 2022 -0700

    Add option to deprovision tagged instances

commit fab0356
Author: Paras Jain <[email protected]>
Date:   Mon Oct 17 20:39:44 2022 -0700

    Prefix AWS keys with UUID
  • Loading branch information
parasj committed Oct 18, 2022
1 parent 0321bff commit 3be0485
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 37 deletions.
51 changes: 26 additions & 25 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,55 @@
"""CLI for the Skyplane object store"""
import os
import subprocess
import time
from functools import partial
import traceback
from pathlib import Path
from shlex import split
import traceback
from urllib import request
import uuid
import os

from rich import print as rprint

import skyplane.cli
import skyplane.cli.usage.definitions
import skyplane.cli.usage.client
from skyplane import GB
from skyplane.cli.usage.client import UsageClient, UsageStatsStatus
from skyplane.compute.azure.azure_auth import AzureAuthentication
from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider
from skyplane.compute.gcp.gcp_auth import GCPAuthentication
from skyplane.compute.gcp.gcp_cloud_provider import GCPCloudProvider
from skyplane.replicate.replicator_client import ReplicatorClient, TransferStats
from typing import Optional

import typer
from rich import print as rprint
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import IntPrompt

import skyplane.cli
import skyplane.cli
import skyplane.cli.cli_aws
import skyplane.cli.cli_azure
import skyplane.cli.cli_config
import skyplane.cli.cli_internal as cli_internal
import skyplane.cli.experiments
from skyplane import cloud_config, config_path, exceptions, skyplane_root
from skyplane.cli.common import print_header, console, print_stats_completed
import skyplane.cli.usage.client
import skyplane.cli.usage.client
import skyplane.cli.usage.definitions
import skyplane.cli.usage.definitions
from skyplane import GB, cloud_config, config_path, exceptions, skyplane_root
from skyplane.cli.cli_impl.cp_replicate import (
confirm_transfer,
enrich_dest_objs,
generate_full_transferobjlist,
generate_topology,
confirm_transfer,
launch_replication_job,
)
from skyplane.cli.cli_impl.cp_replicate_fallback import (
get_usage_gbits,
replicate_onprem_cp_cmd,
replicate_onprem_sync_cmd,
replicate_small_cp_cmd,
replicate_small_sync_cmd,
get_usage_gbits,
)
from skyplane.replicate.replication_plan import ReplicationJob
from skyplane.cli.cli_impl.init import load_aws_config, load_azure_config, load_gcp_config
from skyplane.cli.common import parse_path, query_instances
from skyplane.cli.common import console, parse_path, print_header, print_stats_completed, query_instances
from skyplane.cli.usage.client import UsageClient, UsageStatsStatus
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider
from skyplane.compute.azure.azure_auth import AzureAuthentication
from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider
from skyplane.compute.gcp.gcp_auth import GCPAuthentication
from skyplane.compute.gcp.gcp_cloud_provider import GCPCloudProvider
from skyplane.config import SkyplaneConfig
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.replicate.replication_plan import ReplicationJob
from skyplane.replicate.replicator_client import ReplicatorClient, TransferStats
from skyplane.utils import logger
from skyplane.utils.fn import do_parallel

Expand Down Expand Up @@ -248,6 +244,7 @@ def cp(
multipart_chunk_size_mb=cloud_config.get_flag("multipart_chunk_size_mb"),
multipart_max_chunks=cloud_config.get_flag("multipart_max_chunks"),
error_reporting_args=args,
host_uuid=cloud_config.anon_clientid,
)
if cloud_config.get_flag("verify_checksums"):
provider_dst = topo.sink_region().split(":")[0]
Expand Down Expand Up @@ -444,6 +441,7 @@ def sync(
multipart_chunk_size_mb=cloud_config.get_flag("multipart_chunk_size_mb"),
multipart_max_chunks=cloud_config.get_flag("multipart_max_chunks"),
error_reporting_args=args,
host_uuid=cloud_config.anon_clientid,
)
if cloud_config.get_flag("verify_checksums"):
provider_dst = topo.sink_region().split(":")[0]
Expand Down Expand Up @@ -475,9 +473,12 @@ def sync(
@app.command()
def deprovision(
all: bool = typer.Option(False, "--all", "-a", help="Deprovision all resources including networks."),
filter_client_id: Optional[str] = typer.Option(None, help="Only deprovision instances with this client ID under the instance tag."),
):
"""Deprovision all resources created by skyplane."""
instances = query_instances()
if filter_client_id:
instances = [instance for instance in instances if instance.tags().get("skyplaneclientid") == filter_client_id]

if instances:
typer.secho(f"Deprovisioning {len(instances)} instances", fg="yellow", bold=True)
Expand Down
2 changes: 2 additions & 0 deletions skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def launch_replication_job(
time_limit_seconds: Optional[int] = None,
log_interval_s: float = 1.0,
error_reporting_args: Optional[Dict] = None,
host_uuid: Optional[str] = None,
):
if "SKYPLANE_DOCKER_IMAGE" in os.environ:
rprint(f"[bright_black]Using overridden docker image: {gateway_docker_image}[/bright_black]")
Expand All @@ -297,6 +298,7 @@ def launch_replication_job(
azure_instance_class=azure_instance_class,
gcp_instance_class=gcp_instance_class,
gcp_use_premium_network=gcp_use_premium_network,
host_uuid=host_uuid,
)
typer.secho(f"Storing debug information for transfer in {rc.transfer_dir / 'client.log'}", fg="yellow", err=True)
(rc.transfer_dir / "topology.json").write_text(topo.to_json())
Expand Down
13 changes: 11 additions & 2 deletions skyplane/cli/cli_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def replicate_random(
)
confirm_transfer(topo=topo, job=job, ask_to_confirm_transfer=False)
stats = launch_replication_job(
topo=topo, job=job, debug=debug, reuse_gateways=reuse_gateways, use_bbr=use_bbr, use_compression=False, use_e2ee=True
topo=topo,
job=job,
debug=debug,
reuse_gateways=reuse_gateways,
use_bbr=use_bbr,
use_compression=False,
use_e2ee=True,
host_uuid=None,
)
print(stats)
return 0 if stats.monitor_status == "completed" else 1
Expand Down Expand Up @@ -127,5 +134,7 @@ def replicate_random_solve(
random_chunk_size_mb=total_transfer_size_mb // n_chunks,
)
confirm_transfer(topo=topo, job=job, ask_to_confirm_transfer=False)
stats = launch_replication_job(topo=topo, job=job, debug=debug, reuse_gateways=reuse_gateways, use_bbr=use_bbr, use_compression=False)
stats = launch_replication_job(
topo=topo, job=job, debug=debug, reuse_gateways=reuse_gateways, use_bbr=use_bbr, use_compression=False, host_uuid=None
)
return 0 if stats.monitor_status == "completed" else 1
1 change: 1 addition & 0 deletions skyplane/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import time
import uuid
from multiprocessing import BoundedSemaphore
Expand Down
2 changes: 1 addition & 1 deletion skyplane/compute/aws/aws_key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_key(self, aws_region: str, key_name: str) -> Path:
local_key_file = self.local_key_dir / f"{key_name}.pem"
local_key_file.parent.mkdir(parents=True, exist_ok=True)
logger.fs.debug(f"[AWS] Creating keypair {key_name} in {aws_region}")
key_pair = ec2.create_key_pair(KeyName=f"skyplane-{aws_region}", KeyType="rsa")
key_pair = ec2.create_key_pair(KeyName=key_name, KeyType="rsa")
with local_key_file.open("w") as f:
key_str = key_pair.key_material
if not key_str.endswith("\n"):
Expand Down
19 changes: 16 additions & 3 deletions skyplane/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

from cryptography.utils import CryptographyDeprecationWarning

from skyplane.compute.aws.aws_key_manager import AWSKeyManager

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
import paramiko

from skyplane import key_root, exceptions
from skyplane import exceptions, key_root
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.compute.server import Server, ServerState
from skyplane.utils.cache import ignore_lru_cache
from skyplane.utils import imports
from skyplane.utils.cache import ignore_lru_cache


class AWSServer(Server):
Expand All @@ -22,9 +24,9 @@ def __init__(self, region_tag, instance_id, log_dir=None):
super().__init__(region_tag, log_dir=log_dir)
assert self.region_tag.split(":")[0] == "aws"
self.auth = AWSAuthentication()
self.key_manager = AWSKeyManager(self.auth)
self.aws_region = self.region_tag.split(":")[1]
self.instance_id = instance_id
self.local_keyfile = key_root / "aws" / f"skyplane-{self.aws_region}.pem"

@property
@imports.inject("boto3", pip_extra="aws")
Expand Down Expand Up @@ -71,6 +73,17 @@ def region(self):
def instance_state(self):
return ServerState.from_aws_state(self.get_boto3_instance_resource().state["Name"])

@property
@ignore_lru_cache()
def local_keyfile(self):
key_name = self.get_boto3_instance_resource().key_name
if self.key_manager.key_exists_local(key_name):
return self.key_manager.get_key(key_name)
else:
raise exceptions.BadConfigException(
f"Failed to connect to AWS server {self.uuid()}. Delete local AWS keys and retry: `rm -rf {key_root / 'aws'}`"
)

def __repr__(self):
return f"AWSServer(region_tag={self.region_tag}, instance_id={self.instance_id})"

Expand Down
5 changes: 2 additions & 3 deletions skyplane/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ class SkyplaneConfig:
aws_enabled: bool
azure_enabled: bool
gcp_enabled: bool
anon_clientid: str
azure_principal_id: Optional[str] = None
azure_subscription_id: Optional[str] = None
azure_client_id: Optional[str] = None
gcp_project_id: Optional[str] = None
anon_clientid: Optional[str] = None

@staticmethod
def generate_machine_id() -> str:
Expand Down Expand Up @@ -186,8 +186,7 @@ def to_config_file(self, path):

if "client" not in config:
config.add_section("client")
if self.anon_clientid:
config.set("client", "anon_clientid", self.anon_clientid)
config.set("client", "anon_clientid", self.anon_clientid)

if "flags" not in config:
config.add_section("flags")
Expand Down
14 changes: 11 additions & 3 deletions skyplane/replicate/replicator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
azure_instance_class: Optional[str] = "Standard_D2_v5", # set to None to disable Azure
gcp_instance_class: Optional[str] = "n2-standard-16", # set to None to disable GCP
gcp_use_premium_network: bool = True,
host_uuid: Optional[str] = None,
):
self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3))
self.topology = topology
Expand All @@ -80,9 +81,10 @@ def __init__(
self.azure_instance_class = azure_instance_class
self.gcp_instance_class = gcp_instance_class
self.gcp_use_premium_network = gcp_use_premium_network
self.host_uuid = host_uuid

# provisioning
self.aws = AWSCloudProvider()
self.aws = AWSCloudProvider(key_prefix=f"skyplane-{host_uuid.replace('-', '') if host_uuid else ''}")
self.azure = AzureCloudProvider()
self.gcp = GCPCloudProvider()
self.bound_nodes: Dict[ReplicationTopologyGateway, Server] = {}
Expand Down Expand Up @@ -192,12 +194,17 @@ def provision_gateways(
# provision instances
def provision_gateway_instance(region: str) -> Server:
provider, subregion = region.split(":")
tags = {"skyplane": "true", "skyplaneclientid": self.host_uuid} if self.host_uuid else {"skyplane": "true"}
if provider == "aws":
assert self.aws.auth.enabled()
server = self.aws.provision_instance(subregion, self.aws_instance_class, use_spot_instances=aws_use_spot_instances)
server = self.aws.provision_instance(
subregion, self.aws_instance_class, use_spot_instances=aws_use_spot_instances, tags=tags
)
elif provider == "azure":
assert self.azure.auth.enabled()
server = self.azure.provision_instance(subregion, self.azure_instance_class, use_spot_instances=azure_use_spot_instances)
server = self.azure.provision_instance(
subregion, self.azure_instance_class, use_spot_instances=azure_use_spot_instances, tags=tags
)
elif provider == "gcp":
assert self.gcp.auth.enabled()
# todo specify network tier in ReplicationTopology
Expand All @@ -206,6 +213,7 @@ def provision_gateway_instance(region: str) -> Server:
self.gcp_instance_class,
use_spot_instances=gcp_use_spot_instances,
gcp_premium_network=self.gcp_use_premium_network,
tags=tags,
)
else:
raise NotImplementedError(f"Unknown provider {provider}")
Expand Down

0 comments on commit 3be0485

Please sign in to comment.