Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ux] catch top level exceptions with pretty printing #397 #414

Merged
merged 7 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
from pathlib import Path
from shlex import split
import traceback

import questionary
import typer
Expand All @@ -13,8 +14,8 @@
import skyplane.cli.cli_internal as cli_internal
import skyplane.cli.cli_solver
import skyplane.cli.experiments
from skyplane import GB, config_path, skyplane_root, cloud_config
from skyplane.cli.common import print_header
from skyplane import config_path, exceptions, skyplane_root, cloud_config
from skyplane.cli.common import print_header, console
from skyplane.cli.cli_impl.cp_local import (
copy_azure_local,
copy_gcs_local,
Expand All @@ -27,11 +28,10 @@
from skyplane.cli.cli_impl.cp_replicate import (
generate_full_transferobjlist,
generate_topology,
generate_transfer_obj_list,
confirm_transfer,
launch_replication_job,
)
from skyplane.replicate.replication_plan import TransferObjectList, ReplicationJob
from skyplane.replicate.replication_plan import ReplicationJob
from skyplane.cli.cli_impl.init import load_aws_config, load_azure_config, load_gcp_config
from skyplane.cli.cli_impl.ls import ls_local, ls_objstore
from skyplane.cli.common import check_ulimit, parse_path, query_instances
Expand Down Expand Up @@ -166,11 +166,16 @@ def cp(
account_name, container_name = bucket_dst
copy_azure_local(account_name, container_name, path_src, Path(path_dst))
elif provider_src in clouds and provider_dst in clouds:
src_client = ObjectStoreInterface.create(clouds[provider_src], bucket_src)
src_region = src_client.region_tag()
dst_client = ObjectStoreInterface.create(clouds[provider_dst], bucket_dst)
dst_region = dst_client.region_tag()
transfer_pairs = generate_full_transferobjlist(src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst)
try:
src_client = ObjectStoreInterface.create(clouds[provider_src], bucket_src)
src_region = src_client.region_tag()
dst_client = ObjectStoreInterface.create(clouds[provider_dst], bucket_dst)
dst_region = dst_client.region_tag()
transfer_pairs = generate_full_transferobjlist(src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst)
except exceptions.SkyplaneException as e:
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e.pretty_print_str())
raise typer.Exit(1)
topo = generate_topology(
src_region,
dst_region,
Expand Down Expand Up @@ -278,11 +283,16 @@ def sync(

clouds = {"s3": "aws:infer", "gs": "gcp:infer", "azure": "azure:infer"}

src_client = ObjectStoreInterface.create(clouds[provider_src], bucket_src)
src_region = src_client.region_tag()
dst_client = ObjectStoreInterface.create(clouds[provider_dst], bucket_dst)
dst_region = dst_client.region_tag()
full_transfer_pairs = generate_full_transferobjlist(src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst)
try:
src_client = ObjectStoreInterface.create(clouds[provider_src], bucket_src)
src_region = src_client.region_tag()
dst_client = ObjectStoreInterface.create(clouds[provider_dst], bucket_dst)
dst_region = dst_client.region_tag()
full_transfer_pairs = generate_full_transferobjlist(src_region, bucket_src, path_src, dst_region, bucket_dst, path_dst)
except exceptions.SkyplaneException as e:
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e.pretty_print_str())
raise typer.Exit(1)

# filter out any transfer pairs that are already in the destination
transfer_pairs = []
Expand Down
80 changes: 10 additions & 70 deletions skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
import signal
import traceback
from typing import List, Optional, Tuple

import typer
Expand All @@ -10,7 +11,7 @@
from skyplane import exceptions, GB, format_bytes, skyplane_root
from skyplane.compute.cloud_providers import CloudProvider
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob, TransferObjectList
from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob
from skyplane.replicate.replicator_client import ReplicatorClient
from skyplane.utils import logger
from skyplane.utils.timer import Timer
Expand Down Expand Up @@ -73,72 +74,6 @@ def generate_topology(
return topo


def generate_transfer_obj_list(
src_region: str,
dst_region: str,
source_bucket: str,
dest_bucket: str,
src_key_prefix: str = "",
dest_key_prefix: str = "",
cached_src_objs: Optional[List[ObjectStoreObject]] = None,
) -> TransferObjectList:

if cached_src_objs:
src_objs = cached_src_objs
else:
source_iface = ObjectStoreInterface.create(src_region, source_bucket)
logger.fs.debug(f"Querying objects in {source_bucket}")
with console.status(f"Querying objects in {source_bucket}") as status:
src_objs = []
for obj in source_iface.list_objects(src_key_prefix):
src_objs.append(obj)
status.update(f"Querying objects in {source_bucket} (found {len(src_objs)} objects so far)")

if not src_objs:
logger.error("Specified object does not exist.")
raise exceptions.MissingObjectException()

# map objects to destination object paths
# todo isolate this logic and test independently
logger.fs.debug(f"Mapping objects to destination paths")
src_objs_job = []
dest_objs_job = []
# if only one object exists, replicate it
if len(src_objs) == 1 and src_objs[0].key == src_key_prefix:
src_objs_job.append(src_objs[0].key)
if dest_key_prefix.endswith("/"):
dest_objs_job.append(dest_key_prefix + src_objs[0].key.split("/")[-1])
else:
dest_objs_job.append(dest_key_prefix)
# multiple objects to replicate
else:
for src_obj in src_objs:
src_objs_job.append(src_obj.key)
# remove prefix from object key
src_path_no_prefix = src_obj.key[len(src_key_prefix) :] if src_obj.key.startswith(src_key_prefix) else src_obj.key
# remove single leading slash if present
src_path_no_prefix = src_path_no_prefix[1:] if src_path_no_prefix.startswith("/") else src_path_no_prefix
if len(dest_key_prefix) == 0:
dest_objs_job.append(src_path_no_prefix)
elif dest_key_prefix.endswith("/"):
dest_objs_job.append(dest_key_prefix + src_path_no_prefix)
else:
dest_objs_job.append(dest_key_prefix + "/" + src_path_no_prefix)

obj_sizes = {obj.key: obj.size for obj in src_objs}

dst_iface = ObjectStoreInterface.create(dst_region, dest_bucket)
logger.fs.debug(f"Querying objects in {dest_bucket}")
with console.status(f"Querying objects in {dest_bucket}") as status:
dst_objs = []
for obj in dst_iface.list_objects(dest_key_prefix):
if obj.key in dest_objs_job:
dst_objs.append(obj)
status.update(f"Querying objects in {dest_bucket} (found {len(dst_objs)} objects so far)")

return TransferObjectList(src_objs_job, dest_objs_job, obj_sizes, src_objs, dst_objs)


def map_object_key_prefix(
source_prefix: str,
dest_prefix: str,
Expand Down Expand Up @@ -183,7 +118,7 @@ def generate_full_transferobjlist(
status.update(f"Querying objects in {source_bucket} (found {len(source_objs)} objects so far)")
if not source_objs:
logger.error("Specified object does not exist.")
raise exceptions.MissingObjectException()
raise exceptions.MissingObjectException(f"No objects were found in the specified prefix {source_prefix} in {source_bucket}")

# map objects to destination object paths
for source_obj in source_objs:
Expand Down Expand Up @@ -296,9 +231,14 @@ def launch_replication_job(
write_socket_profile=debug,
copy_gateway_logs=debug,
)
except KeyboardInterrupt:
except (KeyboardInterrupt, exceptions.SkyplaneException) as e:
if isinstance(e, KeyboardInterrupt):
rprint("\n[bold red]Transfer cancelled by user. Exiting.[/bold red]")
elif isinstance(e, exceptions.SkyplaneException):
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e.pretty_print_str())
if not reuse_gateways:
logger.fs.warning("Deprovisioning gateways then exiting...")
logger.fs.warning("Deprovisioning gateways then exiting. Please wait...")
# disable sigint to prevent repeated KeyboardInterrupts
s = signal.signal(signal.SIGINT, signal.SIG_IGN)
rc.deprovision_gateways()
Expand Down
77 changes: 43 additions & 34 deletions skyplane/compute/azure/azure_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import paramiko
from azure.mgmt.compute.models import ResourceIdentityType
from azure.core.exceptions import HttpResponseError

from skyplane import key_root
from skyplane import exceptions, key_root
from skyplane.compute.azure.azure_auth import AzureAuthentication
from skyplane.compute.azure.azure_server import AzureServer
from skyplane.compute.cloud_providers import CloudProvider
Expand Down Expand Up @@ -316,42 +317,50 @@ def provision_instance(self, location: str, vm_size: str, name: Optional[str] =
# Create the VM
with Timer("Creating Azure VM"):
with self.provisioning_semaphore:
poller = compute_client.virtual_machines.begin_create_or_update(
resource_group,
AzureServer.vm_name(name),
{
"location": location,
"tags": {"skyplane": "true"},
"hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)},
"storage_profile": {
# "image_reference": {
# "publisher": "canonical",
# "offer": "0001-com-ubuntu-server-focal",
# "sku": "20_04-lts",
# "version": "latest",
# },
"image_reference": {
"publisher": "microsoft-aks",
"offer": "aks",
"sku": "aks-engine-ubuntu-1804-202112",
"version": "latest",
try:
poller = compute_client.virtual_machines.begin_create_or_update(
resource_group,
AzureServer.vm_name(name),
{
"location": location,
"tags": {"skyplane": "true"},
"hardware_profile": {"vm_size": self.lookup_valid_instance(location, vm_size)},
"storage_profile": {
# "image_reference": {
# "publisher": "canonical",
# "offer": "0001-com-ubuntu-server-focal",
# "sku": "20_04-lts",
# "version": "latest",
# },
"image_reference": {
"publisher": "microsoft-aks",
"offer": "aks",
"sku": "aks-engine-ubuntu-1804-202112",
"version": "latest",
},
"os_disk": {"create_option": "FromImage", "delete_option": "Delete"},
},
"os_disk": {"create_option": "FromImage", "delete_option": "Delete"},
},
"os_profile": {
"computer_name": AzureServer.vm_name(name),
"admin_username": uname,
"linux_configuration": {
"disable_password_authentication": True,
"ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]},
"os_profile": {
"computer_name": AzureServer.vm_name(name),
"admin_username": uname,
"linux_configuration": {
"disable_password_authentication": True,
"ssh": {"public_keys": [{"path": f"/home/{uname}/.ssh/authorized_keys", "key_data": pub_key}]},
},
},
"network_profile": {"network_interfaces": [{"id": nic_result.id}]},
# give VM managed identity w/ system assigned identity
"identity": {"type": ResourceIdentityType.system_assigned},
},
"network_profile": {"network_interfaces": [{"id": nic_result.id}]},
# give VM managed identity w/ system assigned identity
"identity": {"type": ResourceIdentityType.system_assigned},
},
)
vm_result = poller.result()
)
vm_result = poller.result()
except HttpResponseError as e:
if "ResourceQuotaExceeded" in str(e):
raise exceptions.InsufficientVCPUException(f"Got ResourceQuotaExceeded error in Azure region {location}") from e
elif "QuotaExceeded" in str(e):
raise exceptions.InsufficientVCPUException(f"Got QuotaExceeded error in Azure region {location}") from e
else:
raise

server = AzureServer(name)
server.wait_for_ssh_ready()
Expand Down
19 changes: 16 additions & 3 deletions skyplane/compute/gcp/gcp_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import googleapiclient
import paramiko

from skyplane import key_root
from skyplane import exceptions, key_root
from skyplane.compute.azure.azure_cloud_provider import AzureCloudProvider
from skyplane.compute.cloud_providers import CloudProvider
from skyplane.compute.gcp.gcp_auth import GCPAuthentication
Expand Down Expand Up @@ -257,8 +257,21 @@ def provision_instance(
"scheduling": {"onHostMaintenance": "TERMINATE", "automaticRestart": False},
"deletionProtection": False,
}
result = compute.instances().insert(project=self.auth.project_id, zone=region, body=req_body).execute()
self.wait_for_operation_to_complete(region, result["name"])
try:
result = compute.instances().insert(project=self.auth.project_id, zone=region, body=req_body).execute()
self.wait_for_operation_to_complete(region, result["name"])
except googleapiclient.errors.HttpError as e:
if e.resp.status == 409:
if "ZONE_RESOURCE_POOL_EXHAUSTED" in e.content:
raise exceptions.InsufficientVCPUException(f"Got ZONE_RESOURCE_POOL_EXHAUSTED in region {region}") from e
elif "RESOURCE_EXHAUSTED" in e.content:
raise exceptions.InsufficientVCPUException(f"Got RESOURCE_EXHAUSTED in region {region}") from e
elif "QUOTA_EXCEEDED" in e.content:
raise exceptions.InsufficientVCPUException(f"Got QUOTA_EXCEEDED in region {region}") from e
elif "QUOTA_LIMIT" in e.content:
raise exceptions.InsufficientVCPUException(f"Got QUOTA_LIMIT in region {region}") from e
else:
raise e

# wait for server to reach RUNNING state
server = GCPServer(f"gcp:{region}", name)
Expand Down
42 changes: 27 additions & 15 deletions skyplane/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
class SkyplaneException(Exception):
pass
def pretty_print_str(self):
err = f"[bold][red]SkyplaneException: {str(self)}[/red][/bold]"
return err


class MissingBucketException(SkyplaneException):
pass
def pretty_print_str(self):
err = f"[red][bold]:x: MissingBucketException:[/bold] {str(self)}[/red]"
err += "\n[bold][red]Please ensure that the bucket exists and is accessible.[/red][/bold]"
return err


class MissingObjectException(SkyplaneException):
pass
def pretty_print_str(self):
err = f"[red][bold]:x: MissingObjectException:[/bold] {str(self)}[/red]"
err += "\n[bold][red]Please ensure that the object exists and is accessible.[/red][/bold]"
return err


class InsufficientVCPUException(SkyplaneException):
pass
class ChecksumMismatchException(SkyplaneException):
def pretty_print_str(self):
err = f"[red][bold]:x: ChecksumMismatchException:[/bold] {str(self)}[/red]"
err += "\n[bold][red]Please re-run the transfer due to checksum mismatch at the destination object store.[/red][/bold]"
return err


class ObjectStoreException(SkyplaneException):
pass
class InsufficientVCPUException(SkyplaneException):
def pretty_print_str(self):
err = f"[red][bold]:x: InsufficientVCPUException:[/bold] {str(self)}[/red]"
err += "\n[bold][red]Please ensure that you have enough vCPUs in the given region.[/red][/bold]"
# todo print link to a documentation page to request more vCPUs
return err


class TransferFailedException(Exception):
class TransferFailedException(SkyplaneException):
def __init__(self, message, failed_objects=None):
super().__init__(message)
self.failed_objects = failed_objects

def __str__(self):
def pretty_print_str(self):
err = f"[red][bold]:x: TransferFailedException:[/bold] {str(self)}[/red]"
if self.failed_objects and len(self.failed_objects) > 0:
failed_obj_str = (
str(self.failed_objects)
if len(self.failed_objects) <= 16
else str(self.failed_objects[:16]) + f" and {len(self.failed_objects) - 16} more"
)
return super().__str__() + "\nFailed objects: " + failed_obj_str
err += "\n[bold][red]Failed objects:[/red][/bold] " + str(self.failed_objects)
return err
2 changes: 1 addition & 1 deletion skyplane/obj_store/azure_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def upload_object(self, src_file_path, dst_object_name, part_number=None, upload
b64_md5sum = base64.b64encode(check_md5).decode("utf-8") if check_md5 else None
blob_md5 = blob_client.get_blob_properties().properties.content_settings.content_md5
if b64_md5sum != blob_md5:
raise exceptions.ObjectStoreException(
raise exceptions.ChecksumMismatchException(
f"Checksum mismatch for object {dst_object_name} in bucket {self.bucket_name}, "
+ f"expected {b64_md5sum}, got {blob_md5}"
)
Loading