diff --git a/skyplane/cli/cli.py b/skyplane/cli/cli.py index dfd508f50..f72f085df 100644 --- a/skyplane/cli/cli.py +++ b/skyplane/cli/cli.py @@ -4,8 +4,10 @@ from pathlib import Path from shlex import split import traceback +from skyplane.replicate.replicator_client import ReplicatorClient import typer +from rich.progress import Progress, SpinnerColumn, TextColumn import skyplane.cli.cli_aws import skyplane.cli.cli_azure @@ -154,6 +156,7 @@ def error_local(): job=job, ask_to_confirm_transfer=not confirm, ) + stats = launch_replication_job( topo=topo, job=job, @@ -163,7 +166,6 @@ def error_local(): use_compression=cloud_config.get_flag("compress") if src_region != dst_region else False, use_e2ee=cloud_config.get_flag("encrypt_e2e") if src_region != dst_region else False, use_socket_tls=cloud_config.get_flag("encrypt_socket_tls") if src_region != dst_region else False, - verify_checksums=cloud_config.get_flag("verify_checksums"), aws_instance_class=cloud_config.get_flag("aws_instance_class"), azure_instance_class=cloud_config.get_flag("azure_instance_class"), gcp_instance_class=cloud_config.get_flag("gcp_instance_class"), @@ -171,7 +173,20 @@ def error_local(): multipart_enabled=multipart, multipart_max_chunk_size_mb=cloud_config.get_flag("multipart_max_chunk_size_mb"), ) - return 0 if stats["success"] else 1 + + if cloud_config.get_flag("verify_checksums"): + provider_dst = topo.sink_region().split(":")[0] + if provider_dst == "azure": + typer.secho("Note: Azure post-transfer verification is not yet supported.", fg="yellow", bold=True) + else: + with Progress( + SpinnerColumn(), + TextColumn("Verifying all files were copied{task.description}"), + ) as progress: + progress.add_task("", total=None) + ReplicatorClient.verify_transfer_prefix(dest_prefix=path_dst, job=job) + + return 0 if stats["success"] else 1 else: raise NotImplementedError(f"{provider_src} to {provider_dst} not supported yet") @@ -301,7 +316,6 @@ def sync( use_compression=cloud_config.get_flag("compress") if src_region != dst_region else False, use_e2ee=cloud_config.get_flag("encrypt_e2e") if src_region != dst_region else False, use_socket_tls=cloud_config.get_flag("encrypt_socket_tls") if src_region != dst_region else False, - verify_checksums=cloud_config.get_flag("verify_checksums"), aws_instance_class=cloud_config.get_flag("aws_instance_class"), azure_instance_class=cloud_config.get_flag("azure_instance_class"), gcp_instance_class=cloud_config.get_flag("gcp_instance_class"), @@ -309,6 +323,20 @@ def sync( multipart_enabled=multipart, multipart_max_chunk_size_mb=cloud_config.get_flag("multipart_max_chunk_size_mb"), ) + + if cloud_config.get_flag("verify_checksums"): + provider_dst = topo.sink_region().split(":")[0] + if provider_dst == "azure": + typer.secho("Note: Azure post-transfer verification is not yet supported.", fg="yellow", bold=True) + else: + with Progress( + SpinnerColumn(), + TextColumn("Verifying all files were copied{task.description}"), + transient=True, + ) as progress: + progress.add_task("", total=None) + ReplicatorClient.verify_transfer_prefix(dest_prefix=path_dst, job=job) + return 0 if stats["success"] else 1 diff --git a/skyplane/cli/cli_impl/cp_replicate.py b/skyplane/cli/cli_impl/cp_replicate.py index b781228bd..59ef28098 100644 --- a/skyplane/cli/cli_impl/cp_replicate.py +++ b/skyplane/cli/cli_impl/cp_replicate.py @@ -11,6 +11,9 @@ from skyplane import exceptions, GB, format_bytes, gateway_docker_image, skyplane_root from skyplane.compute.cloud_providers import CloudProvider from skyplane.obj_store.object_store_interface import ObjectStoreInterface, ObjectStoreObject +from skyplane.obj_store.s3_interface import S3Object +from skyplane.obj_store.gcs_interface import GCSObject +from skyplane.obj_store.azure_interface import AzureObject from skyplane.replicate.replication_plan import ReplicationTopology, ReplicationJob from skyplane.replicate.replicator_client import ReplicatorClient from skyplane.utils import logger @@ -140,7 +143,15 @@ def generate_full_transferobjlist( # map objects to destination object paths for source_obj in source_objs: dest_key = map_object_key_prefix(source_prefix, source_obj.key, dest_prefix, recursive=recursive) - dest_obj = ObjectStoreObject(dest_region.split(":")[0], dest_bucket, dest_key) + if dest_region.startswith("aws"): + dest_obj = S3Object(dest_region.split(":")[0], dest_bucket, dest_key) + elif dest_region.startswith("gcp"): + dest_obj = GCSObject(dest_region.split(":")[0], dest_bucket, dest_key) + elif dest_region.startswith("azure"): + dest_obj = AzureObject(dest_region.split(":")[0], dest_bucket, dest_key) + else: + raise ValueError(f"Invalid dest_region {dest_region} - could not create corresponding object") + # dest_obj = ObjectStoreObject(dest_region.split(":")[0], dest_bucket, dest_key) dest_objs.append(dest_obj) # query destination at dest_key @@ -209,7 +220,6 @@ def launch_replication_job( use_compression: bool = False, use_e2ee: bool = True, use_socket_tls: bool = False, - verify_checksums: bool = True, # multipart multipart_enabled: bool = False, multipart_max_chunk_size_mb: int = 8, @@ -306,12 +316,6 @@ def launch_replication_job( typer.secho(error, fg="red") raise typer.Exit(1) - if verify_checksums: - if any(node.region.startswith("azure") for node in rc.bound_nodes.keys()): - typer.secho("Note: Azure post-transfer verification is not yet supported.", fg="yellow", bold=True) - else: - rc.verify_transfer(job) - # print stats if stats["success"]: rprint(f"\n:white_check_mark: [bold green]Transfer completed successfully[/bold green]") diff --git a/skyplane/cli/cli_internal.py b/skyplane/cli/cli_internal.py index b9462c30f..863a5c4a5 100644 --- a/skyplane/cli/cli_internal.py +++ b/skyplane/cli/cli_internal.py @@ -73,7 +73,6 @@ def replicate_random( reuse_gateways=reuse_gateways, use_bbr=use_bbr, use_compression=False, - verify_checksums=False, use_e2ee=True, ) return 0 if stats["success"] else 1 @@ -159,6 +158,5 @@ def replicate_random_solve( reuse_gateways=reuse_gateways, use_bbr=use_bbr, use_compression=False, - verify_checksums=False, ) return 0 if stats["success"] else 1 diff --git a/skyplane/obj_store/s3_interface.py b/skyplane/obj_store/s3_interface.py index 224b80ab6..01dadacba 100644 --- a/skyplane/obj_store/s3_interface.py +++ b/skyplane/obj_store/s3_interface.py @@ -61,7 +61,7 @@ def list_objects(self, prefix="") -> Iterator[S3Object]: page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) for page in page_iterator: for obj in page.get("Contents", []): - yield S3Object("s3", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"]) + yield S3Object("aws", self.bucket_name, obj["Key"], obj["Size"], obj["LastModified"]) def delete_objects(self, keys: List[str]): s3_client = self.auth.get_boto3_client("s3", self.aws_region) diff --git a/skyplane/replicate/replicator_client.py b/skyplane/replicate/replicator_client.py index 6e9a1fb12..99ff58eb1 100644 --- a/skyplane/replicate/replicator_client.py +++ b/skyplane/replicate/replicator_client.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Iterable import nacl.secret import nacl.utils +from rich import print as rprint import pandas as pd from rich.progress import Progress, SpinnerColumn, TextColumn, TimeRemainingColumn, DownloadColumn, BarColumn, TransferSpeedColumn @@ -29,6 +30,19 @@ from skyplane.utils.timer import Timer +def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] = (), instance_filter=None, n=-1) -> Dict[str, List[Server]]: + if instance_filter is None: + instance_filter = {"tags": {"skyplane": "true"}} + results = do_parallel( + lambda region: provider.get_matching_instances(region=region, **instance_filter), + region_list, + spinner=True, + n=n, + desc="Querying clouds for active instances", + ) + return {r: ilist for r, ilist in results if ilist} + + class ReplicatorClient: def __init__( self, @@ -686,42 +700,23 @@ def fn(s: Server): do_parallel(fn, self.bound_nodes.values(), n=-1) progress.update(cleanup_task, description=": Shutting down gateways") - def verify_transfer(self, job: ReplicationJob): + @staticmethod + def verify_transfer_prefix(job: ReplicationJob, dest_prefix: str): """Check that all objects to copy are present in the destination""" - src_interface = ObjectStoreInterface.create(job.source_region, job.source_bucket) dst_interface = ObjectStoreInterface.create(job.dest_region, job.dest_bucket) - # only check metadata (src.size == dst.size) && (src.modified <= dst.modified) - def verify(tup): - src_key, dst_key = tup[0].key, tup[1].key - try: - if src_interface.get_obj_size(src_key) != dst_interface.get_obj_size(dst_key): - return False - elif src_interface.get_obj_last_modified(src_key) > dst_interface.get_obj_last_modified(dst_key): - return False - else: - return True - except NoSuchObjectException: - return False - - # verify that all objects in src_interface are present in dst_interface - matches = do_parallel(verify, job.transfer_pairs, n=512, spinner=True, spinner_persist=True, desc="Verifying transfer") - failed_src_objs = [src_key for (src_key, dst_key), match in matches if not match] - if len(failed_src_objs) > 0: + # algorithm: check all expected keys are present in the destination + # by iteratively removing found keys from list_objects from a + # precomputed dictionary of keys to check. + dst_keys = {dst_o.key: src_o for src_o, dst_o in job.transfer_pairs} + for obj in dst_interface.list_objects(dest_prefix): + # check metadata (src.size == dst.size) && (src.modified <= dst.modified) + src_obj = dst_keys.get(obj.key) + if src_obj and src_obj.size == obj.size and src_obj.last_modified <= obj.last_modified: + del dst_keys[obj.key] + + if dst_keys: raise exceptions.TransferFailedException( - f"{len(failed_src_objs)} objects failed verification", - failed_src_objs, + f"{len(dst_keys)} objects failed verification", + [obj.key for obj in dst_keys.values()], ) - - -def refresh_instance_list(provider: CloudProvider, region_list: Iterable[str] = (), instance_filter=None, n=-1) -> Dict[str, List[Server]]: - if instance_filter is None: - instance_filter = {"tags": {"skyplane": "true"}} - results = do_parallel( - lambda region: provider.get_matching_instances(region=region, **instance_filter), - region_list, - spinner=True, - n=n, - desc="Querying clouds for active instances", - ) - return {r: ilist for r, ilist in results if ilist}