Skip to content

Commit

Permalink
added more fallback tests + using a file (#873)
Browse files Browse the repository at this point in the history
* Modified the tests so that they load from an actual quota file instead
of me defining a dictionary.
* Modified planner so that it can accept a file name for the quota
limits (default to the skyplane config quota files)
* Added more tests for error conditions (no quota file is provided +
quota file is provided but the requested region is not included in the
quota file)

---------

Co-authored-by: Sarah Wooders <[email protected]>
Co-authored-by: Asim Biswal <[email protected]>
  • Loading branch information
3 people authored Jun 20, 2023
1 parent 99db23c commit ed52e18
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 51 deletions.
76 changes: 47 additions & 29 deletions skyplane/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
GatewayWriteObjectStore,
GatewayReceive,
GatewaySend,
GatewayGenData,
GatewayWriteLocal,
)

from skyplane.api.transfer_job import TransferJob
Expand All @@ -30,21 +28,27 @@


class Planner:
def __init__(self, transfer_config: TransferConfig):
def __init__(self, transfer_config: TransferConfig, quota_limits_file: Optional[str] = None):
self.transfer_config = transfer_config
self.config = SkyplaneConfig.load_config(config_path)
self.n_instances = self.config.get_flag("max_instances")

# Loading the quota information, add ibm cloud when it is supported
self.quota_limits = {}
if os.path.exists(aws_quota_path):
with aws_quota_path.open("r") as f:
self.quota_limits["aws"] = json.load(f)
if os.path.exists(azure_standardDv5_quota_path):
with azure_standardDv5_quota_path.open("r") as f:
self.quota_limits["azure"] = json.load(f)
if os.path.exists(gcp_quota_path):
with gcp_quota_path.open("r") as f:
self.quota_limits["gcp"] = json.load(f)
quota_limits = {}
if quota_limits_file is not None:
with open(quota_limits_file, "r") as f:
quota_limits = json.load(f)
else:
if os.path.exists(aws_quota_path):
with aws_quota_path.open("r") as f:
quota_limits["aws"] = json.load(f)
if os.path.exists(azure_standardDv5_quota_path):
with azure_standardDv5_quota_path.open("r") as f:
quota_limits["azure"] = json.load(f)
if os.path.exists(gcp_quota_path):
with gcp_quota_path.open("r") as f:
quota_limits["gcp"] = json.load(f)
self.quota_limits = quota_limits

