Skip to content

Commit

Permalink
Remove create_bucket option from ObjStoreInterface (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj committed Jul 11, 2022
1 parent c260d7a commit 200c2f4
Show file tree
Hide file tree
Showing 14 changed files with 42 additions and 87 deletions.
3 changes: 1 addition & 2 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from skyplane.replicate.replication_plan import ReplicationJob
from skyplane.cli.cli_impl.init import load_aws_config, load_azure_config, load_gcp_config
from skyplane.cli.cli_impl.ls import ls_local, ls_objstore
from skyplane.cli.common import check_ulimit, parse_path, query_instances
from skyplane.compute.aws.aws_auth import AWSAuthentication
from skyplane.compute.aws.aws_cloud_provider import AWSCloudProvider
Expand Down Expand Up @@ -117,8 +116,8 @@ def error_local():
if provider_src in clouds and provider_dst in clouds:
try:
src_client = ObjectStoreInterface.create(clouds[provider_src], bucket_src)
src_region = src_client.region_tag()
dst_client = ObjectStoreInterface.create(clouds[provider_dst], bucket_dst)
src_region = src_client.region_tag()
dst_region = dst_client.region_tag()
transfer_pairs = generate_full_transferobjlist(
src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst, recursive=recursive
Expand Down
5 changes: 2 additions & 3 deletions skyplane/cli/cli_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def get_service_quota(region):
@app.command()
def cp_datasync(src_bucket: str, dst_bucket: str, path: str):
aws_auth = AWSAuthentication()
s3_interface = S3Interface(None, aws_region="us-east-1")
src_region = s3_interface.infer_s3_region(src_bucket)
dst_region = s3_interface.infer_s3_region(dst_bucket)
src_region = S3Interface(src_bucket, aws_region="infer").aws_region
dst_region = S3Interface(dst_bucket, aws_region="infer").aws_region

iam_client = aws_auth.get_boto3_client("iam", "us-east-1")
try:
Expand Down
7 changes: 7 additions & 0 deletions skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def generate_full_transferobjlist(
"""Query source region and destination region buckets and return list of objects to transfer."""
source_iface = ObjectStoreInterface.create(source_region, source_bucket)
dest_iface = ObjectStoreInterface.create(dest_region, dest_bucket)

# ensure buckets exist
if not source_iface.bucket_exists():
raise exceptions.MissingBucketException(f"Source bucket {source_bucket} does not exist")
if not dest_iface.bucket_exists():
raise exceptions.MissingBucketException(f"Destination bucket {dest_bucket} does not exist")

source_objs, dest_objs = [], []

# query all source region objects
Expand Down
2 changes: 1 addition & 1 deletion skyplane/cli/cli_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typer

from skyplane.cli.common import print_header
from skyplane import skyplane_root, MB
from skyplane import skyplane_root
from skyplane.cli.cli_impl.cp_replicate import confirm_transfer, launch_replication_job
from skyplane.obj_store.object_store_interface import ObjectStoreObject
from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob
Expand Down
3 changes: 0 additions & 3 deletions skyplane/cli/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import re
import resource
import subprocess
from functools import partial
from pathlib import Path
from sys import platform

import typer
from rich.console import Console
Expand Down
3 changes: 0 additions & 3 deletions skyplane/compute/gcp/gcp_auth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from re import I
from pathlib import Path
from typing import Optional
import base64
import json
import os

import google.auth
Expand All @@ -12,7 +10,6 @@
from skyplane import cloud_config, config_path, gcp_config_path, key_root
from skyplane.config import SkyplaneConfig
from skyplane.utils import logger
from google.oauth2 import service_account


class GCPAuthentication:
Expand Down
2 changes: 1 addition & 1 deletion skyplane/compute/gcp/gcp_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import uuid
from pathlib import Path
from typing import List, Optional
from typing import List

import googleapiclient
import paramiko
Expand Down
2 changes: 1 addition & 1 deletion skyplane/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def is_api_ready():
status_val = json.loads(http_pool.request("GET", api_url).data.decode("utf-8"))
is_up = status_val.get("status") == "ok"
return is_up
except Exception as e:
except Exception:
return False

try:
Expand Down
32 changes: 5 additions & 27 deletions skyplane/obj_store/azure_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,13 @@ def full_path(self):


class AzureInterface(ObjectStoreInterface):
def __init__(self, account_name, container_name, region="infer", create_bucket=False, max_concurrency=1):
def __init__(self, account_name: str, container_name: str, region: str = "infer", max_concurrency=1):
self.auth = AzureAuthentication()
self.account_name = account_name
self.container_name = container_name
self.account_url = f"https://{self.account_name}.blob.core.windows.net"
self.max_concurrency = max_concurrency # parallel upload/downloads, seems to cause issues if too high

# check container exists
if not self.storage_account_exists():
if create_bucket:
self.create_storage_account()
logger.info(f"Created Azure storage account {self.account_name}")
else:
# print available storage accounts from azure API
avail_storage_accounts = [account.name for account in self.storage_management_client.storage_accounts.list()]
token = self.auth.credential.get_token("https://management.azure.com/")
raise exceptions.MissingBucketException(
f"Azure storage account {self.account_name} not found, found the following storage accounts: {avail_storage_accounts} with token {token}"
)
if not self.container_exists():
if create_bucket:
self.create_container()
logger.info(f"Created Azure container {self.container_name}")
else:
raise exceptions.MissingBucketException(f"Azure container {self.container_name} not found")

# infer region
if region == "infer":
self.storage_account = self.query_storage_account(self.account_name)
self.azure_region = self.storage_account.location
else:
self.azure_region = region
self.azure_region = self.query_storage_account(self.account_name).location if region == "infer" else region

@property
def blob_service_client(self):
Expand Down Expand Up @@ -95,6 +70,9 @@ def container_exists(self):
except ResourceNotFoundError:
return False

def bucket_exists(self):
return self.storage_account_exists() and self.container_exists()

def exists(self, obj_name):
return self.blob_service_client.get_blob_client(container=self.container_name, blob=obj_name).exists()

Expand Down
33 changes: 10 additions & 23 deletions skyplane/obj_store/gcs_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from skyplane import exceptions
from skyplane.compute.gcp.gcp_auth import GCPAuthentication
from skyplane.obj_store.object_store_interface import NoSuchObjectException, ObjectStoreInterface, ObjectStoreObject
from skyplane.utils import logger


class GCSObject(ObjectStoreObject):
Expand All @@ -20,24 +19,12 @@ def full_path(self):


class GCSInterface(ObjectStoreInterface):
def __init__(self, bucket_name, gcp_region="infer", create_bucket=False):
def __init__(self, bucket_name: str, gcp_region: str = "infer"):
self.bucket_name = bucket_name
self.auth = GCPAuthentication()
# self.auth.set_service_account_credentials("skyplane1") # use service account credentials
self._gcs_client = self.auth.get_storage_client()
self._requests_session = requests.Session()
try:
self.gcp_region = self.infer_gcp_region(bucket_name) if gcp_region is None or gcp_region == "infer" else gcp_region
if not self.bucket_exists():
raise exceptions.MissingBucketException()
except exceptions.MissingBucketException:
if create_bucket:
assert gcp_region is not None and gcp_region != "infer", "Must specify AWS region when creating bucket"
self.gcp_region = gcp_region
self.create_bucket()
logger.info(f"Created GCS bucket {self.bucket_name} in region {self.gcp_region}")
else:
raise
self.gcp_region = self.infer_gcp_region(bucket_name) if gcp_region == "infer" else gcp_region

def region_tag(self):
return "gcp:" + self.gcp_region
Expand Down Expand Up @@ -69,6 +56,13 @@ def bucket_exists(self):
except Exception:
return False

def exists(self, obj_name):
try:
self.get_obj_metadata(obj_name)
return True
except NoSuchObjectException:
return False

def create_bucket(self, premium_tier=True):
if not self.bucket_exists():
bucket = self._gcs_client.bucket(self.bucket_name)
Expand Down Expand Up @@ -105,13 +99,6 @@ def get_obj_size(self, obj_name):
def get_obj_last_modified(self, obj_name):
return self.get_obj_metadata(obj_name).updated

def exists(self, obj_name):
try:
self.get_obj_metadata(obj_name)
return True
except NoSuchObjectException:
return False

def send_xml_request(
self,
blob_name: str,
Expand Down Expand Up @@ -246,7 +233,7 @@ def complete_multipart_upload(self, dst_object_name, upload_id):
response = self.send_xml_request(
dst_object_name, {"uploadId": upload_id}, "POST", data=xml_data, content_type="application/xml"
)
except Exception as e:
except Exception:
# cancel upload
response = self.send_xml_request(dst_object_name, {"uploadId": upload_id}, "DELETE")
return False
Expand Down
11 changes: 7 additions & 4 deletions skyplane/obj_store/object_store_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def delete_bucket(self):
def list_objects(self, prefix=""):
raise NotImplementedError()

def bucket_exists(self):
raise NotImplementedError()

def exists(self):
raise NotImplementedError()

Expand Down Expand Up @@ -85,23 +88,23 @@ def complete_multipart_upload(self, dst_object_name, upload_id):
return ValueError("Multipart uploads not supported")

@staticmethod
def create(region_tag: str, bucket: str, create_bucket: bool = False):
def create(region_tag: str, bucket: str):
if region_tag.startswith("aws"):
from skyplane.obj_store.s3_interface import S3Interface

_, region = region_tag.split(":", 1)
return S3Interface(bucket, aws_region=region, create_bucket=create_bucket)
return S3Interface(bucket, aws_region=region)
elif region_tag.startswith("gcp"):
from skyplane.obj_store.gcs_interface import GCSInterface

_, region = region_tag.split(":", 1)
return GCSInterface(bucket, gcp_region=region, create_bucket=create_bucket)
return GCSInterface(bucket, gcp_region=region)
elif region_tag.startswith("azure"):
from skyplane.obj_store.azure_interface import AzureInterface

storage_account, container = bucket.split("/", 1) # <storage_account>/<container>
_, region = region_tag.split(":", 1)
return AzureInterface(storage_account, container, region=region, create_bucket=create_bucket)
return AzureInterface(storage_account, container, region=region)
else:
raise ValueError(f"Invalid region_tag {region_tag} - could not create interface")

Expand Down
18 changes: 2 additions & 16 deletions skyplane/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,10 @@ def full_path(self):


class S3Interface(ObjectStoreInterface):
def __init__(self, bucket_name, aws_region="infer", create_bucket=False):
def __init__(self, bucket_name: str, aws_region: str = "infer"):
self.auth = AWSAuthentication()
self.bucket_name = bucket_name
try:
if bucket_name is not None:
self.aws_region = self.infer_s3_region(bucket_name) if aws_region is None or aws_region == "infer" else aws_region
if not self.bucket_exists():
raise exceptions.MissingBucketException(f"Bucket {bucket_name} does not exist")
else:
self.aws_region = None
except exceptions.MissingBucketException:
if create_bucket:
assert aws_region is not None and aws_region != "infer", "Must specify AWS region when creating bucket"
self.aws_region = aws_region
self.create_bucket()
logger.info(f"Created S3 bucket {self.bucket_name} in region {self.aws_region}")
else:
raise
self.aws_region = self.infer_s3_region(bucket_name) if aws_region == "infer" else aws_region

def region_tag(self):
return "aws:" + self.aws_region
Expand Down
5 changes: 3 additions & 2 deletions skyplane/replicate/replicator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ def deprovision_gateway_instance(server: Server):
public_ips = [i.public_ip() for i in self.bound_nodes.values()] + [i.public_ip() for i in self.temp_nodes]
aws_regions = [node.region for node in self.topology.gateway_nodes if node.region.startswith("aws:")]
aws_jobs = [partial(self.aws.remove_ips_from_security_group, r.split(":")[1], public_ips) for r in set(aws_regions)]
do_parallel(lambda fn: fn(), aws_jobs)
gcp_jobs = self.gcp.remove_ips_from_firewall(public_ips)
gcp_regions = [node.region for node in self.topology.gateway_nodes if node.region.startswith("gcp:")]
gcp_jobs = [self.gcp.remove_ips_from_firewall(public_ips)] if gcp_regions else []
do_parallel(lambda fn: fn(), aws_jobs + gcp_jobs, desc="Removing firewall rules")

# Terminate instances
instances = list(self.bound_nodes.values()) + self.temp_nodes
Expand Down
3 changes: 2 additions & 1 deletion tests/interface_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

def interface_test_framework(region, bucket, multipart: bool, test_delete_bucket: bool = False):
logger.info("creating interfaces...")
interface = ObjectStoreInterface.create(region, bucket, create_bucket=True)
interface = ObjectStoreInterface.create(region, bucket)
interface.create_bucket()
assert interface.bucket_exists()
debug_time = lambda n, s, e: logger.info(f"{n} {s}MB in {round(e, 2)}s ({round(s / e, 2)}MB/s)")

Expand Down

0 comments on commit 200c2f4

Please sign in to comment.