From 7e75eeef7015b2ffd87991fd5f55c2a55b497021 Mon Sep 17 00:00:00 2001 From: stangch <171081544+stangch@users.noreply.github.com> Date: Fri, 26 Jul 2024 10:20:13 -0700 Subject: [PATCH] feat(asset-cli): asset download subcommand Signed-off-by: Tang <171081544+stangch@users.noreply.github.com> --- .../client/cli/_groups/asset_group.py | 85 ++++++++++- src/deadline/job_attachments/download.py | 113 +++++++++----- src/deadline/job_attachments/upload.py | 5 +- .../deadline_client/cli/test_cli_asset.py | 141 +++++++++++++++++- 4 files changed, 298 insertions(+), 46 deletions(-) diff --git a/src/deadline/client/cli/_groups/asset_group.py b/src/deadline/client/cli/_groups/asset_group.py index ba4f5d274..4f5a902df 100644 --- a/src/deadline/client/cli/_groups/asset_group.py +++ b/src/deadline/client/cli/_groups/asset_group.py @@ -20,13 +20,15 @@ from deadline.client import api from deadline.job_attachments.upload import FileStatus, S3AssetManager, S3AssetUploader -from deadline.job_attachments.models import ( - JobAttachmentS3Settings, - AssetRootManifest, -) +from deadline.job_attachments.download import download_file_with_s3_key +from deadline.job_attachments.models import JobAttachmentS3Settings, AssetRootManifest from deadline.job_attachments.asset_manifests.decode import decode_manifest from deadline.job_attachments.asset_manifests.base_manifest import BaseAssetManifest from deadline.job_attachments.caches import HashCache +from deadline.job_attachments._aws.aws_clients import ( + get_s3_client, + get_s3_transfer_manager, +) from .._common import _apply_cli_options_to_config, _handle_error, _ProgressBarCallbackManager from ...exceptions import NonValidInputError, ManifestOutdatedError @@ -300,15 +302,82 @@ def asset_diff(root_dir: str, manifest_dir: str, raw: bool, **args): @cli_asset.command(name="download") -@click.option("--farm-id", help="The AWS Deadline Cloud Farm to use.") -@click.option("--queue-id", help="The AWS Deadline Cloud Queue to use.") +@click.option("--farm-id", help="The AWS Deadline Cloud Farm to use. ") +@click.option("--queue-id", help="The AWS Deadline Cloud Queue to use. ") @click.option("--job-id", help="The AWS Deadline Cloud Job to get. ") +@click.option( + "--manifest-out", + required=True, + help="Destination path to directory where manifest is created. ", +) @_handle_error -def asset_download(**args): +def asset_download(manifest_out: str, **args): """ Downloads input manifest of previously submitted job. """ - click.echo("download complete") + if not os.path.isdir(manifest_out): + raise NonValidInputError(f"Specified destination directory {manifest_out} does not exist. ") + + # setup config + config = _apply_cli_options_to_config( + required_options={"farm_id", "queue_id", "job_id"}, **args + ) + deadline: BaseClient = api.get_boto3_client("deadline", config=config) + queue_id: str = get_setting("defaults.queue_id", config=config) + farm_id: str = get_setting("defaults.farm_id", config=config) + job_id = config_file.get_setting("defaults.job_id", config=config) + + queue: dict = deadline.get_queue( + farmId=farm_id, + queueId=queue_id, + ) + + # assume queue role - session permissions + queue_role_session: boto3.Session = api.get_queue_user_boto3_session( + deadline=deadline, + config=config, + farm_id=farm_id, + queue_id=queue_id, + ) + + # get input_manifest_paths from Deadline GetJob API + job: dict = deadline.get_job(farmId=farm_id, queueId=queue_id, jobId=job_id) + attachments: dict = job["attachments"] + input_manifest_paths: list[str] = [ + manifest["inputManifestPath"] for manifest in attachments["manifests"] + ] + + # get s3BucketName from Deadline GetQueue API + bucket_name: str = queue["jobAttachmentSettings"]["s3BucketName"] + + # get S3 prefix + s3Prefix = queue["jobAttachmentSettings"]["rootPrefix"] + "/Manifests/" + + s3_client: BaseClient = get_s3_client(session=queue_role_session) + transfer_manager = get_s3_transfer_manager(s3_client=s3_client) + + # download each input_manifest_path + for input_manifest_path in input_manifest_paths: + local_file_name = Path(manifest_out, job_id + "_manifest") + + result = download_file_with_s3_key( + s3_bucket=bucket_name, + s3_key=s3Prefix + input_manifest_path, + local_file_name=local_file_name, + session=queue_role_session, + transfer_manager=transfer_manager, + ) + + if result is not None: + transfer_path = result.meta.call_args.fileobj # type: ignore[attr-defined] + file_size = result.meta.size # type: ignore[attr-defined] + click.echo( + f"\nDownloaded file to '{transfer_path}' ({file_size} bytes)\nWith S3 key: '{input_manifest_path}'. " + ) + else: + click.echo( + f"\nFailed to download file with S3 key '{input_manifest_path}' from bucket '{bucket_name}'" + ) def read_local_manifest(manifest: str) -> BaseAssetManifest: diff --git a/src/deadline/job_attachments/download.py b/src/deadline/job_attachments/download.py index c12f268b8..8e02a86b6 100644 --- a/src/deadline/job_attachments/download.py +++ b/src/deadline/job_attachments/download.py @@ -372,48 +372,25 @@ def download_files_in_directory( ) -def download_file( - file: RelativeFilePath, - hash_algorithm: HashAlgorithm, - local_download_dir: str, +def download_file_with_s3_key( s3_bucket: str, - cas_prefix: Optional[str], - s3_client: Optional[BaseClient] = None, + s3_key: str, + transfer_manager, + local_file_name: Path, + file_bytes: Optional[int] = 0, session: Optional[boto3.Session] = None, - modified_time_override: Optional[float] = None, progress_tracker: Optional[ProgressTracker] = None, file_conflict_resolution: Optional[FileConflictResolution] = FileConflictResolution.CREATE_COPY, -) -> Tuple[int, Optional[Path]]: +) -> concurrent.futures.Future | None: """ - Downloads a file from the S3 bucket to the local directory. `modified_time_override` is ignored if the manifest - version used supports timestamps. - Returns a tuple of (size in bytes, filename) of the downloaded file. - - The file size of 0 means that this file comes from a manifest version that does not provide file sizes. - - The filename of None indicates that this file has been skipped or has not been downloaded. + Helper to download a file from the S3 bucket with a specified S3 key to the local directory. + Returns the asynchronous result of the downloaded file, and None if otherwise """ - if not s3_client: - s3_client = get_s3_client(session=session) - - transfer_manager = get_s3_transfer_manager(s3_client=s3_client) - - # The modified time in the manifest is in microseconds, but utime requires the time be expressed in seconds. - modified_time_override = file.mtime / 1000000 # type: ignore[attr-defined] - - file_bytes = file.size - - # Python will handle the path separator '/' correctly on every platform. - local_file_name = Path(local_download_dir).joinpath(file.path) - - s3_key = ( - f"{cas_prefix}/{file.hash}.{hash_algorithm.value}" - if cas_prefix - else f"{file.hash}.{hash_algorithm.value}" - ) # If the file name already exists, resolve the conflict based on the file_conflict_resolution if local_file_name.is_file(): if file_conflict_resolution == FileConflictResolution.SKIP: - return (file_bytes, None) + return None elif file_conflict_resolution == FileConflictResolution.OVERWRITE: pass elif file_conflict_resolution == FileConflictResolution.CREATE_COPY: @@ -432,6 +409,7 @@ def download_file( future: concurrent.futures.Future + # provides progress callback for asynchronous download from S3 def handler(bytes_downloaded): nonlocal progress_tracker nonlocal future @@ -450,9 +428,63 @@ def handler(bytes_downloaded): extra_args={"ExpectedBucketOwner": get_account_id(session=session)}, subscribers=subscribers, ) + future.result() + return future + + +def download_file( + file: RelativeFilePath, + hash_algorithm: HashAlgorithm, + local_download_dir: str, + s3_bucket: str, + cas_prefix: Optional[str], + s3_client: Optional[BaseClient] = None, + session: Optional[boto3.Session] = None, + modified_time_override: Optional[float] = None, + progress_tracker: Optional[ProgressTracker] = None, + file_conflict_resolution: Optional[FileConflictResolution] = FileConflictResolution.CREATE_COPY, +) -> Tuple[int, Optional[Path]]: + """ + Downloads a file from the S3 bucket to the local directory. `modified_time_override` is ignored if the manifest + version used supports timestamps. + Returns a tuple of (size in bytes, filename) of the downloaded file. + - The file size of 0 means that this file comes from a manifest version that does not provide file sizes. + - The filename of None indicates that this file has been skipped or has not been downloaded. + """ + if not s3_client: + s3_client = get_s3_client(session=session) + + transfer_manager = get_s3_transfer_manager(s3_client=s3_client) + + file_bytes = file.size + + # The modified time in the manifest is in microseconds, but utime requires the time be expressed in seconds. + modified_time_override = file.mtime / 1000000 # type: ignore[attr-defined] + + # Python will handle the path separator '/' correctly on every platform. + local_file_name = Path(local_download_dir).joinpath(file.path) + + s3_key = ( + f"{cas_prefix}/{file.hash}.{hash_algorithm.value}" + if cas_prefix + else f"{file.hash}.{hash_algorithm.value}" + ) try: - future.result() + future = download_file_with_s3_key( + s3_bucket=s3_bucket, + s3_key=s3_key, + file_bytes=file_bytes, + transfer_manager=transfer_manager, + local_file_name=local_file_name, + session=session, + progress_tracker=progress_tracker, + file_conflict_resolution=file_conflict_resolution, + ) + + if future is None: + return (file_bytes, None) + except concurrent.futures.CancelledError as ce: if progress_tracker and progress_tracker.continue_reporting is False: raise AssetSyncCancelledError("File download cancelled.") @@ -490,6 +522,18 @@ def process_client_error(exc: ClientError, status_code: int): # TODO: Temporary to prevent breaking backwards-compatibility; if file not found, try again without hash alg postfix status_code = int(exc.response["ResponseMetadata"]["HTTPStatusCode"]) if status_code == 404: + + def handler(bytes_downloaded): + nonlocal progress_tracker + nonlocal future + + if progress_tracker: + should_continue = progress_tracker.track_progress_callback(bytes_downloaded) + if not should_continue and future is not None: + future.cancel() + + subscribers = [ProgressCallbackInvoker(handler)] + s3_key = s3_key.rsplit(".", 1)[0] future = transfer_manager.download( bucket=s3_bucket, @@ -499,7 +543,8 @@ def process_client_error(exc: ClientError, status_code: int): subscribers=subscribers, ) try: - future.result() + if future is not None: + future.result() except concurrent.futures.CancelledError as ce: if progress_tracker and progress_tracker.continue_reporting is False: raise AssetSyncCancelledError("File download cancelled.") diff --git a/src/deadline/job_attachments/upload.py b/src/deadline/job_attachments/upload.py index 16b22e52a..88fcbb243 100644 --- a/src/deadline/job_attachments/upload.py +++ b/src/deadline/job_attachments/upload.py @@ -191,7 +191,10 @@ def upload_assets( if manifest_write_dir: self._write_local_manifest( - manifest_write_dir, manifest_name, full_manifest_key, manifest + manifest_write_dir, + manifest_name, + full_manifest_key, + manifest, ) self.upload_bytes_to_s3( diff --git a/test/unit/deadline_client/cli/test_cli_asset.py b/test/unit/deadline_client/cli/test_cli_asset.py index e3f2e6d91..84f1c889c 100644 --- a/test/unit/deadline_client/cli/test_cli_asset.py +++ b/test/unit/deadline_client/cli/test_cli_asset.py @@ -1,9 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +import os import pytest -from unittest.mock import patch, Mock, MagicMock +from unittest.mock import patch, Mock, MagicMock, PropertyMock from click.testing import CliRunner -import os +import concurrent.futures from deadline.client.cli import main from deadline.client.cli._groups import asset_group @@ -20,9 +21,10 @@ from deadline.job_attachments.asset_manifests.v2023_03_03 import AssetManifest from deadline.job_attachments.asset_manifests.hash_algorithms import HashAlgorithm -from ..api.test_job_bundle_submission import ( +from ..shared_constants import ( MOCK_FARM_ID, MOCK_QUEUE_ID, + MOCK_JOB_ID, ) @@ -150,9 +152,27 @@ def _mock_read_local_manifest(manifest): MOCK_ROOT_DIR = "/path/to/root" MOCK_MANIFEST_DIR = "/path/to/manifest" +MOCK_MANIFEST_OUT_DIR = "path/to/out/dir" MOCK_MANIFEST_FILE = os.path.join(MOCK_MANIFEST_DIR, "manifest_input") MOCK_INVALID_DIR = "/nopath/" MOCK_UPLOAD_ATTACHMENTS_RESPONSE = {"manifests": [{"inputManifestPath": "s3://mock/manifest.json"}]} +MOCK_JOB_ATTACHMENTS = { + "manifests": [ + { + "inputManifestHash": "mock_input_manifest_hash", + "inputManifestPath": "mock_input_manifest_path", + "outputRelativeDirectories": ["mock_output_dir"], + "rootPath": "mock_root_path", + "rootPathFormat": "mock_root_path_format", + } + ] +} +MOCK_QUEUE = { + "queueId": "queue-0123456789abcdef0123456789abcdef", + "displayName": "mock_queue", + "description": "mock_description", + "jobAttachmentSettings": {"s3BucketName": "mock_bucket", "rootPrefix": "mock_deadline"}, +} class TestSnapshot: @@ -583,3 +603,118 @@ def test_asset_diff_invalid_manifest_dir(self, tmp_path): assert ( f"Specified manifest directory {invalid_manifest_dir} does not exist. " in result.output ) + + +class TestAssetDownload: + + MOCK_FUTURE = Mock(spec=concurrent.futures.Future) + + def test_asset_download_valid(self, fresh_deadline_config): + """ + Test the asset download command with valid inputs. + """ + + # Mock the API calls + with patch.object(api, "get_boto3_client") as mock_get_boto3_client, patch.object( + api, "get_queue_user_boto3_session" + ), patch.object(os.path, "isdir", side_effect=[True, True]), patch.object( + os.path, "isfile", side_effect=[True, True] + ), patch.object( + mock_get_boto3_client.return_value, + "get_job", + return_value={"attachments": MOCK_JOB_ATTACHMENTS}, + ), patch.object( + mock_get_boto3_client.return_value, + "get_queue", + return_value=MOCK_QUEUE, + ), patch.object( + asset_group, + "download_file_with_s3_key", + return_value=self.MOCK_FUTURE, + ) as mock_download_file: + + mock_meta = Mock() + mock_call_args = Mock() + mock_call_args.fileobj = "mocked_transfer_path" + mock_meta.call_args = mock_call_args + mock_meta.size = 1024 + + type(self.MOCK_FUTURE).meta = PropertyMock(return_value=mock_meta) + + runner = CliRunner() + result = runner.invoke( + main, + [ + "asset", + "download", + "--farm-id", + MOCK_FARM_ID, + "--queue-id", + MOCK_QUEUE_ID, + "--job-id", + MOCK_JOB_ID, + "--manifest-out", + MOCK_MANIFEST_OUT_DIR, + ], + ) + + print(result.output) + mock_download_file.assert_called_once() + assert result.exit_code == 0 + assert ( + "\nDownloaded file to 'mocked_transfer_path' (1024 bytes)\nWith S3 key: 'mock_input_manifest_path'. \n" + in result.output + ) + + def test_download_invalid_job_id(self, fresh_deadline_config): + """ + Test the asset download command when the required --job-id option is missing and doesn't exist in default config. + """ + runner = CliRunner() + result = runner.invoke( + main, + [ + "asset", + "download", + "--farm-id", + MOCK_FARM_ID, + "--queue-id", + MOCK_QUEUE_ID, + "--manifest-out", + MOCK_MANIFEST_OUT_DIR, + ], + ) + + assert result.exit_code in [1, 2] + assert ( + "Usage: main asset download [OPTIONS]\nTry 'main asset download -h' for help.\n\nError: Missing '--job-id' or default Job ID configuration\n" + in result.output + ) + + def test_download_invalid_manifest_out(self, fresh_deadline_config): + """ + Test the asset download command when the required --manifest-out option is missing. + """ + runner = CliRunner() + result = runner.invoke( + main, + [ + "asset", + "download", + "--farm-id", + MOCK_FARM_ID, + "--queue-id", + MOCK_QUEUE_ID, + "--job-id", + MOCK_JOB_ID, + "--manifest-out", + MOCK_INVALID_DIR, + ], + ) + + print(result.output) + + assert result.exit_code == 1 + assert ( + f"Specified destination directory {MOCK_INVALID_DIR} does not exist. " in result.output + )