Skip to content

Commit

Permalink
Add pytest integration tests (#874)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Jun 21, 2023
1 parent ed52e18 commit d6c5430
Show file tree
Hide file tree
Showing 16 changed files with 482 additions and 148 deletions.
2 changes: 1 addition & 1 deletion skyplane/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def copy(self, src: str, dst: str, recursive: bool = False):

pipeline = self.pipeline()
pipeline.queue_copy(src, dst, recursive=recursive)
pipeline.start()
pipeline.start(progress=True)

def object_store(self):
return ObjectStore()
4 changes: 4 additions & 0 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
self.topology = topology
self.provisioner = provisioner
self.transfer_config = transfer_config
# disable for azure
# TODO: remove this
self.http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=3))
self.provisioning_lock = threading.Lock()
self.provisioned = False
Expand Down Expand Up @@ -235,6 +237,7 @@ def copy_log(instance):
instance.download_file("/tmp/gateway.stdout", out_file)
instance.download_file("/tmp/gateway.stderr", err_file)

print("COPY GATEWAY LOGS")
do_parallel(copy_log, self.bound_nodes.values(), n=-1)

def deprovision(self, max_jobs: int = 64, spinner: bool = False):
Expand Down Expand Up @@ -307,6 +310,7 @@ def run_async(self, jobs: List[TransferJob], hooks: Optional[TransferHook] = Non
"""
if not self.provisioned:
logger.error("Dataplane must be pre-provisioned. Call dataplane.provision() before starting a transfer")
print("discord", jobs)
tracker = TransferProgressTracker(self, jobs, self.transfer_config, hooks)
self.pending_transfers.append(tracker)
tracker.start()
Expand Down
206 changes: 110 additions & 96 deletions skyplane/api/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,130 +135,134 @@ def run(self):
# "src_spot_instance": getattr(self.transfer_config, f"{src_cloud_provider}_use_spot_instances"),
# "dst_spot_instance": getattr(self.transfer_config, f"{dst_cloud_provider}_use_spot_instances"),
}
# TODO: eventually jobs should be able to be concurrently dispatched and executed
# however this will require being able to handle conflicting multipart uploads ids

# initialize everything first
for job_uuid, job in self.jobs.items():
self.job_chunk_requests[job_uuid] = {}
self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}

session_start_timestamp_ms = int(time.time() * 1000)
try:
for job_uuid, job in self.jobs.items():
# pre-dispatch chunks to begin pre-buffering chunks
chunk_streams = {
job_uuid: job.dispatch(self.dataplane, transfer_config=self.transfer_config) for job_uuid, job in self.jobs.items()
}
for job_uuid, job in self.jobs.items():
try:
chunk_stream = job.dispatch(self.dataplane, transfer_config=self.transfer_config)
logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}")
self.job_chunk_requests[job_uuid] = {}
self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}

for chunk in chunk_streams[job_uuid]:
for chunk in chunk_stream:
chunks_dispatched = [chunk]
# TODO: check chunk ID
self.job_chunk_requests[job_uuid][chunk.chunk_id] = chunk
assert job_uuid in self.job_chunk_requests and chunk.chunk_id in self.job_chunk_requests[job_uuid]
self.hooks.on_chunk_dispatched(chunks_dispatched)
for region in self.dataplane.topology.dest_region_tags:
self.job_pending_chunk_ids[job_uuid][region].add(chunk.chunk_id)

logger.fs.debug(
f"[TransferProgressTracker] Job {job.uuid} dispatched with {len(self.job_chunk_requests[job_uuid])} chunk requests"
)
except Exception as e:
UsageClient.log_exception(
"dispatch job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0], # TODO: support multiple destinations
session_start_timestamp_ms,
)
raise e

except Exception as e:
UsageClient.log_exception(
"dispatch job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0], # TODO: support multiple destinations
session_start_timestamp_ms,
)
raise e

self.hooks.on_dispatch_end()
self.hooks.on_dispatch_end()

def monitor_single_dst_helper(dst_region):
start_time = time.time()
try:
self.monitor_transfer(dst_region)
except exceptions.SkyplaneGatewayException as err:
reformat_err = Exception(err.pretty_print_str()[37:])
UsageClient.log_exception(
"monitor transfer",
reformat_err,
args,
self.dataplane.topology.src_region_tag,
dst_region,
session_start_timestamp_ms,
)
raise err
except Exception as e:
UsageClient.log_exception(
"monitor transfer", e, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms
)
raise e
end_time = time.time()

runtime_s = end_time - start_time
# transfer successfully completed
transfer_stats = {
"dst_region": dst_region,
"total_runtime_s": round(runtime_s, 4),
}

results = []
dest_regions = self.dataplane.topology.dest_region_tags
with ThreadPoolExecutor(max_workers=len(dest_regions)) as executor:
e2e_start_time = time.time()
try:
future_list = [executor.submit(monitor_single_dst_helper, dest) for dest in dest_regions]
for future in as_completed(future_list):
results.append(future.result())
except Exception as e:
raise e
e2e_end_time = time.time()
transfer_stats = {
"total_runtime_s": e2e_end_time - e2e_start_time,
"throughput_gbits": self.query_bytes_dispatched() / (e2e_end_time - e2e_start_time) / GB * 8,
}
self.hooks.on_transfer_end()

def monitor_single_dst_helper(dst_region):
start_time = time.time()
start_time = int(time.time())
try:
self.monitor_transfer(dst_region)
except exceptions.SkyplaneGatewayException as err:
reformat_err = Exception(err.pretty_print_str()[37:])
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}")
job.finalize()
except Exception as e:
UsageClient.log_exception(
"monitor transfer",
reformat_err,
"finalize job",
e,
args,
self.dataplane.topology.src_region_tag,
dst_region,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise err
raise e
end_time = int(time.time())

# verify transfer
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}")
job.verify()
except Exception as e:
UsageClient.log_exception(
"monitor transfer", e, args, self.dataplane.topology.src_region_tag, dst_region, session_start_timestamp_ms
"verify job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise e
end_time = time.time()

runtime_s = end_time - start_time
# transfer successfully completed
transfer_stats = {
"dst_region": dst_region,
"total_runtime_s": round(runtime_s, 4),
}

results = []
dest_regions = self.dataplane.topology.dest_region_tags
with ThreadPoolExecutor(max_workers=len(dest_regions)) as executor:
e2e_start_time = time.time()
try:
future_list = [executor.submit(monitor_single_dst_helper, dest) for dest in dest_regions]
for future in as_completed(future_list):
results.append(future.result())
except Exception as e:
raise e
e2e_end_time = time.time()
transfer_stats = {
"total_runtime_s": e2e_end_time - e2e_start_time,
"throughput_gbits": self.query_bytes_dispatched() / (e2e_end_time - e2e_start_time) / GB * 8,
}
self.hooks.on_transfer_end()

start_time = int(time.time())
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Finalizing job {job.uuid}")
job.finalize()
except Exception as e:
UsageClient.log_exception(
"finalize job",
e,
UsageClient.log_transfer(
transfer_stats,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
self.dataplane.topology.dest_region_tags,
session_start_timestamp_ms,
)
raise e
end_time = int(time.time())

# verify transfer
try:
for job in self.jobs.values():
logger.fs.debug(f"[TransferProgressTracker] Verifying job {job.uuid}")
job.verify()
except Exception as e:
UsageClient.log_exception(
"verify job",
e,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags[0],
session_start_timestamp_ms,
)
raise e

# transfer successfully completed
UsageClient.log_transfer(
transfer_stats,
args,
self.dataplane.topology.src_region_tag,
self.dataplane.topology.dest_region_tags,
session_start_timestamp_ms,
)
print_stats_completed(total_runtime_s=transfer_stats["total_runtime_s"], throughput_gbits=transfer_stats["throughput_gbits"])
print_stats_completed(total_runtime_s=transfer_stats["total_runtime_s"], throughput_gbits=transfer_stats["throughput_gbits"])

@imports.inject("pandas")
def monitor_transfer(pd, self, region_tag):
Expand Down Expand Up @@ -299,13 +303,22 @@ def monitor_transfer(pd, self, region_tag):
# update job_complete_chunk_ids and job_pending_chunk_ids
# TODO: do chunk-tracking per-destination
for job_uuid, job in self.jobs.items():
job_complete_chunk_ids = set(chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid)
try:
job_complete_chunk_ids = set(
chunk_id for chunk_id in completed_chunk_ids if self._chunk_to_job_map[chunk_id] == job_uuid
)
except Exception as e:
raise e
new_chunk_ids = (
self.job_complete_chunk_ids[job_uuid][region_tag]
.union(job_complete_chunk_ids)
.difference(self.job_complete_chunk_ids[job_uuid][region_tag])
)
completed_chunks = []
for id in new_chunk_ids:
assert (
job_uuid in self.job_chunk_requests and id in self.job_chunk_requests[job_uuid]
), f"Missing chunk id {id} for job {job_uuid}: {self.job_chunk_requests}"
for id in new_chunk_ids:
completed_chunks.append(self.job_chunk_requests[job_uuid][id])
self.hooks.on_chunk_completed(completed_chunks, region_tag)
Expand All @@ -319,7 +332,8 @@ def monitor_transfer(pd, self, region_tag):
time.sleep(0.05)

@property
@functools.lru_cache(maxsize=1)
# TODO: this is a very slow function, but we can't cache it since self.job_chunk_requests changes over time
# do not call it more often than necessary
def _chunk_to_job_map(self):
return {chunk_id: job_uuid for job_uuid, cr_dict in self.job_chunk_requests.items() for chunk_id in cr_dict.keys()}

Expand Down
39 changes: 25 additions & 14 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,11 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->
multipart_exit_event = threading.Event()
multipart_chunk_threads = []

# TODO: remove after azure multipart implemented
azure_dest = any([dst_iface.provider == "azure" for dst_iface in self.dst_ifaces])

# start chunking threads
if self.transfer_config.multipart_enabled:
if not azure_dest and self.transfer_config.multipart_enabled:
for _ in range(self.concurrent_multipart_chunk_threads):
t = threading.Thread(
target=self._run_multipart_chunk_thread,
Expand All @@ -338,8 +341,13 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) ->

# begin chunking loop
for transfer_pair in transfer_pair_generator:
# print("transfer_pair", transfer_pair.src_obj.key, transfer_pair.dst_objs)
src_obj = transfer_pair.src_obj
if self.transfer_config.multipart_enabled and src_obj.size > self.transfer_config.multipart_threshold_mb * MB:
if (
not azure_dest
and self.transfer_config.multipart_enabled
and src_obj.size > self.transfer_config.multipart_threshold_mb * MB
):
multipart_send_queue.put(transfer_pair)
else:
if transfer_pair.src_obj.size == 0:
Expand Down Expand Up @@ -460,13 +468,16 @@ def __init__(
dst_paths: List[str] or str,
recursive: bool = False,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
job_id: Optional[str] = None,
):
self.src_path = src_path
self.dst_paths = dst_paths
self.recursive = recursive
self.requester_pays = requester_pays
self.uuid = uuid
if job_id is None:
self.uuid = str(uuid.uuid4())
else:
self.uuid = job_id

@property
def transfer_type(self) -> str:
Expand Down Expand Up @@ -559,9 +570,9 @@ def __init__(
dst_paths: List[str] or str,
recursive: bool = False,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
job_id: Optional[str] = None,
):
super().__init__(src_path, dst_paths, recursive, requester_pays, uuid)
super().__init__(src_path, dst_paths, recursive, requester_pays, job_id)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down Expand Up @@ -645,6 +656,9 @@ def dispatch(

# send chunk requests to source gateways
chunk_batch = [cr.chunk for cr in batch if cr.chunk is not None]
# TODO: allow multiple partition ids per chunk
for chunk in chunk_batch: # assign job UUID as partition ID
chunk.partition_id = self.uuid
min_idx = queue_size.index(min(queue_size))
n_added = 0
while n_added < len(chunk_batch):
Expand Down Expand Up @@ -701,6 +715,9 @@ def complete_fn(batch):

do_parallel(complete_fn, batches, n=8)

# TODO: Do NOT do this if we are pipelining multiple transfers - remove just what was completed
self.multipart_transfer_list = []

def verify(self):
"""Verify the integrity of the transfered destination objects"""

Expand Down Expand Up @@ -750,14 +767,8 @@ def size_gb(self):
class SyncJob(CopyJob):
"""sync job that copies the source objects that does not exist in the destination bucket to the destination"""

def __init__(
self,
src_path: str,
dst_paths: List[str] or str,
requester_pays: bool = False,
uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())),
):
super().__init__(src_path, dst_paths, True, requester_pays, uuid)
def __init__(self, src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, job_id: Optional[str] = None):
super().__init__(src_path, dst_paths, True, requester_pays, job_id)
self.transfer_list = []
self.multipart_transfer_list = []

Expand Down
Loading

0 comments on commit d6c5430

Please sign in to comment.