Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytest integration tests #874

Merged
merged 15 commits into from
Jun 21, 2023
2 changes: 1 addition & 1 deletion skyplane/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def copy(self, src: str, dst: str, recursive: bool = False):

pipeline = self.pipeline()
pipeline.queue_copy(src, dst, recursive=recursive)
pipeline.start()
pipeline.start(progress=True)

def object_store(self):
return ObjectStore()
4 changes: 4 additions & 0 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
self.topology = topology
self.provisioner = provisioner
self.transfer_config = transfer_config
# disable for azure
# TODO: remove this
self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3))
self.provisioning_lock = threading.Lock()
self.provisioned = False
Expand Down Expand Up @@ -235,6 +237,7 @@ def copy_log(instance):
instance.download_file("/tmp/gateway.stdout", out_file)
instance.download_file("/tmp/gateway.stderr", err_file)

print("COPY GATEWAY LOGS")
do_parallel(copy_log, self.bound_nodes.values(), n=-1)

def deprovision(self, max_jobs: int = 64, spinner: bool = False):
Expand Down Expand Up @@ -307,6 +310,7 @@ def run_async(self, jobs: List[TransferJob], hooks: Optional[TransferHook] = Non
"""
if not self.provisioned:
logger.error("Dataplane must be pre-provisioned. Call dataplane.provision() before starting a transfer")
print("discord", jobs)
tracker = TransferProgressTracker(self, jobs, self.transfer_config, hooks)
self.pending_transfers.append(tracker)
tracker.start()
Expand Down
206 changes: 110 additions & 96 deletions skyplane/api/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,130 +135,134 @@ def run(self):
# "src_spot_instance": getattr(self.transfer_config, f"{src_cloud_provider}_use_spot_instances"),
# "dst_spot_instance": getattr(self.transfer_config, f"{dst_cloud_provider}_use_spot_instances"),
}
# TODO: eventually jobs should be able to be concurrently dispatched and executed
# however this will require being able to handle conflicting multipart uploads ids

# initialize everything first
for job_uuid, job in self.jobs.items():
self.job_chunk_requests[job_uuid] = {}
self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}

session_start_timestamp_ms = int(time.time() * 1000)
try:
for job_uuid, job in self.jobs.items():
# pre-dispatch chunks to begin pre-buffering chunks
chunk_streams = {
job_uuid: job.dispatch(self.dataplane, transfer_config=self.transfer_config) for job_uuid, job in self.jobs.items()
}
for job_uuid, job in self.jobs.items():
try:
chunk_stream = job.dispatch(self.dataplane, transfer_config=self.transfer_config)
logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}")
self.job_chunk_requests[job_uuid] = {}
self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}

for chunk in chunk_streams[job_uuid]:
for chunk in chunk_stream:
chunks_dispatched = [chunk]
# TODO: check chunk ID
self.job_chunk_requests[job_uuid][chunk.chunk_id] = chunk
assert job_uuid in self.job_chunk_requests and chunk.chunk_id in self.job_chunk_requests[job_uuid]
self.hooks.on_chunk_dispatched(chunks_dispatched)
for region in self.dataplane.topology.dest_region_tags:
self.job_pending_chunk_ids[job_uuid][region].add(chunk.chunk_id)

logger.fs.debug(
f"[TransferProgressTracker] Job {job.uuid} dispatched with {len(self.job_chunk_requests[job_uuid])} chunk requests"
)
except Exception as e:
UsageClient.log_exception(
"dispatch job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0], # TODO: support multiple destinations
session_start_timestamp_ms,
)
raise e

except Exception as e:
UsageClient.log_exception(
"dispatch job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0], # TODO: support multiple destinations
session_start_timestamp_ms,
)
raise e

self.hooks.on_dispatch_end()
self.hooks.on_dispatch_end()

def monitor_single_dst_helper(dst_region):
start_time = time.time()
try:
self.monitor_transfer(dst_region)
except exceptions.SkyplaneGatewayException as err:
reformat_err = Exception(err.pretty_print_str()[37:])
UsageClient.log_exception(
"monitor transfer",
reformat_err,
args,
self.dataplane.topology.src_region_tag,
dst_region,
session_start_timestamp_ms,
)
raise err
except Exception as e:
UsageClient.log_exception(
"monitor transfer", e, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms
)
raise e
end_time = time.time()

runtime_s = end_time - start_time
# transfer successfully completed
transfer_stats = {
"dst_region": dst_region,
"total_runtime_s": round(runtime_s, 4),
}

results = []
dest_regions = self.dataplane.topology.dest_region_tags
with ThreadPoolExecutor(max_workers=len(dest_regions)) as executor:
e2e_start_time = time.time()
try:
future_list = [executor.submit(monitor_single_dst_helper, dest) for dest in dest_regions]
for future in as_completed(future_list):
results.append(future.result())
except Exception as e:
raise e
e2e_end_time = time.time()
transfer_stats = {
"total_runtime_s": e2e_end_time - e2e_start_time,
"throughput_gbits": self.query_bytes_dispatched() / (e2e_end_time - e2e_start_time) / GB * 8,
}
self.hooks.on_transfer_end()

def monitor_single_dst_helper(dst_region):
start_time = time.time()
start_time = int(time.time())
try:
self.monitor_transfer(dst_region)
except exceptions.SkyplaneGatewayException as err:
reformat_err = Exception(err.pretty_print_str()[37:])
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}")
job.finalize()
except Exception as e:
UsageClient.log_exception(
"monitor transfer",
reformat_err,
"finalize job",
e,
args,
self.dataplane.topology.src_region_tag,
dst_region,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise err
raise e
end_time = int(time.time())

# verify transfer
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}")
job.verify()
except Exception as e:
UsageClient.log_exception(
"monitor transfer", e, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms
"verify job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise e
end_time = time.time()

runtime_s = end_time - start_time
# transfer successfully completed
transfer_stats = {
"dst_region": dst_region,
"total_runtime_s": round(runtime_s, 4),
}

results = []
dest_regions = self.dataplane.topology.dest_region_tags
with ThreadPoolExecutor(max_workers=len(dest_regions)) as executor:
e2e_start_time = time.time()
try:
future_list = [executor.submit(monitor_single_dst_helper, dest) for dest in dest_regions]
for future in as_completed(future_list):
results.append(future.result())
except Exception as e:
raise e
e2e_end_time = time.time()
transfer_stats = {
"total_runtime_s": e2e_end_time - e2e_start_time,
"throughput_gbits": self.query_bytes_dispatched() / (e2e_end_time - e2e_start_time) / GB * 8,
}
self.hooks.on_transfer_end()

start_time = int(time.time())
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}")
job.finalize()
except Exception as e:
UsageClient.log_exception(
"finalize job",
e,
UsageClient.log_transfer(
transfer_stats,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
self.dataplane.topology.dest_region_tags,
session_start_timestamp_ms,
)
raise e
end_time = int(time.time())

# verify transfer
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}")
job.verify()
except Exception as e:
UsageClient.log_exception(
"verify job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise e

# transfer successfully completed
UsageClient.log_transfer(
transfer_stats,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags,
session_start_timestamp_ms,
)
print_stats_completed(total_runtime_s=transfer_stats["total_runtime_s"], throughput_gbits=transfer_stats["throughput_gbits"])
print_stats_completed(total_runtime_s=transfer_stats["total_runtime_s"], throughput_gbits=transfer_stats["throughput_gbits"])

@imports.inject("pandas")
def monitor_transfer(pd, self, region_tag):
Expand Down Expand Up @@ -299,13 +303,22 @@ def monitor_transfer(pd, self, region_tag):
# update job_complete_chunk_ids and job_pending_chunk_ids
# TODO: do chunk-tracking per-destination
for job_uuid, job in self.jobs.items():
job_complete_chunk_ids = set(chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid)
try:
job_complete_chunk_ids = set(
chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid
)
except Exception as e:
raise e
new_chunk_ids = (
self.job_complete_chunk_ids[job_uuid][region_tag]
.union(job_complete_chunk_ids)
.difference(self.job_complete_chunk_ids[job_uuid][region_tag])
)
completed_chunks = []
for id in new_chunk_ids:
assert (
job_uuid in self.job_chunk_requests and id in self.job_chunk_requests[job_uuid]
), f"Missing chunk id {id} for job {job_uuid}: {self.job_chunk_requests}"
for id in new_chunk_ids:
completed_chunks.append(self.job_chunk_requests[job_uuid][id])
self.hooks.on_chunk_completed(completed_chunks, region_tag)
Expand All @@ -319,7 +332,8 @@ def monitor_transfer(pd, self, region_tag):
time.sleep(0.05)

@property
@functools.lru_cache(maxsize=1)
# TODO: this is a very slow function, but we can't cache it since self.job_chunk_requests changes over time
# do not call it more often than necessary
def _chunk_to_job_map(self):
return {chunk_id: job_uuid for job_uuid, cr_dict in self.job_chunk_requests.items() for chunk_id in cr_dict.keys()}

Expand Down
39 changes: 25 additions & 14 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,11 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
multipart_exit_event = threading.Event()
multipart_chunk_threads = []

# TODO: remove after azure multipart implemented
azure_dest = any([dst_iface.provider == "azure" for dst_iface in self.dst_ifaces])

# start chunking threads
if self.transfer_config.multipart_enabled:
if not azure_dest and self.transfer_config.multipart_enabled:
for _ in range(self.concurrent_multipart_chunk_threads):
t = threading.Thread(
target=self._run_multipart_chunk_thread,
Expand All @@ -338,8 +341,13 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->

# begin chunking loop
for transfer_pair in transfer_pair_generator:
# print("transfer_pair", transfer_pair.src_obj.key, transfer_pair.dst_objs)
src_obj = transfer_pair.src_obj
if self.transfer_config.multipart_enabled and src_obj.size > self.transfer_config.multipart_threshold_mb * MB:
if (
not azure_dest
and self.transfer_config.multipart_enabled
and src_obj.size > self.transfer_config.multipart_threshold_mb * MB
):
multipart_send_queue.put(transfer_pair)
else:
if transfer_pair.src_obj.size == 0:
Expand Down Expand Up @@ -460,13 +468,16 @@ def __init__(
dst_paths: List[str] or str,
recursive: bool = False,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
job_id: Optional[str] = None,
):
self.src_path = src_path
self.dst_paths = dst_paths
self.recursive = recursive
self.requester_pays = requester_pays
self.uuid = uuid
if job_id is None:
self.uuid = str(uuid.uuid4())
else:
self.uuid = job_id

@property
def transfer_type(self) -> str:
Expand Down Expand Up @@ -559,9 +570,9 @@ def __init__(
dst_paths: List[str] or str,
recursive: bool = False,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
job_id: Optional[str] = None,
):
super().__init__(src_path, dst_paths, recursive, requester_pays, uuid)
super().__init__(src_path, dst_paths, recursive, requester_pays, job_id)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down Expand Up @@ -645,6 +656,9 @@ def dispatch(

# send chunk requests to source gateways
chunk_batch = [cr.chunk for cr in batch if cr.chunk is not None]
# TODO: allow multiple partition ids per chunk
for chunk in chunk_batch: # assign job UUID as partition ID
chunk.partition_id = self.uuid
min_idx = queue_size.index(min(queue_size))
n_added = 0
while n_added < len(chunk_batch):
Expand Down Expand Up @@ -701,6 +715,9 @@ def complete_fn(batch):

do_parallel(complete_fn, batches, n=8)

# TODO: Do NOT do this if we are pipelining multiple transfers - remove just what was completed
self.multipart_transfer_list = []

def verify(self):
"""Verify the integrity of the transfered destination objects"""

Expand Down Expand Up @@ -750,14 +767,8 @@ def size_gb(self):
class SyncJob(CopyJob):
"""sync job that copies the source objects that does not exist in the destination bucket to the destination"""

def __init__(
self,
src_path: str,
dst_paths: List[str] or str,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
):
super().__init__(src_path, dst_paths, True, requester_pays, uuid)
def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None):
super().__init__(src_path, dst_paths, True, requester_pays, job_id)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down
Loading