Skip to content

Commit

Permalink
Faster transfer verification via prefixes (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj authored Aug 2, 2022
1 parent 0f05219 commit 3cdebe7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 48 deletions.
34 changes: 31 additions & 3 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from pathlib import Path
from shlex import split
import traceback
from skyplane.replicate.replicator_client import ReplicatorClient

import typer
from rich.progress import Progress, SpinnerColumn, TextColumn

import skyplane.cli.cli_aws
import skyplane.cli.cli_azure
Expand Down Expand Up @@ -154,6 +156,7 @@ def error_local():
job=job,
ask_to_confirm_transfer=not confirm,
)

stats = launch_replication_job(
topo=topo,
job=job,
Expand All @@ -163,15 +166,27 @@ def error_local():
use_compression=cloud_config.get_flag("compress") if src_region != dst_region else False,
use_e2ee=cloud_config.get_flag("encrypt_e2e") if src_region != dst_region else False,
use_socket_tls=cloud_config.get_flag("encrypt_socket_tls") if src_region != dst_region else False,
verify_checksums=cloud_config.get_flag("verify_checksums"),
aws_instance_class=cloud_config.get_flag("aws_instance_class"),
azure_instance_class=cloud_config.get_flag("azure_instance_class"),
gcp_instance_class=cloud_config.get_flag("gcp_instance_class"),
gcp_use_premium_network=cloud_config.get_flag("gcp_use_premium_network"),
multipart_enabled=multipart,
multipart_max_chunk_size_mb=cloud_config.get_flag("multipart_max_chunk_size_mb"),
)
return 0 if stats["success"] else 1

if cloud_config.get_flag("verify_checksums"):
provider_dst = topo.sink_region().split(":")[0]
if provider_dst == "azure":
typer.secho("Note: Azure post-transfer verification is not yet supported.", fg="yellow", bold=True)
else:
with Progress(
SpinnerColumn(),
TextColumn("Verifying all files were copied{task.description}"),
) as progress:
progress.add_task("", total=None)
ReplicatorClient.verify_transfer_prefix(dest_prefix=path_dst, job=job)

return 0 if stats["success"] else 1
else:
raise NotImplementedError(f"{provider_src} to {provider_dst} not supported yet")

Expand Down Expand Up @@ -301,14 +316,27 @@ def sync(
use_compression=cloud_config.get_flag("compress") if src_region != dst_region else False,
use_e2ee=cloud_config.get_flag("encrypt_e2e") if src_region != dst_region else False,
use_socket_tls=cloud_config.get_flag("encrypt_socket_tls") if src_region != dst_region else False,
verify_checksums=cloud_config.get_flag("verify_checksums"),
aws_instance_class=cloud_config.get_flag("aws_instance_class"),
azure_instance_class=cloud_config.get_flag("azure_instance_class"),
gcp_instance_class=cloud_config.get_flag("gcp_instance_class"),
gcp_use_premium_network=cloud_config.get_flag("gcp_use_premium_network"),
multipart_enabled=multipart,
multipart_max_chunk_size_mb=cloud_config.get_flag("multipart_max_chunk_size_mb"),
)

if cloud_config.get_flag("verify_checksums"):
provider_dst = topo.sink_region().split(":")[0]
if provider_dst == "azure":
typer.secho("Note: Azure post-transfer verification is not yet supported.", fg="yellow", bold=True)
else:
with Progress(
SpinnerColumn(),
TextColumn("Verifying all files were copied{task.description}"),
transient=True,
) as progress:
progress.add_task("", total=None)
ReplicatorClient.verify_transfer_prefix(dest_prefix=path_dst, job=job)

return 0 if stats["success"] else 1


