Skip to content

Commit

Permalink
Fix support for whole bucket replication in skylark cp (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj authored Apr 27, 2022
1 parent bf1dacd commit f59ecb2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 15 deletions.
40 changes: 29 additions & 11 deletions skylark/cli/cli_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ def replicate_helper(
n_chunks: int = 512,
random: bool = False,
# bucket options
source_bucket: str = typer.Option(None),
dest_bucket: str = typer.Option(None),
src_key_prefix: str = "/",
dest_key_prefix: str = "/",
source_bucket: Optional[str] = None,
dest_bucket: Optional[str] = None,
src_key_prefix: str = "",
dest_key_prefix: str = "",
# gateway provisioning options
reuse_gateways: bool = False,
gateway_docker_image: str = os.environ.get("SKYLARK_DOCKER_IMAGE", "ghcr.io/skyplane-project/skyplane:main"),
Expand All @@ -231,8 +231,8 @@ def replicate_helper(
source_bucket=None,
dest_region=topo.sink_region(),
dest_bucket=None,
src_objs=[f"/{i}" for i in range(n_chunks)],
dest_objs=[f"/{i}" for i in range(n_chunks)],
src_objs=[str(i) for i in range(n_chunks)],
dest_objs=[str(i) for i in range(n_chunks)],
random_chunk_size_mb=chunk_size_mb,
)
else:
Expand All @@ -241,17 +241,35 @@ def replicate_helper(
if not src_objs:
logger.error("Specified object does not exist.")
raise exceptions.MissingObjectException()
dest_is_directory = False
if dest_key_prefix.endswith("/"):
dest_is_directory = True

# map objects to destination object paths
# todo isolate this logic and test independently
src_objs_job = []
dest_objs_job = []
# if only one object exists, replicate it
if len(src_objs) == 1 and src_objs[0].key == src_key_prefix:
src_objs_job.append(src_objs[0].key)
if dest_key_prefix.endswith("/"):
dest_objs_job.append(dest_key_prefix + src_objs[0].key.split("/")[-1])
else:
dest_objs_job.append(dest_key_prefix)
# multiple objects to replicate
else:
for src_obj in src_objs:
src_objs_job.append(src_obj.key)
src_path_no_prefix = src_obj.key.lstrip(src_key_prefix if src_key_prefix.endswith("/") else src_key_prefix + "/")
if dest_key_prefix.endswith("/") or len(dest_key_prefix) == 0:
dest_objs_job.append(dest_key_prefix + src_path_no_prefix)
else:
dest_objs_job.append(dest_key_prefix + "/" + src_path_no_prefix)

job = ReplicationJob(
source_region=topo.source_region(),
source_bucket=source_bucket,
dest_region=topo.sink_region(),
dest_bucket=dest_bucket,
src_objs=[obj.key for obj in src_objs],
dest_objs=[dest_key_prefix + obj.key if dest_is_directory else dest_key_prefix for obj in src_objs],
src_objs=src_objs_job,
dest_objs=dest_objs_job,
obj_sizes={obj.key: obj.size for obj in src_objs},
)

Expand Down
2 changes: 1 addition & 1 deletion skylark/gateway/gateway_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, region: str, outgoing_ports: Dict[str, int], chunk_dir: PathL
self.gateway_receiver = GatewayReceiver(chunk_store=self.chunk_store, max_pending_chunks=max_incoming_ports, use_tls=use_tls)
self.gateway_sender = GatewaySender(chunk_store=self.chunk_store, outgoing_ports=outgoing_ports, use_tls=use_tls)

self.obj_store_conn = GatewayObjStoreConn(chunk_store=self.chunk_store, max_conn=8)
self.obj_store_conn = GatewayObjStoreConn(chunk_store=self.chunk_store, max_conn=64)

# Download thread pool
self.dl_pool_semaphore = BoundedSemaphore(value=128)
Expand Down
2 changes: 1 addition & 1 deletion skylark/gateway/gateway_obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ObjStoreRequest:


class GatewayObjStoreConn:
def __init__(self, chunk_store: ChunkStore, max_conn=32):
def __init__(self, chunk_store: ChunkStore, max_conn=1):
self.chunk_store = chunk_store
self.n_processes = max_conn
self.processes = []
Expand Down
2 changes: 2 additions & 0 deletions skylark/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def exists(self, obj_name):
def download_object(self, src_object_name, dst_file_path):
src_object_name, dst_file_path = str(src_object_name), str(dst_file_path)
s3_client = self.auth.get_boto3_client("s3", self.aws_region)
assert len(src_object_name) > 0, f"Source object name must be non-empty: '{src_object_name}'"
s3_client.download_file(self.bucket_name, src_object_name, dst_file_path, Config=TransferConfig(use_threads=False))

def upload_object(self, src_file_path, dst_object_name):
dst_object_name, src_file_path = str(dst_object_name), str(src_file_path)
s3_client = self.auth.get_boto3_client("s3", self.aws_region)
assert len(dst_object_name) > 0, f"Destination object name must be non-empty: '{dst_object_name}'"
s3_client.upload_file(src_file_path, self.bucket_name, dst_object_name, Config=TransferConfig(use_threads=False))
5 changes: 3 additions & 2 deletions skylark/test/test_s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from skylark.utils.utils import Timer
from skylark.utils import logger


def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"):
logger.debug("creating interfaces...")
interface = S3Interface(region, bucket)
Expand All @@ -31,7 +32,7 @@ def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"):
logger.debug("uploading...")
with Timer() as t:
interface.upload_object(fpath, obj_name)
debug_time("uploaded", file_size_mb, t.elapsed)
debug_time("uploaded", file_size_mb, t.elapsed)

assert interface.get_obj_size(obj_name) == os.path.getsize(fpath)

Expand All @@ -44,7 +45,7 @@ def test_s3_interface(region="us-east-1", bucket="sky-us-east-1"):
logger.debug("downloading...")
with Timer() as t:
interface.download_object(obj_name, fpath)
debug_time("downloaded", file_size_mb, t.elapsed)
debug_time("downloaded", file_size_mb, t.elapsed)

assert interface.get_obj_size(obj_name) == os.path.getsize(fpath)

Expand Down

0 comments on commit f59ecb2

Please sign in to comment.