diff --git a/skylark/cli/cli_helper.py b/skylark/cli/cli_helper.py index 088a9442f..cb0742aa2 100644 --- a/skylark/cli/cli_helper.py +++ b/skylark/cli/cli_helper.py @@ -203,10 +203,10 @@ def replicate_helper( n_chunks: int = 512, random: bool = False, # bucket options - source_bucket: str = typer.Option(None), - dest_bucket: str = typer.Option(None), - src_key_prefix: str = "/", - dest_key_prefix: str = "/", + source_bucket: Optional[str] = None, + dest_bucket: Optional[str] = None, + src_key_prefix: str = "", + dest_key_prefix: str = "", # gateway provisioning options reuse_gateways: bool = False, gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/skyplane-project/skyplane:main"), @@ -231,8 +231,8 @@ def replicate_helper( source_bucket=None, dest_region=topo.sink_region(), dest_bucket=None, - src_objs=[f"/{i}" for i in range(n_chunks)], - dest_objs=[f"/{i}" for i in range(n_chunks)], + src_objs=[str(i) for i in range(n_chunks)], + dest_objs=[str(i) for i in range(n_chunks)], random_chunk_size_mb=chunk_size_mb, ) else: @@ -241,17 +241,35 @@ def replicate_helper( if not src_objs: logger.error("Specified object does not exist.") raise exceptions.MissingObjectException() - dest_is_directory = False - if dest_key_prefix.endswith("/"): - dest_is_directory = True + + # map objects to destination object paths + # todo isolate this logic and test independently + src_objs_job = [] + dest_objs_job = [] + # if only one object exists, replicate it + if len(src_objs) == 1 and src_objs[0].key == src_key_prefix: + src_objs_job.append(src_objs[0].key) + if dest_key_prefix.endswith("/"): + dest_objs_job.append(dest_key_prefix + src_objs[0].key.split("/")[-1]) + else: + dest_objs_job.append(dest_key_prefix) + # multiple objects to replicate + else: + for src_obj in src_objs: + src_objs_job.append(src_obj.key) + src_path_no_prefix = src_obj.key.lstrip(src_key_prefix if src_key_prefix.endswith("/") else src_key_prefix + "/") + if dest_key_prefix.endswith("/") or len(dest_key_prefix) == 0: + dest_objs_job.append(dest_key_prefix + src_path_no_prefix) + else: + dest_objs_job.append(dest_key_prefix + "/" + src_path_no_prefix) job = ReplicationJob( source_region=topo.source_region(), source_bucket=source_bucket, dest_region=topo.sink_region(), dest_bucket=dest_bucket, - src_objs=[obj.key for obj in src_objs], - dest_objs=[dest_key_prefix + obj.key if dest_is_directory else dest_key_prefix for obj in src_objs], + src_objs=src_objs_job, + dest_objs=dest_objs_job, obj_sizes={obj.key: obj.size for obj in src_objs}, ) diff --git a/skylark/gateway/gateway_daemon.py b/skylark/gateway/gateway_daemon.py index 1898c849b..c59aa69ff 100644 --- a/skylark/gateway/gateway_daemon.py +++ b/skylark/gateway/gateway_daemon.py @@ -33,7 +33,7 @@ def __init__(self, region: str, outgoing_ports: Dict[str, int], chunk_dir: PathL self.gateway_receiver = GatewayReceiver(chunk_store=self.chunk_store, max_pending_chunks=max_incoming_ports, use_tls=use_tls) self.gateway_sender = GatewaySender(chunk_store=self.chunk_store, outgoing_ports=outgoing_ports, use_tls=use_tls) - self.obj_store_conn = GatewayObjStoreConn(chunk_store=self.chunk_store, max_conn=8) + self.obj_store_conn = GatewayObjStoreConn(chunk_store=self.chunk_store, max_conn=64) # Download thread pool self.dl_pool_semaphore = BoundedSemaphore(value=128) diff --git a/skylark/gateway/gateway_obj_store.py b/skylark/gateway/gateway_obj_store.py index f24b65bb2..d8340b81d 100644 --- a/skylark/gateway/gateway_obj_store.py +++ b/skylark/gateway/gateway_obj_store.py @@ -21,7 +21,7 @@ class ObjStoreRequest: class GatewayObjStoreConn: - def __init__(self, chunk_store: ChunkStore, max_conn=32): + def __init__(self, chunk_store: ChunkStore, max_conn=1): self.chunk_store = chunk_store self.n_processes = max_conn self.processes = [] diff --git a/skylark/obj_store/s3_interface.py b/skylark/obj_store/s3_interface.py index 9fa533cee..1e00cdab2 100644 --- a/skylark/obj_store/s3_interface.py +++ b/skylark/obj_store/s3_interface.py @@ -81,9 +81,11 @@ def exists(self, obj_name): def download_object(self, src_object_name, dst_file_path): src_object_name, dst_file_path = str(src_object_name), str(dst_file_path) s3_client = self.auth.get_boto3_client("s3", self.aws_region) + assert len(src_object_name) > 0, f"Source object name must be non-empty: '{src_object_name}'" s3_client.download_file(self.bucket_name, src_object_name, dst_file_path, Config=TransferConfig(use_threads=False)) def upload_object(self, src_file_path, dst_object_name): dst_object_name, src_file_path = str(dst_object_name), str(src_file_path) s3_client = self.auth.get_boto3_client("s3", self.aws_region) + assert len(dst_object_name) > 0, f"Destination object name must be non-empty: '{dst_object_name}'" s3_client.upload_file(src_file_path, self.bucket_name, dst_object_name, Config=TransferConfig(use_threads=False)) diff --git a/skylark/test/test_s3_interface.py b/skylark/test/test_s3_interface.py index 843da6c7b..cffd48c4c 100644 --- a/skylark/test/test_s3_interface.py +++ b/skylark/test/test_s3_interface.py @@ -7,6 +7,7 @@ from skylark.utils.utils import Timer from skylark.utils import logger + def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"): logger.debug("creating interfaces...") interface = S3Interface(region, bucket) @@ -31,7 +32,7 @@ def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"): logger.debug("uploading...") with Timer() as t: interface.upload_object(fpath, obj_name) - debug_time("uploaded", file_size_mb, t.elapsed) + debug_time("uploaded", file_size_mb, t.elapsed) assert interface.get_obj_size(obj_name) == os.path.getsize(fpath) @@ -44,7 +45,7 @@ def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"): logger.debug("downloading...") with Timer() as t: interface.download_object(obj_name, fpath) - debug_time("downloaded", file_size_mb, t.elapsed) + debug_time("downloaded", file_size_mb, t.elapsed) assert interface.get_obj_size(obj_name) == os.path.getsize(fpath)