Expand Down
20 changes: 12 additions & 8 deletions skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from skyplane import exceptions, GB, format_bytes, gateway_docker_image, skyplane_root
from skyplane.compute.cloud_providers import CloudProvider
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject
from skyplane.obj_store.s3_interface import S3Object
from skyplane.obj_store.gcs_interface import GCSObject
from skyplane.obj_store.azure_interface import AzureObject
from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob
from skyplane.replicate.replicator_client import ReplicatorClient
from skyplane.utils import logger
Expand Down Expand Up @@ -140,7 +143,15 @@ def generate_full_transferobjlist(
# map objects to destination object paths
for source_obj in source_objs:
dest_key = map_object_key_prefix(source_prefix, source_obj.key, dest_prefix, recursive=recursive)
dest_obj = ObjectStoreObject(dest_region.split(":")[0], dest_bucket, dest_key)
if dest_region.startswith("aws"):
dest_obj = S3Object(dest_region.split(":")[0], dest_bucket, dest_key)
elif dest_region.startswith("gcp"):
dest_obj = GCSObject(dest_region.split(":")[0], dest_bucket, dest_key)
elif dest_region.startswith("azure"):
dest_obj = AzureObject(dest_region.split(":")[0], dest_bucket, dest_key)
else:
raise ValueError(f"Invalid dest_region {dest_region} - could not create corresponding object")
# dest_obj = ObjectStoreObject(dest_region.split(":")[0], dest_bucket, dest_key)
dest_objs.append(dest_obj)

# query destination at dest_key
Expand Down Expand Up @@ -209,7 +220,6 @@ def launch_replication_job(
use_compression: bool = False,
use_e2ee: bool = True,
use_socket_tls: bool = False,
verify_checksums: bool = True,
# multipart
multipart_enabled: bool = False,
multipart_max_chunk_size_mb: int = 8,
Expand Down Expand Up @@ -306,12 +316,6 @@ def launch_replication_job(
typer.secho(error, fg="red")
raise typer.Exit(1)

if verify_checksums:
if any(node.region.startswith("azure") for node in rc.bound_nodes.keys()):
typer.secho("Note: Azure post-transfer verification is not yet supported.", fg="yellow", bold=True)
else:
rc.verify_transfer(job)

# print stats
if stats["success"]:
rprint(f"\n:white_check_mark: [bold green]Transfer completed successfully[/bold green]")
Expand Down
2 changes: 0 additions & 2 deletions skyplane/cli/cli_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def replicate_random(
reuse_gateways=reuse_gateways,
use_bbr=use_bbr,
use_compression=False,
verify_checksums=False,
use_e2ee=True,
)
return 0 if stats["success"] else 1
Expand Down Expand Up @@ -159,6 +158,5 @@ def replicate_random_solve(
reuse_gateways=reuse_gateways,
use_bbr=use_bbr,
use_compression=False,
verify_checksums=False,
)
return 0 if stats["success"] else 1
2 changes: 1 addition & 1 deletion skyplane/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def list_objects(self, prefix="") -> Iterator[S3Object]:
page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
for page in page_iterator:
for obj in page.get("Contents", []):
yield S3Object("s3", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"])
yield S3Object("aws", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"])

def delete_objects(self, keys: List[str]):
s3_client = self.auth.get_boto3_client("s3", self.aws_region)
Expand Down
63 changes: 29 additions & 34 deletions skyplane/replicate/replicator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, List, Optional, Tuple, Iterable
import nacl.secret
import nacl.utils
from rich import print as rprint

import pandas as pd
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeRemainingColumn, DownloadColumn, BarColumn, TransferSpeedColumn
Expand All @@ -29,6 +30,19 @@
from skyplane.utils.timer import Timer


def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] = (), instance_filter=None, n=-1) -> Dict[str, List[Server]]:
if instance_filter is None:
instance_filter = {"tags": {"skyplane": "true"}}
results = do_parallel(
lambda region: provider.get_matching_instances(region=region, **instance_filter),
region_list,
spinner=True,
n=n,
desc="Querying clouds for active instances",
)
return {r: ilist for r, ilist in results if ilist}


class ReplicatorClient:
def __init__(
self,
Expand Down Expand Up @@ -686,42 +700,23 @@ def fn(s: Server):
do_parallel(fn, self.bound_nodes.values(), n=-1)
progress.update(cleanup_task, description=": Shutting down gateways")

def verify_transfer(self, job: ReplicationJob):
@staticmethod
def verify_transfer_prefix(job: ReplicationJob, dest_prefix: str):
"""Check that all objects to copy are present in the destination"""
src_interface = ObjectStoreInterface.create(job.source_region, job.source_bucket)
dst_interface = ObjectStoreInterface.create(job.dest_region, job.dest_bucket)

# only check metadata (src.size == dst.size) && (src.modified <= dst.modified)
def verify(tup):
src_key, dst_key = tup[0].key, tup[1].key
try:
if src_interface.get_obj_size(src_key) != dst_interface.get_obj_size(dst_key):
return False
elif src_interface.get_obj_last_modified(src_key) > dst_interface.get_obj_last_modified(dst_key):
return False
else:
return True
except NoSuchObjectException:
return False

# verify that all objects in src_interface are present in dst_interface
matches = do_parallel(verify, job.transfer_pairs, n=512, spinner=True, spinner_persist=True, desc="Verifying transfer")
failed_src_objs = [src_key for (src_key, dst_key), match in matches if not match]
if len(failed_src_objs) > 0:
# algorithm: check all expected keys are present in the destination
# by iteratively removing found keys from list_objects from a
# precomputed dictionary of keys to check.
dst_keys = {dst_o.key: src_o for src_o, dst_o in job.transfer_pairs}
for obj in dst_interface.list_objects(dest_prefix):
# check metadata (src.size == dst.size) && (src.modified <= dst.modified)
src_obj = dst_keys.get(obj.key)
if src_obj and src_obj.size == obj.size and src_obj.last_modified <= obj.last_modified:
del dst_keys[obj.key]

if dst_keys:
raise exceptions.TransferFailedException(
f"{len(failed_src_objs)} objects failed verification",
failed_src_objs,
f"{len(dst_keys)} objects failed verification",
[obj.key for obj in dst_keys.values()],
)


def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] = (), instance_filter=None, n=-1) -> Dict[str, List[Server]]:
if instance_filter is None:
instance_filter = {"tags": {"skyplane": "true"}}
results = do_parallel(
lambda region: provider.get_matching_instances(region=region, **instance_filter),
region_list,
spinner=True,
n=n,
desc="Querying clouds for active instances",
)
return {r: ilist for r, ilist in results if ilist}

0 comments on commit 3cdebe7

Please sign in to comment.