diff --git a/skyplane/cli/cli.py b/skyplane/cli/cli.py index 361818c4a..700e20be6 100644 --- a/skyplane/cli/cli.py +++ b/skyplane/cli/cli.py @@ -23,7 +23,6 @@ ) from skyplane.replicate.replication_plan import ReplicationJob from skyplane.cli.cli_impl.init import load_aws_config, load_azure_config, load_gcp_config -from skyplane.cli.cli_impl.ls import ls_local, ls_objstore from skyplane.cli.common import check_ulimit, parse_path, query_instances from skyplane.compute.aws.aws_auth import AWSAuthentication from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider @@ -117,8 +116,8 @@ def error_local(): if provider_src in clouds and provider_dst in clouds: try: src_client = ObjectStoreInterface.create(clouds[provider_src], bucket_src) - src_region = src_client.region_tag() dst_client = ObjectStoreInterface.create(clouds[provider_dst], bucket_dst) + src_region = src_client.region_tag() dst_region = dst_client.region_tag() transfer_pairs = generate_full_transferobjlist( src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst, recursive=recursive diff --git a/skyplane/cli/cli_aws.py b/skyplane/cli/cli_aws.py index b7b62bc9d..e513280cd 100644 --- a/skyplane/cli/cli_aws.py +++ b/skyplane/cli/cli_aws.py @@ -41,9 +41,8 @@ def get_service_quota(region): @app.command() def cp_datasync(src_bucket: str, dst_bucket: str, path: str): aws_auth = AWSAuthentication() - s3_interface = S3Interface(None, aws_region="us-east-1") - src_region = s3_interface.infer_s3_region(src_bucket) - dst_region = s3_interface.infer_s3_region(dst_bucket) + src_region = S3Interface(src_bucket, aws_region="infer").aws_region + dst_region = S3Interface(dst_bucket, aws_region="infer").aws_region iam_client = aws_auth.get_boto3_client("iam", "us-east-1") try: diff --git a/skyplane/cli/cli_impl/cp_replicate.py b/skyplane/cli/cli_impl/cp_replicate.py index 7646453c3..72e784821 100644 --- a/skyplane/cli/cli_impl/cp_replicate.py +++ b/skyplane/cli/cli_impl/cp_replicate.py @@ -118,6 +118,13 @@ def generate_full_transferobjlist( """Query source region and destination region buckets and return list of objects to transfer.""" source_iface = ObjectStoreInterface.create(source_region, source_bucket) dest_iface = ObjectStoreInterface.create(dest_region, dest_bucket) + + # ensure buckets exist + if not source_iface.bucket_exists(): + raise exceptions.MissingBucketException(f"Source bucket {source_bucket} does not exist") + if not dest_iface.bucket_exists(): + raise exceptions.MissingBucketException(f"Destination bucket {dest_bucket} does not exist") + source_objs, dest_objs = [], [] # query all source region objects diff --git a/skyplane/cli/cli_internal.py b/skyplane/cli/cli_internal.py index 038c11548..b9462c30f 100644 --- a/skyplane/cli/cli_internal.py +++ b/skyplane/cli/cli_internal.py @@ -5,7 +5,7 @@ import typer from skyplane.cli.common import print_header -from skyplane import skyplane_root, MB +from skyplane import skyplane_root from skyplane.cli.cli_impl.cp_replicate import confirm_transfer, launch_replication_job from skyplane.obj_store.object_store_interface import ObjectStoreObject from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob diff --git a/skyplane/cli/common.py b/skyplane/cli/common.py index 21c85c3d5..d4173553c 100644 --- a/skyplane/cli/common.py +++ b/skyplane/cli/common.py @@ -1,10 +1,7 @@ -import os import re -import resource import subprocess from functools import partial from pathlib import Path -from sys import platform import typer from rich.console import Console diff --git a/skyplane/compute/gcp/gcp_auth.py b/skyplane/compute/gcp/gcp_auth.py index de897e94a..70c6308d5 100644 --- a/skyplane/compute/gcp/gcp_auth.py +++ b/skyplane/compute/gcp/gcp_auth.py @@ -1,8 +1,6 @@ -from re import I from pathlib import Path from typing import Optional import base64 -import json import os import google.auth @@ -12,7 +10,6 @@ from skyplane import cloud_config, config_path, gcp_config_path, key_root from skyplane.config import SkyplaneConfig from skyplane.utils import logger -from google.oauth2 import service_account class GCPAuthentication: diff --git a/skyplane/compute/gcp/gcp_cloud_provider.py b/skyplane/compute/gcp/gcp_cloud_provider.py index 7f7c7da8b..15899c6ec 100644 --- a/skyplane/compute/gcp/gcp_cloud_provider.py +++ b/skyplane/compute/gcp/gcp_cloud_provider.py @@ -2,7 +2,7 @@ import time import uuid from pathlib import Path -from typing import List, Optional +from typing import List import googleapiclient import paramiko diff --git a/skyplane/compute/server.py b/skyplane/compute/server.py index e5b80faa5..9e3626fa6 100644 --- a/skyplane/compute/server.py +++ b/skyplane/compute/server.py @@ -345,7 +345,7 @@ def is_api_ready(): status_val = json.loads(http_pool.request("GET", api_url).data.decode("utf-8")) is_up = status_val.get("status") == "ok" return is_up - except Exception as e: + except Exception: return False try: diff --git a/skyplane/obj_store/azure_interface.py b/skyplane/obj_store/azure_interface.py index c4d2972f3..e46c74e09 100644 --- a/skyplane/obj_store/azure_interface.py +++ b/skyplane/obj_store/azure_interface.py @@ -26,38 +26,13 @@ def full_path(self): class AzureInterface(ObjectStoreInterface): - def __init__(self, account_name, container_name, region="infer", create_bucket=False, max_concurrency=1): + def __init__(self, account_name: str, container_name: str, region: str = "infer", max_concurrency=1): self.auth = AzureAuthentication() self.account_name = account_name self.container_name = container_name self.account_url = f"https://{self.account_name}.blob.core.windows.net" self.max_concurrency = max_concurrency # parallel upload/downloads, seems to cause issues if too high - - # check container exists - if not self.storage_account_exists(): - if create_bucket: - self.create_storage_account() - logger.info(f"Created Azure storage account {self.account_name}") - else: - # print available storage accounts from azure API - avail_storage_accounts = [account.name for account in self.storage_management_client.storage_accounts.list()] - token = self.auth.credential.get_token("https://management.azure.com/") - raise exceptions.MissingBucketException( - f"Azure storage account {self.account_name} not found, found the following storage accounts: {avail_storage_accounts} with token {token}" - ) - if not self.container_exists(): - if create_bucket: - self.create_container() - logger.info(f"Created Azure container {self.container_name}") - else: - raise exceptions.MissingBucketException(f"Azure container {self.container_name} not found") - - # infer region - if region == "infer": - self.storage_account = self.query_storage_account(self.account_name) - self.azure_region = self.storage_account.location - else: - self.azure_region = region + self.azure_region = self.query_storage_account(self.account_name).location if region == "infer" else region @property def blob_service_client(self): @@ -95,6 +70,9 @@ def container_exists(self): except ResourceNotFoundError: return False + def bucket_exists(self): + return self.storage_account_exists() and self.container_exists() + def exists(self, obj_name): return self.blob_service_client.get_blob_client(container=self.container_name, blob=obj_name).exists() diff --git a/skyplane/obj_store/gcs_interface.py b/skyplane/obj_store/gcs_interface.py index e3b816d71..f04737105 100644 --- a/skyplane/obj_store/gcs_interface.py +++ b/skyplane/obj_store/gcs_interface.py @@ -11,7 +11,6 @@ from skyplane import exceptions from skyplane.compute.gcp.gcp_auth import GCPAuthentication from skyplane.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject -from skyplane.utils import logger class GCSObject(ObjectStoreObject): @@ -20,24 +19,12 @@ def full_path(self): class GCSInterface(ObjectStoreInterface): - def __init__(self, bucket_name, gcp_region="infer", create_bucket=False): + def __init__(self, bucket_name: str, gcp_region: str = "infer"): self.bucket_name = bucket_name self.auth = GCPAuthentication() - # self.auth.set_service_account_credentials("skyplane1") # use service account credentials self._gcs_client = self.auth.get_storage_client() self._requests_session = requests.Session() - try: - self.gcp_region = self.infer_gcp_region(bucket_name) if gcp_region is None or gcp_region == "infer" else gcp_region - if not self.bucket_exists(): - raise exceptions.MissingBucketException() - except exceptions.MissingBucketException: - if create_bucket: - assert gcp_region is not None and gcp_region != "infer", "Must specify AWS region when creating bucket" - self.gcp_region = gcp_region - self.create_bucket() - logger.info(f"Created GCS bucket {self.bucket_name} in region {self.gcp_region}") - else: - raise + self.gcp_region = self.infer_gcp_region(bucket_name) if gcp_region == "infer" else gcp_region def region_tag(self): return "gcp:" + self.gcp_region @@ -69,6 +56,13 @@ def bucket_exists(self): except Exception: return False + def exists(self, obj_name): + try: + self.get_obj_metadata(obj_name) + return True + except NoSuchObjectException: + return False + def create_bucket(self, premium_tier=True): if not self.bucket_exists(): bucket = self._gcs_client.bucket(self.bucket_name) @@ -105,13 +99,6 @@ def get_obj_size(self, obj_name): def get_obj_last_modified(self, obj_name): return self.get_obj_metadata(obj_name).updated - def exists(self, obj_name): - try: - self.get_obj_metadata(obj_name) - return True - except NoSuchObjectException: - return False - def send_xml_request( self, blob_name: str, @@ -246,7 +233,7 @@ def complete_multipart_upload(self, dst_object_name, upload_id): response = self.send_xml_request( dst_object_name, {"uploadId": upload_id}, "POST", data=xml_data, content_type="application/xml" ) - except Exception as e: + except Exception: # cancel upload response = self.send_xml_request(dst_object_name, {"uploadId": upload_id}, "DELETE") return False diff --git a/skyplane/obj_store/object_store_interface.py b/skyplane/obj_store/object_store_interface.py index 3a38a6077..6408c454d 100644 --- a/skyplane/obj_store/object_store_interface.py +++ b/skyplane/obj_store/object_store_interface.py @@ -36,6 +36,9 @@ def delete_bucket(self): def list_objects(self, prefix=""): raise NotImplementedError() + def bucket_exists(self): + raise NotImplementedError() + def exists(self): raise NotImplementedError() @@ -85,23 +88,23 @@ def complete_multipart_upload(self, dst_object_name, upload_id): return ValueError("Multipart uploads not supported") @staticmethod - def create(region_tag: str, bucket: str, create_bucket: bool = False): + def create(region_tag: str, bucket: str): if region_tag.startswith("aws"): from skyplane.obj_store.s3_interface import S3Interface _, region = region_tag.split(":", 1) - return S3Interface(bucket, aws_region=region, create_bucket=create_bucket) + return S3Interface(bucket, aws_region=region) elif region_tag.startswith("gcp"): from skyplane.obj_store.gcs_interface import GCSInterface _, region = region_tag.split(":", 1) - return GCSInterface(bucket, gcp_region=region, create_bucket=create_bucket) + return GCSInterface(bucket, gcp_region=region) elif region_tag.startswith("azure"): from skyplane.obj_store.azure_interface import AzureInterface storage_account, container = bucket.split("/", 1) # / _, region = region_tag.split(":", 1) - return AzureInterface(storage_account, container, region=region, create_bucket=create_bucket) + return AzureInterface(storage_account, container, region=region) else: raise ValueError(f"Invalid region_tag {region_tag} - could not create interface") diff --git a/skyplane/obj_store/s3_interface.py b/skyplane/obj_store/s3_interface.py index fc845a486..224b80ab6 100644 --- a/skyplane/obj_store/s3_interface.py +++ b/skyplane/obj_store/s3_interface.py @@ -18,24 +18,10 @@ def full_path(self): class S3Interface(ObjectStoreInterface): - def __init__(self, bucket_name, aws_region="infer", create_bucket=False): + def __init__(self, bucket_name: str, aws_region: str = "infer"): self.auth = AWSAuthentication() self.bucket_name = bucket_name - try: - if bucket_name is not None: - self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region - if not self.bucket_exists(): - raise exceptions.MissingBucketException(f"Bucket {bucket_name} does not exist") - else: - self.aws_region = None - except exceptions.MissingBucketException: - if create_bucket: - assert aws_region is not None and aws_region != "infer", "Must specify AWS region when creating bucket" - self.aws_region = aws_region - self.create_bucket() - logger.info(f"Created S3 bucket {self.bucket_name} in region {self.aws_region}") - else: - raise + self.aws_region = self.infer_s3_region(bucket_name) if aws_region == "infer" else aws_region def region_tag(self): return "aws:" + self.aws_region diff --git a/skyplane/replicate/replicator_client.py b/skyplane/replicate/replicator_client.py index 2a80e0d48..166a64284 100644 --- a/skyplane/replicate/replicator_client.py +++ b/skyplane/replicate/replicator_client.py @@ -269,8 +269,9 @@ def deprovision_gateway_instance(server: Server): public_ips = [i.public_ip() for i in self.bound_nodes.values()] + [i.public_ip() for i in self.temp_nodes] aws_regions = [node.region for node in self.topology.gateway_nodes if node.region.startswith("aws:")] aws_jobs = [partial(self.aws.remove_ips_from_security_group, r.split(":")[1], public_ips) for r in set(aws_regions)] - do_parallel(lambda fn: fn(), aws_jobs) - gcp_jobs = self.gcp.remove_ips_from_firewall(public_ips) + gcp_regions = [node.region for node in self.topology.gateway_nodes if node.region.startswith("gcp:")] + gcp_jobs = [self.gcp.remove_ips_from_firewall(public_ips)] if gcp_regions else [] + do_parallel(lambda fn: fn(), aws_jobs + gcp_jobs, desc="Removing firewall rules") # Terminate instances instances = list(self.bound_nodes.values()) + self.temp_nodes diff --git a/tests/interface_util.py b/tests/interface_util.py index 99a439ca2..d87328036 100644 --- a/tests/interface_util.py +++ b/tests/interface_util.py @@ -12,7 +12,8 @@ def interface_test_framework(region, bucket, multipart: bool, test_delete_bucket: bool = False): logger.info("creating interfaces...") - interface = ObjectStoreInterface.create(region, bucket, create_bucket=True) + interface = ObjectStoreInterface.create(region, bucket) + interface.create_bucket() assert interface.bucket_exists() debug_time = lambda n, s, e: logger.info(f"{n} {s}MB in {round(e, 2)}s ({round(s / e, 2)}MB/s)")