# Loading the vcpu information - a dictionary of dictionaries
# {"cloud_provider": {"instance_name": vcpu_cost}}
Expand Down Expand Up @@ -83,10 +87,9 @@ def _get_quota_limits_for(self, cloud_provider: str, region: str, spot: bool = F
:param spot: whether to use spot specified by the user config (default: False)
:type spot: bool
"""
quota_limits = self.quota_limits[cloud_provider]
quota_limits = self.quota_limits.get(cloud_provider, None)
if not quota_limits:
# User needs to reinitialize to save the quota information
logger.warning(f"Please run `skyplane init --reinit-{cloud_provider}` to load the quota information")
return None
if cloud_provider == "gcp":
region_family = "-".join(region.split("-")[:2])
Expand Down Expand Up @@ -116,11 +119,16 @@ def _calculate_vm_types(self, region_tag: str) -> Optional[Tuple[str, int]]:
cloud_provider=cloud_provider, region=region, spot=getattr(self.transfer_config, f"{cloud_provider}_use_spot_instances")
)

config_vm_type = getattr(self.transfer_config, f"{cloud_provider}_instance_class")

# No quota limits (quota limits weren't initialized properly during skyplane init)
if quota_limit is None:
return None
logger.warning(
f"Quota limit file not found for {region_tag}. Try running `skyplane init --reinit-{cloud_provider}` to load the quota information"
)
# return default instance type and number of instances
return config_vm_type, self.n_instances

config_vm_type = getattr(self.transfer_config, f"{cloud_provider}_instance_class")
config_vcpus = self._vm_to_vcpus(cloud_provider, config_vm_type)
if config_vcpus <= quota_limit:
return config_vm_type, quota_limit // config_vcpus
Expand All @@ -144,9 +152,7 @@ def _calculate_vm_types(self, region_tag: str) -> Optional[Tuple[str, int]]:
)
return (vm_type, n_instances)

def _get_vm_type_and_instances(
self, src_region_tag: Optional[str] = None, dst_region_tags: Optional[List[str]] = None
) -> Tuple[Dict[str, str], int]:
def _get_vm_type_and_instances(self, src_region_tag: str, dst_region_tags: List[str]) -> Tuple[Dict[str, str], int]:
"""Dynamically calculates the vm type each region can use (both the source region and all destination regions)
based on their quota limits and calculates the number of vms to launch in all regions by conservatively
taking the minimum of all regions to stay consistent.
Expand All @@ -156,10 +162,16 @@ def _get_vm_type_and_instances(
:param dst_region_tags: a list of the destination region tags (defualt: None)
:type dst_region_tags: Optional[List[str]]
"""

# One of them has to provided
assert src_region_tag is not None or dst_region_tags is not None, "There needs to be at least one source or destination"
src_tags = [src_region_tag] if src_region_tag is not None else []
dst_tags = dst_region_tags or []
# assert src_region_tag is not None or dst_region_tags is not None, "There needs to be at least one source or destination"
src_tags = [src_region_tag] # if src_region_tag is not None else []
dst_tags = dst_region_tags # or []

assert len(src_region_tag.split(":")) == 2, f"Source region tag {src_region_tag} must be in the form of `cloud_provider:region`"
assert (
len(dst_region_tags[0].split(":")) == 2
), f"Destination region tag {dst_region_tags} must be in the form of `cloud_provider:region`"

# do_parallel returns tuples of (region_tag, (vm_type, n_instances))
vm_info = do_parallel(self._calculate_vm_types, src_tags + dst_tags)
Expand All @@ -172,10 +184,10 @@ def _get_vm_type_and_instances(

class UnicastDirectPlanner(Planner):
# DO NOT USE THIS - broken for single-region transfers
def __init__(self, n_instances: int, n_connections: int, transfer_config: TransferConfig):
def __init__(self, n_instances: int, n_connections: int, transfer_config: TransferConfig, quota_limits_file: Optional[str] = None):
super().__init__(transfer_config, quota_limits_file)
self.n_instances = n_instances
self.n_connections = n_connections
super().__init__(transfer_config)

def plan(self, jobs: List[TransferJob]) -> TopologyPlan:
# make sure only single destination
Expand All @@ -184,6 +196,12 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan:

src_region_tag = jobs[0].src_iface.region_tag()
dst_region_tag = jobs[0].dst_ifaces[0].region_tag()

assert len(src_region_tag.split(":")) == 2, f"Source region tag {src_region_tag} must be in the form of `cloud_provider:region`"
assert (
len(dst_region_tag.split(":")) == 2
), f"Destination region tag {dst_region_tag} must be in the form of `cloud_provider:region`"

# jobs must have same sources and destinations
for job in jobs[1:]:
assert job.src_iface.region_tag() == src_region_tag, "All jobs must have same source region"
Expand Down Expand Up @@ -241,10 +259,10 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan:


class MulticastDirectPlanner(Planner):
def __init__(self, n_instances: int, n_connections: int, transfer_config: TransferConfig):
def __init__(self, n_instances: int, n_connections: int, transfer_config: TransferConfig, quota_limits_file: Optional[str] = None):
super().__init__(transfer_config, quota_limits_file)
self.n_instances = n_instances
self.n_connections = n_connections
super().__init__(transfer_config)

def plan(self, jobs: List[TransferJob]) -> TopologyPlan:
src_region_tag = jobs[0].src_iface.region_tag()
Expand Down Expand Up @@ -358,7 +376,7 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan:
plan = TopologyPlan(src_region_tag=src_region_tag, dest_region_tags=dst_region_tags)

# Dynammically calculate n_instances based on quota limits
vm_types, n_instances = self._get_vm_type_and_instances(src_region_tag=src_region_tag)
vm_types, n_instances = self._get_vm_type_and_instances(src_region_tag, dst_region_tags)

# TODO: support on-sided transfers but not requiring VMs to be created in source/destination regions
for i in range(n_instances):
Expand Down Expand Up @@ -419,7 +437,7 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan:
plan = TopologyPlan(src_region_tag=src_region_tag, dest_region_tags=dst_region_tags)

# Dynammically calculate n_instances based on quota limits
vm_types, n_instances = self._get_vm_type_and_instances(dst_region_tags=dst_region_tags)
vm_types, n_instances = self._get_vm_type_and_instances(src_region_tag, dst_region_tags)

# TODO: support on-sided transfers but not requiring VMs to be created in source/destination regions
for i in range(n_instances):
Expand Down
Loading

0 comments on commit ed52e18

Please sign in to comment.