From 397a72229a912f2b1d9efaec054c3f5dca1327c8 Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Tue, 22 Mar 2022 20:34:35 +0000 Subject: [PATCH] Partition batches with greedy algorithm --- skylark/replicate/replicator_client.py | 47 ++++++++++---------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/skylark/replicate/replicator_client.py b/skylark/replicate/replicator_client.py index 84becbd49..d6053a6a7 100644 --- a/skylark/replicate/replicator_client.py +++ b/skylark/replicate/replicator_client.py @@ -225,33 +225,22 @@ def run_replication_plan(self, job: ReplicationJob) -> ReplicationJob: chunks.append(Chunk(key=obj, chunk_id=idx, file_offset_bytes=0, chunk_length_bytes=file_size_bytes)) # partition chunks into roughly equal-sized batches (by bytes) + # iteratively adds chunks to the batch with the smallest size + def partition(items: List[Chunk], n_batches: int) -> List[List[Chunk]]: + batches = [[] for _ in range(n_batches)] + items.sort(key=lambda c: c.chunk_length_bytes, reverse=True) + for item in items: + batch_sizes = [sum(b.chunk_length_bytes for b in bs) for bs in batches] + batches[batch_sizes.index(min(batch_sizes))].append(item) + return batches + src_instances = [self.bound_nodes[n] for n in self.topology.source_instances()] - chunk_lens = [c.chunk_length_bytes for c in chunks] - new_chunk_lens = int(len(chunk_lens) / len(src_instances)) * len(src_instances) - if len(chunk_lens) != new_chunk_lens: - dropped_chunks = len(chunk_lens) - new_chunk_lens - logger.warn(f"Dropping {dropped_chunks} chunks to be evenly distributed") - chunk_lens = chunk_lens[:new_chunk_lens] - chunks = chunks[:new_chunk_lens] - - approx_bytes_per_connection = sum(chunk_lens) / len(src_instances) - assert sum(chunk_lens) > 0, f"No chunks to replicate, got {chunk_lens}" - batch_bytes = 0 - chunk_batches = [] - current_batch = [] - for chunk in chunks: - current_batch.append(chunk) - batch_bytes += chunk.chunk_length_bytes - if batch_bytes >= approx_bytes_per_connection and len(chunk_batches) < len(src_instances): - chunk_batches.append(current_batch) - batch_bytes = 0 - current_batch = [] - if current_batch: # add remaining chunks to the smallest batch by total bytes - smallest_batch = min(chunk_batches, key=lambda b: sum([c.chunk_length_bytes for c in b])) - smallest_batch.extend(current_batch) + chunk_batches = partition(chunks, len(src_instances)) assert (len(chunk_batches) == (len(src_instances) - 1)) or ( len(chunk_batches) == len(src_instances) ), f"{len(chunk_batches)} batches, expected {len(src_instances)}" + for batch_idx, batch in enumerate(chunk_batches): + logger.info(f"Batch {batch_idx} size: {sum(c.chunk_length_bytes for c in batch)} with {len(batch)} chunks") # make list of ChunkRequests chunk_requests_sharded: Dict[int, List[ChunkRequest]] = {} @@ -332,14 +321,14 @@ def shutdown_handler(): if save_log: (transfer_dir / "job.pkl").write_bytes(pickle.dumps(job)) if copy_gateway_logs: - for instance in self.bound_nodes.values(): + + def copy_log(instance): logger.info(f"Copying gateway logs from {instance.uuid()}") instance.run_command("sudo docker logs -t skylark_gateway 2> /tmp/gateway.stderr > /tmp/gateway.stdout") - log_out = transfer_dir / f"gateway_{instance.uuid()}.stdout" - log_err = transfer_dir / f"gateway_{instance.uuid()}.stderr" - instance.download_file("/tmp/gateway.stdout", log_out) - instance.download_file("/tmp/gateway.stderr", log_err) - logger.debug(f"Wrote gateway logs to {transfer_dir}") + instance.download_file("/tmp/gateway.stdout", transfer_dir / f"gateway_{instance.uuid()}.stdout") + instance.download_file("/tmp/gateway.stderr", transfer_dir / f"gateway_{instance.uuid()}.stderr") + + do_parallel(copy_log, self.bound_nodes.values(), n=-1) if write_profile: chunk_status_df = self.get_chunk_status_log_df() (transfer_dir / "chunk_status_df.csv").write_text(chunk_status_df.to_csv(index=False))