From ac44c2a3a8f15b56f2e2545ea85c433e72a50788 Mon Sep 17 00:00:00 2001 From: Gahyun Suh <132245153+gahyusuh@users.noreply.github.com> Date: Wed, 10 May 2023 16:20:03 -0500 Subject: [PATCH] feat(job_attachments)!: add mechanism to cancel file download (#3) BREAKING CHANGE: - `on_downloading_files` must now return a bool indicating whether to cancel the download(s) for functions listed below: - AssetSync.sync_inputs(), and - download_files_in_directory(), download_files_from_manifest(), download_job_output() and mount_vfs_from_manifest() in download.py --- examples/download_cancel_test.py | 168 ++++++++++++++++++ hatch.toml | 4 +- src/bealine_job_attachments/asset_sync.py | 18 +- src/bealine_job_attachments/download.py | 101 ++++++++--- src/bealine_job_attachments/errors.py | 9 +- .../progress_tracker.py | 4 +- src/bealine_job_attachments/upload.py | 5 +- .../integ/test_job_attachments.py | 6 +- .../unit/test_download.py | 12 +- 9 files changed, 287 insertions(+), 40 deletions(-) create mode 100644 examples/download_cancel_test.py diff --git a/examples/download_cancel_test.py b/examples/download_cancel_test.py new file mode 100644 index 000000000..a96d3a92a --- /dev/null +++ b/examples/download_cancel_test.py @@ -0,0 +1,168 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +#! /usr/bin/env python3 +import argparse +import pathlib +from tempfile import TemporaryDirectory +import time +from threading import Thread + +from bealine_job_attachments.asset_sync import AssetSync +from bealine_job_attachments.aws.bealine import get_job, get_queue +from bealine_job_attachments.download import download_job_output +from bealine_job_attachments.errors import AssetSyncCancelledError + +# A testing script to simulate cancellation of (1) syncing inputs, and (2) downloading outputs. +# +# How to test: +# 1. Run the script with the following command for each test: +# (1) To test canceling syncing inputs, run the following command: +# $ python3 download_cancel_test.py sync_inputs -f -q -j +# (2) To test canceling downloading outputs, run the following command: +# $ python3 download_cancel_test.py download_outputs -f -q -j +# 2. In the middle of downloading files, you can send a cencel signal by pressing 'k' key +# and then pressing 'Enter' key in succession. Confirm that cancelling is working as expected. + +MESSAGE_HOW_TO_CANCEL = ( + "To stop the download process, please hit 'k' key and then 'Enter' key in succession.\n" +) +continue_reporting = True +main_terminated = False + + +def run(): + print(MESSAGE_HOW_TO_CANCEL) + parser = argparse.ArgumentParser(description=MESSAGE_HOW_TO_CANCEL) + parser.add_argument( + "test_to_run", + choices=["sync_inputs", "download_outputs"], + help="Test to run. ('sync_inputs' or 'download_outputs')", + ) + parser.add_argument( + "-f", "--farm-id", type=str, help="Bealine Farm to download assets from.", required=True + ) + parser.add_argument( + "-q", "--queue-id", type=str, help="Bealine Queue to download assets from.", required=True + ) + parser.add_argument( + "-j", "--job-id", type=str, help="Bealine Job to download assets from.", required=True + ) + args = parser.parse_args() + + test_to_run = args.test_to_run + farm_id = args.farm_id + queue_id = args.queue_id + job_id = args.job_id + + if test_to_run == "sync_inputs": + test_sync_inputs(farm_id=farm_id, queue_id=queue_id, job_id=job_id) + elif test_to_run == "download_outputs": + test_download_outputs(farm_id=farm_id, queue_id=queue_id, job_id=job_id) + + +def test_sync_inputs( + farm_id: str, + queue_id: str, + job_id: str, +): + """ + Tests cancellation during execution of the `sync_inputs` function. + """ + start_time = time.perf_counter() + + with TemporaryDirectory() as temp_root_dir: + print(f"Created a temporary directory for the test: {temp_root_dir}") + + queue = get_queue(farm_id=farm_id, queue_id=queue_id) + job = get_job(farm_id=farm_id, queue_id=queue_id, job_id=job_id) + + print("Starting test to sync inputs...") + asset_sync = AssetSync() + + try: + download_start = time.perf_counter() + (summary_statistics, local_roots) = asset_sync.sync_inputs( + s3_settings=queue.jobAttachmentSettings, + ja_settings=job.attachmentSettings, + queue_id=queue_id, + job_id=job_id, + session_dir=pathlib.Path(temp_root_dir), + on_downloading_files=mock_on_downloading_files, + ) + print(f"Download Summary Statistics:\n{summary_statistics}") + print( + f"Finished downloading after {time.perf_counter() - download_start} seconds, returned:\n{local_roots}" + ) + + except AssetSyncCancelledError as asce: + print(f"AssetSyncCancelledError: {asce}") + print(f"payload: {asce.summary_statistics}") + + print(f"\nTotal test runtime: {time.perf_counter() - start_time}") + + print(f"Cleaned up the temporary directory: {temp_root_dir}") + global main_terminated + main_terminated = True + + +def test_download_outputs( + farm_id: str, + queue_id: str, + job_id: str, +): + """ + Tests cancellation during execution of the `download_job_output` function. + """ + start_time = time.perf_counter() + + queue = get_queue(farm_id=farm_id, queue_id=queue_id) + + print("Starting test to download outputs...") + + try: + download_start = time.perf_counter() + summary_statistics = download_job_output( + s3_settings=queue.jobAttachmentSettings, + job_id=job_id, + on_downloading_files=mock_on_downloading_files, + ) + print(f"Download Summary Statistics:\n{summary_statistics}") + print(f"Finished downloading after {time.perf_counter() - download_start} seconds") + + except AssetSyncCancelledError as asce: + print(f"AssetSyncCancelledError: {asce}") + print(f"payload: {asce.summary_statistics}") + + print(f"\nTotal test runtime: {time.perf_counter() - start_time}") + + global main_terminated + main_terminated = True + + +def mock_on_downloading_files(metadata): + print(metadata) + return mock_on_cancellation_check() + + +def mock_on_cancellation_check(): + return continue_reporting + + +def wait_for_cancellation_input(): + while not main_terminated: + ch = input() + if ch == "k": + set_cancelled() + break + + +def set_cancelled(): + global continue_reporting + continue_reporting = False + print("Canceled the process.") + + +if __name__ == "__main__": + t = Thread(target=wait_for_cancellation_input) + t.start() + run() diff --git a/hatch.toml b/hatch.toml index 7f43849e9..7853cdc6c 100644 --- a/hatch.toml +++ b/hatch.toml @@ -23,7 +23,6 @@ dependencies = [ [envs.default.scripts] test = "pytest --cov-config pyproject.toml {args:test/bealine test/bealine_job_attachments/unit}" -integtest = "pytest {args:test/bealine_job_attachments/integ}" typing = "mypy {args:src test}" style = [ "ruff {args:.}", @@ -41,6 +40,9 @@ lint = [ [[envs.all.matrix]] python = ["3.7", "3.9", "3.10", "3.11"] +[envs.integ.scripts] +test = "pytest {args:test/bealine_job_attachments/integ} -vvv --numprocesses=1" + [envs.default.env-vars] PIP_INDEX_URL="https://aws:{env:CODEARTIFACT_AUTH_TOKEN}@{env:CODEARTIFACT_DOMAIN}-{env:CODEARTIFACT_ACCOUNT_ID}.d.codeartifact.{env:CODEARTIFACT_REGION}.amazonaws.com/pypi/{env:CODEARTIFACT_REPOSITORY}/simple/" diff --git a/src/bealine_job_attachments/asset_sync.py b/src/bealine_job_attachments/asset_sync.py index 0a51b0536..348a402cf 100644 --- a/src/bealine_job_attachments/asset_sync.py +++ b/src/bealine_job_attachments/asset_sync.py @@ -241,12 +241,24 @@ def sync_inputs( queue_id: str, job_id: str, session_dir: Path, - on_downloading_files: Optional[Callable] = None, + on_downloading_files: Optional[Callable[[ProgressReportMetadata], bool]] = None, ) -> Tuple[DownloadSummaryStatistics, List[Dict[str, str]]]: """ Downloads a manifest file and corresponding input files, if found. - Returns a tuple of (1) final summary statistics for file downloads, and - (2) a list of local roots for each asset root, used for path mapping. + + Args: + s3_settings: S3-specific Job Attachment settings. + ja_settings: Job Attachment settings. + queue_id: the ID of the queue. + job_id: the ID of the job. + session_dir: the directory that the session is going to use. + on_downloading_files: a function that will be called with a ProgressReportMetadata object + for each file being downloaded. If the function returns False, the download will be + cancelled. If it returns True, the download will continue. + + Returns: + a tuple of (1) final summary statistics for file downloads, and + (2) a list of local roots for each asset root, used for path mapping. """ if not s3_settings: logger.info( diff --git a/src/bealine_job_attachments/download.py b/src/bealine_job_attachments/download.py index 19997f746..73a67f032 100644 --- a/src/bealine_job_attachments/download.py +++ b/src/bealine_job_attachments/download.py @@ -20,6 +20,7 @@ from bealine_job_attachments.progress_tracker import ( DownloadSummaryStatistics, + ProgressReportMetadata, ProgressStatus, ProgressTracker, ) @@ -28,7 +29,7 @@ from .asset_manifests.decode import decode_manifest from .asset_manifests.versions import ManifestVersion from .aws.aws_clients import get_account_id, get_s3_client -from .errors import JobAttachmentsError, MissingAssetRootError +from .errors import AssetSyncCancelledError, JobAttachmentsError, MissingAssetRootError from .models import JobAttachmentS3Settings, JobSettings from .utils import get_bucket_and_object_key @@ -113,7 +114,6 @@ def get_job_input_paths_by_asset_root( The lists are separated by asset root. Returns a tuple of (total size of assets in bytes, lists of assets) """ - # TODO: Change this to use the manifest file paths when the job attachment browser supports it. assets: DefaultDict[str, list[RelativeFilePath]] = DefaultDict(list) total_bytes = 0 @@ -167,7 +167,6 @@ def get_job_input_output_paths_by_asset_root( of this job. The lists are separated by asset root. Returns a tuple of (total size of files in bytes, lists of files) """ - # TODO: Need to handle colliding paths for files: https://sim.amazon.com/issues/Bea-5786 (total_input_bytes, input_files) = get_job_input_paths_by_asset_root( s3_settings=s3_settings, job_settings=job_settings, session=session ) @@ -188,7 +187,7 @@ def download_files_in_directory( directory_path: str, local_download_dir: str, session: Optional[boto3.Session] = None, - on_downloading_files: Optional[Callable] = None, + on_downloading_files: Optional[Callable[[ProgressReportMetadata], bool]] = None, ) -> DownloadSummaryStatistics: """ From a given job's input and output files, downloads all files in @@ -200,7 +199,6 @@ def download_files_in_directory( s3_settings=s3_settings, job_settings=job_settings, job_id=job_id, session=session ) - # TODO: will need to handle multiple roots in the future https://sim.amazon.com/issues/Bea-4507 files_to_download: list[RelativeFilePath] = [] total_bytes = 0 for files in all_paths_hashes.values(): @@ -240,14 +238,21 @@ def download_file( s3_client: Optional[BaseClient] = None, session: Optional[boto3.Session] = None, modified_time_override: Optional[float] = None, - progress_tracker_callback: Optional[Callable] = None, -) -> bool: + progress_tracker: Optional[ProgressTracker] = None, +) -> Tuple[int, Path]: """ Downloads a file from the S3 bucket to the local directory. `modified_time_override` is ignored if the manifest version used supports timestamps. - Returns whether the download progress should be tracked in number of files instead of bytes. - (The manifest containing the file may be of a version that does not provide the size of files.) + Returns a tuple of (size in bytes, filename) of the downloaded file. + (If this file comes from a version of manifest that does not provide file sizes, the size will be 0, + which means the download progress should be tracked by the number of files instead of bytes.) """ + # If it's cancelled, raise an AssetSyncCancelledError. + if progress_tracker and not progress_tracker.continue_reporting: + raise AssetSyncCancelledError( + "File download cancelled.", progress_tracker.get_summary_statistics() + ) + if not s3_client: s3_client = get_s3_client(session=session) @@ -270,26 +275,29 @@ def download_file( s3_key, str(local_file_name), ExtraArgs={"ExpectedBucketOwner": get_account_id(session=session)}, - Callback=_progress_logger(file_bytes, progress_tracker_callback), + Callback=_progress_logger( + file_bytes, + progress_tracker.track_progress_callback if progress_tracker else None, + ), ) logger.info(f"Downloaded {file.path} to {str(local_file_name)}") os.utime(local_file_name, (modified_time_override, modified_time_override)) # type: ignore[arg-type] - return file_bytes == 0 + return (file_bytes, local_file_name) def _progress_logger( file_size_in_bytes: int, progress_tracker_callback: Optional[Callable] = None ) -> Callable[[int], None]: - total_uploaded = 0 + total_downloaded = 0 - def handler(bytes_uploaded): + def handler(bytes_downloaded): if progress_tracker_callback is None or file_size_in_bytes == 0: return - nonlocal total_uploaded - total_uploaded += bytes_uploaded - progress_tracker_callback(bytes_uploaded, total_uploaded == file_size_in_bytes) + nonlocal total_downloaded + total_downloaded += bytes_downloaded + progress_tracker_callback(bytes_downloaded, total_downloaded == file_size_in_bytes) return handler @@ -308,6 +316,8 @@ def _download_files_parallel( Downloads files in parallel using thread pool. Returns a list of local paths of downloaded files. """ + downloaded_file_names: list[str] = [] + # TODO: tune this. max_worker defaults to 5 * number of processors. We can run into issues here # if we thread too aggressively on slower internet connections. So for now let's set it to 5, # which would the number of threads with one processor. @@ -322,22 +332,29 @@ def _download_files_parallel( s3_client, session, file_mod_time, - progress_tracker.track_progress_callback if progress_tracker else None, + progress_tracker, ): file for file in files } # surfaces any exceptions in the thread for future in concurrent.futures.as_completed(futures): - track_progress_by_number_of_files = future.result() - if track_progress_by_number_of_files and progress_tracker: + (file_bytes, local_file_name) = future.result() + downloaded_file_names.append(str(local_file_name.resolve())) + if file_bytes == 0 and progress_tracker: progress_tracker.increase_processed(1, 0) progress_tracker.report_progress() - # to report progress 100% at the end + # to report progress 100% at the end, and + # to check if the download was canceled in the middle of processing the last batch of files. if progress_tracker: progress_tracker.report_progress() + if not progress_tracker.continue_reporting: + raise AssetSyncCancelledError( + "File download cancelled.", + progress_tracker.get_download_summary_statistics(downloaded_file_names), + ) - return [str(Path(local_download_dir).joinpath(file.path).resolve()) for file in files] + return downloaded_file_names def download_files_from_manifest( @@ -346,10 +363,22 @@ def download_files_from_manifest( local_download_dir: str, cas_prefix: Optional[str] = None, session: Optional[boto3.Session] = None, - on_downloading_files: Optional[Callable] = None, + on_downloading_files: Optional[Callable[[ProgressReportMetadata], bool]] = None, ) -> DownloadSummaryStatistics: """ Given an asset manifest, downloads all files from a CAS in that manifest. + + Args: + s3_bucket: The name of the S3 bucket. + manifest_path: The path to the manifest file. + local_download_dir: The local root directory to download the files to. + cas_prefix: The CAS prefix of the files. + session: The boto3 session to use. + on_downloading_files: a callback to be called to periodically report progress to the caller. + The callback returns True if the operation should continue as normal, or False to cancel. + + Returns: + The download summary statistics """ s3_client = get_s3_client(session=session) @@ -471,12 +500,24 @@ def download_job_output( step_id: Optional[str] = None, task_id: Optional[str] = None, session: Optional[boto3.Session] = None, - on_downloading_files: Optional[Callable] = None, + on_downloading_files: Optional[Callable[[ProgressReportMetadata], bool]] = None, ) -> DownloadSummaryStatistics: """ Convenience function to download all output files from the given job, with optional step and task-level granularity. Automatically downloads to the asset root(s) set in the output manifest S3 metadata. + Args: + s3_settings: The S3 settings to use. + job_id: The ID of the job to download. + step_id: The ID of the step to download. + task_id: The ID of the task to download. + session: The boto3 session to use. + on_downloading_files: a callback to be called to periodically report progress to the caller. + The callback returns True if the operation should continue as normal, or False to cancel. + + Returns: + The download summary statistics + TODO: The download location is OS-specific to the *submitting machine* matching the profile of job["attachmentSettings"]["submissionProfileName"]. The OS of the *downloading machine* might be different, so we need to check that @@ -516,10 +557,22 @@ def mount_vfs_from_manifest( local_download_dir: str, cas_prefix: Optional[str] = None, session: Optional[boto3.Session] = None, - on_downloading_files: Optional[Callable] = None, + on_downloading_files: Optional[Callable[[ProgressReportMetadata], bool]] = None, ) -> DownloadSummaryStatistics: """ Given an asset manifest, downloads all files from a CAS in that manifest. + + Args: + s3_bucket: The name of the S3 bucket. + manifest_path: The path to the manifest file. + local_download_dir: The local root directory to download the files to. + cas_prefix: The CAS prefix of the files. + session: The boto3 session to use. + on_downloading_files: a callback to be called to periodically report progress to the caller. + The callback returns True if the operation should continue as normal, or False to cancel. + + Returns: + The download summary statistics. """ logger.info("Successfully triggered vfs from manifests!") diff --git a/src/bealine_job_attachments/errors.py b/src/bealine_job_attachments/errors.py index b1426d07e..dc7012e20 100644 --- a/src/bealine_job_attachments/errors.py +++ b/src/bealine_job_attachments/errors.py @@ -5,6 +5,8 @@ """ +from typing import Optional + from bealine_job_attachments.progress_tracker import SummaryStatistics @@ -55,8 +57,9 @@ class AssetSyncCancelledError(JobAttachmentsError): Exception thrown when an operation (synching files to/from S3) has been cancelled. """ - summary_statistics: SummaryStatistics + summary_statistics: Optional[SummaryStatistics] = None - def __init__(self, message, summary_statistics): + def __init__(self, message, summary_statistics: Optional[SummaryStatistics] = None): super().__init__(message) - self.summary_statistics = summary_statistics + if summary_statistics: + self.summary_statistics = summary_statistics diff --git a/src/bealine_job_attachments/progress_tracker.py b/src/bealine_job_attachments/progress_tracker.py index da4910a9b..bc5c838ce 100644 --- a/src/bealine_job_attachments/progress_tracker.py +++ b/src/bealine_job_attachments/progress_tracker.py @@ -6,7 +6,7 @@ from dataclasses import asdict, dataclass, field, fields from enum import Enum from threading import Lock -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional from bealine_job_attachments.utils import human_readable_file_size @@ -109,7 +109,7 @@ class ProgressTracker: def __init__( self, status: ProgressStatus, - on_progress_callback: Optional[Callable[[Any], bool]] = None, + on_progress_callback: Optional[Callable[[ProgressReportMetadata], bool]] = None, interval: int = DURATION_BETWEEN_CALLS, files_in_chunk: int = FILES_IN_CHUNK, ) -> None: diff --git a/src/bealine_job_attachments/upload.py b/src/bealine_job_attachments/upload.py index add7b8ea4..23d19d180 100644 --- a/src/bealine_job_attachments/upload.py +++ b/src/bealine_job_attachments/upload.py @@ -46,7 +46,7 @@ ) from .utils import get_bealine_formatted_os, get_os_pure_path, hash_data, hash_file, join_s3_paths -# TODO: full performance analysis to determine the ideal threshold https://sim.amazon.com/issues/Bea-5551 +# TODO: full performance analysis to determine the ideal threshold LIST_OBJECT_THRESHOLD: int = 100 @@ -125,7 +125,6 @@ def upload_input_files( TODO: There is a known performance bottleneck if the bucket has a large number of files, but there isn't currently any way of knowing the size of the bucket without iterating through the contents of a prefix. For now, we'll just head-object when we have a small number of files. - Research to be done in https://sim.amazon.com/issues/Bea-5551 """ files_to_upload: list[base_manifest.Path] = manifest.paths check_if_in_s3 = True @@ -332,7 +331,7 @@ def handler(bytes_uploaded): total_uploaded += bytes_uploaded percentage = round(total_uploaded / file_size * 100) - # TODO: https://sim.amazon.com/issues/Bea-2613 This log is too long for INFO + # TODO: This log is too long for INFO if last_reported_percentage < percentage: logger.info(f"Uploading {path}: uploaded {percentage}%") last_reported_percentage = percentage diff --git a/test/bealine_job_attachments/integ/test_job_attachments.py b/test/bealine_job_attachments/integ/test_job_attachments.py index 467f632bc..9722964b3 100644 --- a/test/bealine_job_attachments/integ/test_job_attachments.py +++ b/test/bealine_job_attachments/integ/test_job_attachments.py @@ -657,7 +657,8 @@ def sync_outputs( ) task_ids = { - task["parameterSet"]["frame"]: task["taskId"] for task in list_tasks_response["tasks"] + task["parameterSet"]["frame"]["int"]: task["taskId"] + for task in list_tasks_response["tasks"] } step0_task0_id = task_ids["0"] @@ -1328,7 +1329,8 @@ def test_sync_outputs_bucket_wrong_account( ) task_ids = { - task["parameterSet"]["frame"]: task["taskId"] for task in list_tasks_response["tasks"] + task["parameterSet"]["frame"]["int"]: task["taskId"] + for task in list_tasks_response["tasks"] } step0_task0_id = task_ids["0"] diff --git a/test/bealine_job_attachments/unit/test_download.py b/test/bealine_job_attachments/unit/test_download.py index ee8bb733b..991e593d4 100644 --- a/test/bealine_job_attachments/unit/test_download.py +++ b/test/bealine_job_attachments/unit/test_download.py @@ -3,7 +3,7 @@ """Tests for downloading files from the Job Attachment CAS.""" from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, fields from io import BytesIO from pathlib import Path from typing import Any, Callable, List @@ -375,7 +375,15 @@ def assert_progress_tracker_values( ) mock_on_downloading_files.assert_called_with(expected_last_progress_report) - assert summary_statistics == expected_summary_statistics + for property in fields(expected_summary_statistics): + if property.name == "downloaded_files_paths": + assert set(getattr(summary_statistics, property.name)) == set( + str(file) for file in expected_files + ) + else: + assert getattr(summary_statistics, property.name) == getattr( + expected_summary_statistics, property.name + ) def assert_download_job_output_with_task_id_and_no_step_id_throws_error(