diff --git a/src/deadline/client/api/__init__.py b/src/deadline/client/api/__init__.py index 58383c43..877c55a4 100644 --- a/src/deadline/client/api/__init__.py +++ b/src/deadline/client/api/__init__.py @@ -4,6 +4,7 @@ "login", "logout", "create_job_from_job_bundle", + "hash_attachments", "wait_for_create_job_to_complete", "get_boto3_session", "get_boto3_client", @@ -52,7 +53,11 @@ list_storage_profiles_for_queue, ) from ._queue_parameters import get_queue_parameter_definitions -from ._submit_job_bundle import create_job_from_job_bundle, wait_for_create_job_to_complete +from ._submit_job_bundle import ( + create_job_from_job_bundle, + wait_for_create_job_to_complete, + hash_attachments, +) from ._telemetry import ( get_telemetry_client, get_deadline_cloud_library_telemetry_client, diff --git a/src/deadline/client/api/_submit_job_bundle.py b/src/deadline/client/api/_submit_job_bundle.py index 31f08d04..eadcbc4b 100644 --- a/src/deadline/client/api/_submit_job_bundle.py +++ b/src/deadline/client/api/_submit_job_bundle.py @@ -276,7 +276,7 @@ def create_job_from_job_bundle( print_function_callback("Job submission canceled.") return None - _, asset_manifests = _hash_attachments( + _, asset_manifests = hash_attachments( asset_manager=asset_manager, asset_groups=upload_group.asset_groups, total_input_files=upload_group.total_input_files, @@ -396,7 +396,7 @@ def wait_for_create_job_to_complete( ) -def _hash_attachments( +def hash_attachments( asset_manager: S3AssetManager, asset_groups: list[AssetRootGroup], total_input_files: int, diff --git a/src/deadline/client/cli/_common.py b/src/deadline/client/cli/_common.py index 1f443063..7e218cb9 100644 --- a/src/deadline/client/cli/_common.py +++ b/src/deadline/client/cli/_common.py @@ -9,6 +9,7 @@ "_handle_error", "_apply_cli_options_to_config", "_cli_object_repr", + "_ProgressBarCallbackManager", ] import sys @@ -16,13 +17,19 @@ from typing import Any, Callable, Optional, Set import click +from contextlib import ExitStack +from deadline.job_attachments.progress_tracker import ProgressReportMetadata from ..config import config_file from ..exceptions import DeadlineOperationError from ..job_bundle import deadline_yaml_dump +from ._groups._sigint_handler import SigIntHandler _PROMPT_WHEN_COMPLETE = "PROMPT_WHEN_COMPLETE" +# Set up the signal handler for handling Ctrl + C interruptions. +sigint_handler = SigIntHandler() + def _prompt_at_completion(ctx: click.Context): """ @@ -176,3 +183,42 @@ def _cli_object_repr(obj: Any): # strings to end with "\n". obj = _fix_multiline_strings(obj) return deadline_yaml_dump(obj) + + +class _ProgressBarCallbackManager: + """ + Manages creation, update, and deletion of a progress bar. On first call of the callback, the progress bar is created. The progress bar is closed + on the final call (100% completion) + """ + + BAR_NOT_CREATED = 0 + BAR_CREATED = 1 + BAR_CLOSED = 2 + + def __init__(self, length: int, label: str): + self._length = length + self._label = label + self._bar_status = self.BAR_NOT_CREATED + self._exit_stack = ExitStack() + + def callback(self, upload_metadata: ProgressReportMetadata) -> bool: + if self._bar_status == self.BAR_CLOSED: + # from multithreaded execution this can be called after completion somtimes. + return sigint_handler.continue_operation + elif self._bar_status == self.BAR_NOT_CREATED: + # Note: click doesn't export the return type of progressbar(), so we suppress mypy warnings for + # not annotating the type of hashing_progress. + self._upload_progress = click.progressbar(length=self._length, label=self._label) # type: ignore[var-annotated] + self._exit_stack.enter_context(self._upload_progress) + self._bar_status = self.BAR_CREATED + + total_progress = int(upload_metadata.progress) + new_progress = total_progress - self._upload_progress.pos + if new_progress > 0: + self._upload_progress.update(new_progress) + + if total_progress == self._length or not sigint_handler.continue_operation: + self._bar_status = self.BAR_CLOSED + self._exit_stack.close() + + return sigint_handler.continue_operation diff --git a/src/deadline/client/cli/_groups/asset_group.py b/src/deadline/client/cli/_groups/asset_group.py index e7eadcb3..6f1bce15 100644 --- a/src/deadline/client/cli/_groups/asset_group.py +++ b/src/deadline/client/cli/_groups/asset_group.py @@ -7,10 +7,17 @@ * diff * download """ +import os +from pathlib import Path import click -from .._common import _handle_error +from deadline.client import api +from deadline.job_attachments.upload import S3AssetManager, S3AssetUploader +from deadline.job_attachments.models import JobAttachmentS3Settings + +from .._common import _handle_error, _ProgressBarCallbackManager +from ...exceptions import NonValidInputError @click.group(name="asset") @@ -22,8 +29,10 @@ def cli_asset(): @cli_asset.command(name="snapshot") -@click.option("--root-dir", help="The root directory to snapshot. ") -@click.option("--manifest-out", help="Destination path to directory where manifest is created. ") +@click.option("--root-dir", required=True, help="The root directory to snapshot. ") +@click.option( + "--manifest-out", default=None, help="Destination path to directory where manifest is created. " +) @click.option( "--recursive", "-r", @@ -33,11 +42,64 @@ def cli_asset(): default=False, ) @_handle_error -def asset_snapshot(**args): +def asset_snapshot(root_dir: str, manifest_out: str, recursive: bool, **args): """ Creates manifest of files specified root directory. """ - click.echo("snapshot taken") + if not os.path.isdir(root_dir): + raise NonValidInputError(f"Specified root directory {root_dir} does not exist. ") + + if manifest_out and not os.path.isdir(manifest_out): + raise NonValidInputError(f"Specified destination directory {manifest_out} does not exist. ") + elif manifest_out is None: + manifest_out = root_dir + click.echo(f"Manifest creation path defaulted to {root_dir} \n") + + inputs = [] + for root, dirs, files in os.walk(root_dir): + inputs.extend([str(os.path.join(root, file)) for file in files]) + if not recursive: + break + + # Placeholder Asset Manager + asset_manager = S3AssetManager( + farm_id=" ", queue_id=" ", job_attachment_settings=JobAttachmentS3Settings(" ", " ") + ) + asset_uploader = S3AssetUploader() + hash_callback_manager = _ProgressBarCallbackManager(length=100, label="Hashing Attachments") + + upload_group = asset_manager.prepare_paths_for_upload( + input_paths=inputs, output_paths=[root_dir], referenced_paths=[] + ) + if upload_group.asset_groups: + _, manifests = api.hash_attachments( + asset_manager=asset_manager, + asset_groups=upload_group.asset_groups, + total_input_files=upload_group.total_input_files, + total_input_bytes=upload_group.total_input_bytes, + print_function_callback=click.echo, + hashing_progress_callback=hash_callback_manager.callback, + ) + + # Write created manifest into local file, at the specified location at manifest_out + for asset_root_manifests in manifests: + if asset_root_manifests.asset_manifest is None: + continue + source_root = Path(asset_root_manifests.root_path) + file_system_location_name = asset_root_manifests.file_system_location_name + (_, _, manifest_name) = asset_uploader._gather_upload_metadata( + manifest=asset_root_manifests.asset_manifest, + source_root=source_root, + file_system_location_name=file_system_location_name, + ) + asset_uploader._write_local_input_manifest( + manifest_write_dir=manifest_out, + manifest_name=manifest_name, + manifest=asset_root_manifests.asset_manifest, + root_dir_name=os.path.basename(root_dir), + ) + + click.echo(f"Manifest created at {manifest_out}\n") @cli_asset.command(name="upload") diff --git a/src/deadline/client/cli/_groups/bundle_group.py b/src/deadline/client/cli/_groups/bundle_group.py index 13648bec..c15b6be8 100644 --- a/src/deadline/client/cli/_groups/bundle_group.py +++ b/src/deadline/client/cli/_groups/bundle_group.py @@ -9,7 +9,6 @@ import re import click -from contextlib import ExitStack from botocore.exceptions import ClientError from deadline.client import api @@ -20,11 +19,10 @@ MisconfiguredInputsError, ) from deadline.job_attachments.models import AssetUploadGroup, JobAttachmentsFileSystem -from deadline.job_attachments.progress_tracker import ProgressReportMetadata from deadline.job_attachments._utils import _human_readable_file_size from ...exceptions import DeadlineOperationError, CreateJobWaiterCanceled -from .._common import _apply_cli_options_to_config, _handle_error +from .._common import _apply_cli_options_to_config, _handle_error, _ProgressBarCallbackManager from ._sigint_handler import SigIntHandler logger = logging.getLogger(__name__) @@ -282,42 +280,3 @@ def bundle_gui_submit(job_bundle_dir, browse, **args): click.echo(f"Job ID: {response['jobId']}") else: click.echo("Job submission canceled.") - - -class _ProgressBarCallbackManager: - """ - Manages creation, update, and deletion of a progress bar. On first call of the callback, the progress bar is created. The progress bar is closed - on the final call (100% completion) - """ - - BAR_NOT_CREATED = 0 - BAR_CREATED = 1 - BAR_CLOSED = 2 - - def __init__(self, length: int, label: str): - self._length = length - self._label = label - self._bar_status = self.BAR_NOT_CREATED - self._exit_stack = ExitStack() - - def callback(self, upload_metadata: ProgressReportMetadata) -> bool: - if self._bar_status == self.BAR_CLOSED: - # from multithreaded execution this can be called after completion somtimes. - return sigint_handler.continue_operation - elif self._bar_status == self.BAR_NOT_CREATED: - # Note: click doesn't export the return type of progressbar(), so we suppress mypy warnings for - # not annotating the type of hashing_progress. - self._upload_progress = click.progressbar(length=self._length, label=self._label) # type: ignore[var-annotated] - self._exit_stack.enter_context(self._upload_progress) - self._bar_status = self.BAR_CREATED - - total_progress = int(upload_metadata.progress) - new_progress = total_progress - self._upload_progress.pos - if new_progress > 0: - self._upload_progress.update(new_progress) - - if total_progress == self._length or not sigint_handler.continue_operation: - self._bar_status = self.BAR_CLOSED - self._exit_stack.close() - - return sigint_handler.continue_operation diff --git a/src/deadline/job_attachments/upload.py b/src/deadline/job_attachments/upload.py index 3d520612..7c0fffb6 100644 --- a/src/deadline/job_attachments/upload.py +++ b/src/deadline/job_attachments/upload.py @@ -164,12 +164,9 @@ def upload_assets( """ # Upload asset manifest - hash_alg = manifest.get_default_hash_alg() - manifest_bytes = manifest.encode().encode("utf-8") - manifest_name_prefix = hash_data( - f"{file_system_location_name or ''}{str(source_root)}".encode(), hash_alg + (hash_alg, manifest_bytes, manifest_name) = self._gather_upload_metadata( + manifest, source_root, file_system_location_name ) - manifest_name = f"{manifest_name_prefix}_input" if partial_manifest_prefix: partial_manifest_key = _join_s3_paths(partial_manifest_prefix, manifest_name) @@ -203,25 +200,70 @@ def upload_assets( return (partial_manifest_key, hash_data(manifest_bytes, hash_alg)) + def _gather_upload_metadata( + self, + manifest: BaseAssetManifest, + source_root: Path, + file_system_location_name: Optional[str] = None, + ) -> tuple[HashAlgorithm, bytes, str]: + """ + Gathers metadata information of manifest to be used for writing the local manifest + """ + hash_alg = manifest.get_default_hash_alg() + manifest_bytes = manifest.encode().encode("utf-8") + manifest_name_prefix = hash_data( + f"{file_system_location_name or ''}{str(source_root)}".encode(), hash_alg + ) + manifest_name = f"{manifest_name_prefix}_input" + + return (hash_alg, manifest_bytes, manifest_name) + def _write_local_manifest( self, manifest_write_dir: str, manifest_name: str, full_manifest_key: str, manifest: BaseAssetManifest, + root_dir_name: Optional[str] = None, ) -> None: """ Writes a manifest file locally in a 'manifests' sub-directory. Also creates/appends to a file mapping the local manifest name to the full S3 key in the same directory. """ - local_manifest_file = Path(manifest_write_dir, "manifests", manifest_name) + self._write_local_input_manifest(manifest_write_dir, manifest_name, manifest, root_dir_name) + + self._write_local_manifest_s3_mapping(manifest_write_dir, manifest_name, full_manifest_key) + + def _write_local_input_manifest( + self, + manifest_write_dir: str, + manifest_name: str, + manifest: BaseAssetManifest, + root_dir_name: Optional[str], + ): + """ + Creates 'manifests' sub-directory and writes a local input manifest file + """ + input_manifest_folder_name = "manifests" + if root_dir_name is not None: + input_manifest_folder_name = root_dir_name + "_" + input_manifest_folder_name + + local_manifest_file = Path(manifest_write_dir, input_manifest_folder_name, manifest_name) logger.info(f"Creating local manifest file: {local_manifest_file}\n") local_manifest_file.parent.mkdir(parents=True, exist_ok=True) with open(local_manifest_file, "w") as file: file.write(manifest.encode()) - # Create or append to an existing mapping file. We use this since path lengths can go beyond the - # file name length limit on Windows if we were to create the full S3 key path locally. + def _write_local_manifest_s3_mapping( + self, + manifest_write_dir: str, + manifest_name: str, + full_manifest_key: str, + ): + """ + Create or append to an existing mapping file. We use this since path lengths can go beyond the + file name length limit on Windows if we were to create the full S3 key path locally. + """ manifest_map_file = Path(manifest_write_dir, "manifests", "manifest_s3_mapping") mapping = {"local_file": manifest_name, "s3_key": full_manifest_key} with open(manifest_map_file, "a") as mapping_file: diff --git a/test/integ/cli/test_cli_asset.py b/test/integ/cli/test_cli_asset.py new file mode 100644 index 00000000..516d5e08 --- /dev/null +++ b/test/integ/cli/test_cli_asset.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +""" +Integ tests for the CLI asset commands. +""" +import os +import json +from click.testing import CliRunner +import pytest +import tempfile +import posixpath +from deadline.job_attachments.asset_manifests.hash_algorithms import hash_file, HashAlgorithm + +from deadline.client.cli import main + + +TEST_FILE_CONTENT = "test file content" +TEST_SUB_DIR_FILE_CONTENT = "subdir file content" +TEST_ROOT_DIR_FILE_CONTENT = "root file content" + +TEST_ROOT_FILE = "root_file.txt" +TEST_SUB_FILE = "subdir_file.txt" + +TEST_ROOT_DIR = "root_dir" +TEST_SUB_DIR_1 = "subdir1" +TEST_SUB_DIR_2 = "subdir2" + + +class TestSnapshot: + + @pytest.fixture + def temp_dir(self): + with tempfile.TemporaryDirectory() as tmpdir_path: + yield tmpdir_path + + def test_root_dir_basic(self, temp_dir): + """ + Snapshot with a valid root directory containing one file, and no other parameters + """ + root_dir = os.path.join(temp_dir, TEST_ROOT_DIR) + os.makedirs(root_dir) + file_path = os.path.join(root_dir, TEST_ROOT_FILE) + with open(file_path, "w") as f: + f.write(TEST_FILE_CONTENT) + + runner = CliRunner() + runner.invoke(main, ["asset", "snapshot", "--root-dir", root_dir]) + + # Check manifest file details to match correct content + # since manifest file name is hashed depending on source location, we have to list out manifest + manifest_folder_path = os.path.join(root_dir, f"{os.path.basename(root_dir)}_manifests") + manifest_files = os.listdir(manifest_folder_path) + assert ( + len(manifest_files) == 1 + ), f"Expected exactly one manifest file, but got {len(manifest_files)}" + + manifest_file_name = manifest_files[0] + manifest_file_path = os.path.join(manifest_folder_path, manifest_file_name) + + with open(manifest_file_path, "r") as f: + manifest_data = json.load(f) + + expected_hash = hash_file(file_path, HashAlgorithm()) # hashed with xxh128 + manifest_data_paths = manifest_data["paths"] + assert ( + len(manifest_data_paths) == 1 + ), f"Expected exactly one path inside manifest, but got {len(manifest_data_paths)}" + assert manifest_data_paths[0]["path"] == TEST_ROOT_FILE + assert manifest_data_paths[0]["hash"] == expected_hash + + def test_root_dir_not_recursive(self, temp_dir): + """ + Snapshot with valid root directory with subdirectory and multiple files, but doesn't recursively snapshot. + """ + root_dir = os.path.join(temp_dir, TEST_ROOT_DIR) + + # Create a file in the root directory + root_file_path = os.path.join(root_dir, TEST_ROOT_FILE) + os.makedirs(os.path.dirname(root_file_path), exist_ok=True) + with open(root_file_path, "w") as f: + f.write(TEST_ROOT_DIR_FILE_CONTENT) + + # Create a file in the subdirectory (should not be included) + subdir_file_path = os.path.join(root_dir, TEST_SUB_DIR_1, TEST_SUB_DIR_2, TEST_SUB_FILE) + os.makedirs(os.path.dirname(subdir_file_path), exist_ok=True) + with open(subdir_file_path, "w") as f: + f.write(TEST_SUB_DIR_FILE_CONTENT) + + runner = CliRunner() + runner.invoke(main, ["asset", "snapshot", "--root-dir", root_dir]) + + # Check manifest file details to match correct content + manifest_folder_path = os.path.join(root_dir, f"{os.path.basename(root_dir)}_manifests") + manifest_files = os.listdir(manifest_folder_path) + assert ( + len(manifest_files) == 1 + ), f"Expected exactly one manifest file, but got {len(manifest_files)}" + + manifest_file_name = manifest_files[0] + manifest_file_path = os.path.join(manifest_folder_path, manifest_file_name) + + with open(manifest_file_path, "r") as f: + manifest_data = json.load(f) + + # should ignore subdirectories + expected_hash = hash_file(root_file_path, HashAlgorithm()) # hashed with xxh128 + manifest_data_paths = manifest_data["paths"] + assert ( + len(manifest_data_paths) == 1 + ), f"Expected exactly one path inside manifest, but got {len(manifest_data_paths)}" + assert manifest_data_paths[0]["path"] == TEST_ROOT_FILE + assert manifest_data_paths[0]["hash"] == expected_hash + + def test_root_dir_recursive(self, temp_dir): + """ + Snapshot with valid root directory with subdirectory and multiple files, and recursively snapshots files. + """ + root_dir = os.path.join(temp_dir, TEST_ROOT_DIR) + + # Create a file in the root directory + root_file_path = os.path.join(root_dir, TEST_ROOT_DIR) + with open(root_file_path, "w") as f: + f.write(TEST_ROOT_DIR_FILE_CONTENT) + + # Create a file in the subdirectory + subdir_file_path = os.path.join(root_dir, TEST_SUB_DIR_1, TEST_SUB_DIR_2, TEST_SUB_FILE) + os.makedirs(os.path.dirname(subdir_file_path), exist_ok=True) + with open(subdir_file_path, "w") as f: + f.write(TEST_SUB_DIR_FILE_CONTENT) + + runner = CliRunner() + runner.invoke(main, ["asset", "snapshot", "--root-dir", root_dir, "--recursive"]) + + # Check manifest file details to match correct content + # since manifest file name is hashed depending on source location, we have to list out manifest + manifest_folder_path = os.path.join(root_dir, f"{os.path.basename(root_dir)}_manifests") + manifest_files = os.listdir(manifest_folder_path) + assert ( + len(manifest_files) == 1 + ), f"Expected exactly one manifest file, but got {len(manifest_files)}" + + root_manifest_file_name = [file for file in manifest_files][0] + root_manifest_file_path = os.path.join(manifest_folder_path, root_manifest_file_name) + + with open(root_manifest_file_path, "r") as f: + manifest_data = json.load(f) + + root_file_hash = hash_file(root_file_path, HashAlgorithm()) # hashed with xxh128 + subdir_file_hash = hash_file(subdir_file_path, HashAlgorithm()) # hashed with xxh128 + manifest_data_paths = manifest_data["paths"] + assert ( + len(manifest_data_paths) == 2 + ), f"Expected exactly 2 paths inside manifest, but got {len(manifest_data_paths)}" + assert manifest_data_paths[0]["path"] == TEST_ROOT_FILE + assert manifest_data_paths[0]["hash"] == root_file_hash + assert manifest_data_paths[1]["path"] == posixpath.join( + TEST_SUB_DIR_1, TEST_SUB_DIR_2, TEST_SUB_FILE + ) + assert manifest_data_paths[1]["hash"] == subdir_file_hash + + def test_specified_manifest_out(self, temp_dir): + """ + Snapshot with valid root directory, checks if manifest is created in the specified --manifest-out location + """ + root_dir = os.path.join(temp_dir, TEST_ROOT_DIR) + os.makedirs(root_dir) + manifest_out_dir = os.path.join(temp_dir, "manifest_out") + os.makedirs(manifest_out_dir) + file_path = os.path.join(root_dir, TEST_ROOT_FILE) + with open(file_path, "w") as f: + f.write(TEST_FILE_CONTENT) + + runner = CliRunner() + runner.invoke( + main, ["asset", "snapshot", "--root-dir", root_dir, "--manifest-out", manifest_out_dir] + ) + + # Check manifest file details to match correct content + # since manifest file name is hashed depending on source location, we have to list out manifest + manifest_folder_path = os.path.join( + manifest_out_dir, f"{os.path.basename(root_dir)}_manifests" + ) + manifest_files = os.listdir(manifest_folder_path) + assert ( + len(manifest_files) == 1 + ), f"Expected exactly one manifest file, but got {len(manifest_files)}" + + manifest_file_name = manifest_files[0] + + manifest_file_path = os.path.join(manifest_folder_path, manifest_file_name) + + with open(manifest_file_path, "r") as f: + manifest_data = json.load(f) + + expected_hash = hash_file(file_path, HashAlgorithm()) # hashed with xxh128 + manifest_data_paths = manifest_data["paths"] + assert ( + len(manifest_data_paths) == 1 + ), f"Expected exactly one path inside manifest, but got {len(manifest_data_paths)}" + assert manifest_data_paths[0]["path"] == TEST_ROOT_FILE + assert manifest_data_paths[0]["hash"] == expected_hash diff --git a/test/unit/deadline_client/api/test_job_bundle_submission.py b/test/unit/deadline_client/api/test_job_bundle_submission.py index 36d77120..de245c90 100644 --- a/test/unit/deadline_client/api/test_job_bundle_submission.py +++ b/test/unit/deadline_client/api/test_job_bundle_submission.py @@ -479,7 +479,7 @@ def test_create_job_from_job_bundle_job_attachments( ) as client_mock, patch.object( _submit_job_bundle.api, "get_queue_user_boto3_session" ), patch.object( - api._submit_job_bundle, "_hash_attachments", return_value=(None, None) + api._submit_job_bundle, "hash_attachments", return_value=(None, None) ) as mock_hash_attachments, patch.object( S3AssetManager, "prepare_paths_for_upload", @@ -601,7 +601,7 @@ def test_create_job_from_job_bundle_empty_job_attachments( ) as client_mock, patch.object( _submit_job_bundle.api, "get_queue_user_boto3_session" ), patch.object( - api._submit_job_bundle, "_hash_attachments", return_value=(None, None) + api._submit_job_bundle, "hash_attachments", return_value=(None, None) ) as mock_hash_attachments, patch.object( S3AssetManager, "prepare_paths_for_upload", @@ -923,7 +923,7 @@ def test_create_job_from_job_bundle_with_single_asset_file( ) as client_mock, patch.object( _submit_job_bundle.api, "get_queue_user_boto3_session" ), patch.object( - api._submit_job_bundle, "_hash_attachments", return_value=(None, None) + api._submit_job_bundle, "hash_attachments", return_value=(None, None) ) as mock_hash_attachments, patch.object( S3AssetManager, "prepare_paths_for_upload", diff --git a/test/unit/deadline_client/cli/test_cli_asset.py b/test/unit/deadline_client/cli/test_cli_asset.py new file mode 100644 index 00000000..9f4e0626 --- /dev/null +++ b/test/unit/deadline_client/cli/test_cli_asset.py @@ -0,0 +1,157 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import pytest +from unittest.mock import patch, Mock +from click.testing import CliRunner + +from deadline.client.cli import main +from deadline.client import api +from deadline.job_attachments.upload import S3AssetManager +from deadline.job_attachments.models import AssetRootGroup + + +@pytest.fixture +def mock_prepare_paths_for_upload(): + with patch.object(S3AssetManager, "prepare_paths_for_upload") as mock: + yield mock + + +@pytest.fixture +def mock_hash_attachments(): + with patch.object(api, "hash_attachments", return_value=(Mock(), [])) as mock: + yield mock + + +@pytest.fixture +def asset_group_mock(tmp_path): + root_dir = str(tmp_path) + return AssetRootGroup( + root_path=root_dir, + inputs=set(), + outputs=set(), + references=set(), + ) + + +@pytest.fixture +def upload_group_mock(asset_group_mock): + return Mock( + asset_groups=[asset_group_mock], + total_input_files=1, + total_input_bytes=100, + ) + + +class TestSnapshot: + + def test_snapshot_root_directory_only( + self, tmp_path, mock_prepare_paths_for_upload, mock_hash_attachments, upload_group_mock + ): + """ + Tests if CLI snapshot command calls correctly with an exiting directory path at --root-dir + """ + root_dir = str(tmp_path) + + temp_file = tmp_path / "temp_file.txt" + temp_file.touch() + + mock_prepare_paths_for_upload.return_value = upload_group_mock + + runner = CliRunner() + result = runner.invoke(main, ["asset", "snapshot", "--root-dir", root_dir]) + + assert result.exit_code == 0 + mock_prepare_paths_for_upload.assert_called_once_with( + input_paths=[str(temp_file)], output_paths=[root_dir], referenced_paths=[] + ) + mock_hash_attachments.assert_called_once() + + def test_invalid_root_directory(self, tmp_path): + """ + Tests if CLI snapshot raises error when called with an invalid --root-dir with non-existing directory path + """ + invalid_root_dir = str(tmp_path / "invalid_dir") + + runner = CliRunner() + result = runner.invoke(main, ["asset", "snapshot", "--root-dir", invalid_root_dir]) + + assert result.exit_code != 0 + assert f"Specified root directory {invalid_root_dir} does not exist. " in result.output + + def test_valid_manifest_out( + self, tmp_path, mock_prepare_paths_for_upload, mock_hash_attachments, upload_group_mock + ): + """ + Tests if CLI snapshot command correctly calls with --manifest-out arguement + """ + root_dir = str(tmp_path) + manifest_out_dir = tmp_path / "manifest_out" + manifest_out_dir.mkdir() + + temp_file = tmp_path / "temp_file.txt" + temp_file.touch() + + mock_prepare_paths_for_upload.return_value = upload_group_mock + + runner = CliRunner() + result = runner.invoke( + main, + [ + "asset", + "snapshot", + "--root-dir", + root_dir, + "--manifest-out", + str(manifest_out_dir), + ], + ) + + assert result.exit_code == 0 + mock_prepare_paths_for_upload.assert_called_once_with( + input_paths=[str(temp_file)], output_paths=[root_dir], referenced_paths=[] + ) + mock_hash_attachments.assert_called_once() + + def test_invalid_manifest_out(self, tmp_path): + """ + Tests if CLI snapshot raises error when called with invalid --manifest-out with non-existing directory path + """ + root_dir = str(tmp_path) + invalid_manifest_out = str(tmp_path / "nonexistent_dir") + + runner = CliRunner() + result = runner.invoke( + main, + ["asset", "snapshot", "--root-dir", root_dir, "--manifest-out", invalid_manifest_out], + ) + + assert result.exit_code != 0 + assert ( + f"Specified destination directory {invalid_manifest_out} does not exist. " + in result.output + ) + + def test_asset_snapshot_recursive( + self, tmp_path, mock_prepare_paths_for_upload, mock_hash_attachments, upload_group_mock + ): + """ + Tests if CLI snapshot --recursive flag is called correctly + """ + root_dir = str(tmp_path) + subdir1 = tmp_path / "subdir1" + subdir2 = tmp_path / "subdir2" + subdir1.mkdir() + subdir2.mkdir() + (subdir1 / "file1.txt").touch() + (subdir2 / "file2.txt").touch() + + expected_inputs = {str(subdir2 / "file2.txt"), str(subdir1 / "file1.txt")} + mock_prepare_paths_for_upload.return_value = upload_group_mock + + runner = CliRunner() + result = runner.invoke(main, ["asset", "snapshot", "--root-dir", root_dir, "--recursive"]) + + assert result.exit_code == 0 + actual_inputs = set(mock_prepare_paths_for_upload.call_args[1]["input_paths"]) + assert actual_inputs == expected_inputs + mock_hash_attachments.assert_called_once() diff --git a/test/unit/deadline_client/cli/test_cli_bundle.py b/test/unit/deadline_client/cli/test_cli_bundle.py index 970c4455..b59a14dd 100644 --- a/test/unit/deadline_client/cli/test_cli_bundle.py +++ b/test/unit/deadline_client/cli/test_cli_bundle.py @@ -118,7 +118,7 @@ def test_cli_bundle_submit(fresh_deadline_config, temp_job_bundle_dir): ) as get_boto3_client_mock, patch.object( _queue_parameters, "get_boto3_client" ) as qp_boto3_client_mock, patch.object( - _submit_job_bundle, "_hash_attachments", return_value=[] + _submit_job_bundle, "hash_attachments", return_value=[] ), patch.object( _submit_job_bundle.api, "get_queue_user_boto3_session" ), patch.object( @@ -339,7 +339,7 @@ def test_cli_bundle_asset_load_method(fresh_deadline_config, temp_job_bundle_dir ) as bundle_boto3_client_mock, patch.object( _queue_parameters, "get_boto3_client" ) as qp_boto3_client_mock, patch.object( - _submit_job_bundle, "_hash_attachments", return_value=(attachment_mock, {}) + _submit_job_bundle, "hash_attachments", return_value=(attachment_mock, {}) ), patch.object( _submit_job_bundle, "_upload_attachments", return_value={} ), patch.object( @@ -633,7 +633,7 @@ def test_cli_bundle_accept_upload_confirmation(fresh_deadline_config, temp_job_b with patch.object( _submit_job_bundle.api, "get_boto3_client" ) as get_boto3_client_mock, patch.object( - _submit_job_bundle, "_hash_attachments", return_value=[SummaryStatistics(), "test"] + _submit_job_bundle, "hash_attachments", return_value=[SummaryStatistics(), "test"] ), patch.object( _submit_job_bundle, "_upload_attachments" ), patch.object( @@ -711,7 +711,7 @@ def test_cli_bundle_reject_upload_confirmation(fresh_deadline_config, temp_job_b ) as get_boto3_client_mock, patch.object( _queue_parameters, "get_boto3_client" ) as qp_boto3_client_mock, patch.object( - _submit_job_bundle, "_hash_attachments", return_value=[SummaryStatistics(), "test"] + _submit_job_bundle, "hash_attachments", return_value=[SummaryStatistics(), "test"] ), patch.object( _submit_job_bundle, "_upload_attachments" ) as upload_attachments_mock, patch.object(