Skip to content

Commit

Permalink
feat(asset-cli): asset download subcommand
Browse files Browse the repository at this point in the history
Signed-off-by: Tang <[email protected]>
  • Loading branch information
stangch committed Jul 30, 2024
1 parent aa1787d commit 7e75eee
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 46 deletions.
85 changes: 77 additions & 8 deletions src/deadline/client/cli/_groups/asset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

from deadline.client import api
from deadline.job_attachments.upload import FileStatus, S3AssetManager, S3AssetUploader
from deadline.job_attachments.models import (
JobAttachmentS3Settings,
AssetRootManifest,
)
from deadline.job_attachments.download import download_file_with_s3_key
from deadline.job_attachments.models import JobAttachmentS3Settings, AssetRootManifest
from deadline.job_attachments.asset_manifests.decode import decode_manifest
from deadline.job_attachments.asset_manifests.base_manifest import BaseAssetManifest
from deadline.job_attachments.caches import HashCache
from deadline.job_attachments._aws.aws_clients import (
get_s3_client,
get_s3_transfer_manager,
)

from .._common import _apply_cli_options_to_config, _handle_error, _ProgressBarCallbackManager
from ...exceptions import NonValidInputError, ManifestOutdatedError
Expand Down Expand Up @@ -300,15 +302,82 @@ def asset_diff(root_dir: str, manifest_dir: str, raw: bool, **args):


@cli_asset.command(name="download")
@click.option("--farm-id", help="The AWS Deadline Cloud Farm to use.")
@click.option("--queue-id", help="The AWS Deadline Cloud Queue to use.")
@click.option("--farm-id", help="The AWS Deadline Cloud Farm to use. ")
@click.option("--queue-id", help="The AWS Deadline Cloud Queue to use. ")
@click.option("--job-id", help="The AWS Deadline Cloud Job to get. ")
@click.option(
"--manifest-out",
required=True,
help="Destination path to directory where manifest is created. ",
)
@_handle_error
def asset_download(**args):
def asset_download(manifest_out: str, **args):
"""
Downloads input manifest of previously submitted job.
"""
click.echo("download complete")
if not os.path.isdir(manifest_out):
raise NonValidInputError(f"Specified destination directory {manifest_out} does not exist. ")

# setup config
config = _apply_cli_options_to_config(
required_options={"farm_id", "queue_id", "job_id"}, **args
)
deadline: BaseClient = api.get_boto3_client("deadline", config=config)
queue_id: str = get_setting("defaults.queue_id", config=config)
farm_id: str = get_setting("defaults.farm_id", config=config)
job_id = config_file.get_setting("defaults.job_id", config=config)

queue: dict = deadline.get_queue(
farmId=farm_id,
queueId=queue_id,
)

# assume queue role - session permissions
queue_role_session: boto3.Session = api.get_queue_user_boto3_session(
deadline=deadline,
config=config,
farm_id=farm_id,
queue_id=queue_id,
)

# get input_manifest_paths from Deadline GetJob API
job: dict = deadline.get_job(farmId=farm_id, queueId=queue_id, jobId=job_id)
attachments: dict = job["attachments"]
input_manifest_paths: list[str] = [
manifest["inputManifestPath"] for manifest in attachments["manifests"]
]

# get s3BucketName from Deadline GetQueue API
bucket_name: str = queue["jobAttachmentSettings"]["s3BucketName"]

# get S3 prefix
s3Prefix = queue["jobAttachmentSettings"]["rootPrefix"] + "/Manifests/"

s3_client: BaseClient = get_s3_client(session=queue_role_session)
transfer_manager = get_s3_transfer_manager(s3_client=s3_client)

# download each input_manifest_path
for input_manifest_path in input_manifest_paths:
local_file_name = Path(manifest_out, job_id + "_manifest")

result = download_file_with_s3_key(
s3_bucket=bucket_name,
s3_key=s3Prefix + input_manifest_path,
local_file_name=local_file_name,
session=queue_role_session,
transfer_manager=transfer_manager,
)

if result is not None:
transfer_path = result.meta.call_args.fileobj # type: ignore[attr-defined]
file_size = result.meta.size # type: ignore[attr-defined]
click.echo(
f"\nDownloaded file to '{transfer_path}' ({file_size} bytes)\nWith S3 key: '{input_manifest_path}'. "
)
else:
click.echo(
f"\nFailed to download file with S3 key '{input_manifest_path}' from bucket '{bucket_name}'"
)


def read_local_manifest(manifest: str) -> BaseAssetManifest:
Expand Down
113 changes: 79 additions & 34 deletions src/deadline/job_attachments/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,48 +372,25 @@ def download_files_in_directory(
)


def download_file(
file: RelativeFilePath,
hash_algorithm: HashAlgorithm,
local_download_dir: str,
def download_file_with_s3_key(
s3_bucket: str,
cas_prefix: Optional[str],
s3_client: Optional[BaseClient] = None,
s3_key: str,
transfer_manager,
local_file_name: Path,
file_bytes: Optional[int] = 0,
session: Optional[boto3.Session] = None,
modified_time_override: Optional[float] = None,
progress_tracker: Optional[ProgressTracker] = None,
file_conflict_resolution: Optional[FileConflictResolution] = FileConflictResolution.CREATE_COPY,
) -> Tuple[int, Optional[Path]]:
) -> concurrent.futures.Future | None:
"""
Downloads a file from the S3 bucket to the local directory. `modified_time_override` is ignored if the manifest
version used supports timestamps.
Returns a tuple of (size in bytes, filename) of the downloaded file.
- The file size of 0 means that this file comes from a manifest version that does not provide file sizes.
- The filename of None indicates that this file has been skipped or has not been downloaded.
Helper to download a file from the S3 bucket with a specified S3 key to the local directory.
Returns the asynchronous result of the downloaded file, and None if otherwise
"""
if not s3_client:
s3_client = get_s3_client(session=session)

transfer_manager = get_s3_transfer_manager(s3_client=s3_client)

# The modified time in the manifest is in microseconds, but utime requires the time be expressed in seconds.
modified_time_override = file.mtime / 1000000 # type: ignore[attr-defined]

file_bytes = file.size

# Python will handle the path separator '/' correctly on every platform.
local_file_name = Path(local_download_dir).joinpath(file.path)

s3_key = (
f"{cas_prefix}/{file.hash}.{hash_algorithm.value}"
if cas_prefix
else f"{file.hash}.{hash_algorithm.value}"
)

# If the file name already exists, resolve the conflict based on the file_conflict_resolution
if local_file_name.is_file():
if file_conflict_resolution == FileConflictResolution.SKIP:
return (file_bytes, None)
return None
elif file_conflict_resolution == FileConflictResolution.OVERWRITE:
pass
elif file_conflict_resolution == FileConflictResolution.CREATE_COPY:
Expand All @@ -432,6 +409,7 @@ def download_file(

future: concurrent.futures.Future

# provides progress callback for asynchronous download from S3
def handler(bytes_downloaded):
nonlocal progress_tracker
nonlocal future
Expand All @@ -450,9 +428,63 @@ def handler(bytes_downloaded):
extra_args={"ExpectedBucketOwner": get_account_id(session=session)},
subscribers=subscribers,
)
future.result()
return future


def download_file(
file: RelativeFilePath,
hash_algorithm: HashAlgorithm,
local_download_dir: str,
s3_bucket: str,
cas_prefix: Optional[str],
s3_client: Optional[BaseClient] = None,
session: Optional[boto3.Session] = None,
modified_time_override: Optional[float] = None,
progress_tracker: Optional[ProgressTracker] = None,
file_conflict_resolution: Optional[FileConflictResolution] = FileConflictResolution.CREATE_COPY,
) -> Tuple[int, Optional[Path]]:
"""
Downloads a file from the S3 bucket to the local directory. `modified_time_override` is ignored if the manifest
version used supports timestamps.
Returns a tuple of (size in bytes, filename) of the downloaded file.
- The file size of 0 means that this file comes from a manifest version that does not provide file sizes.
- The filename of None indicates that this file has been skipped or has not been downloaded.
"""
if not s3_client:
s3_client = get_s3_client(session=session)

transfer_manager = get_s3_transfer_manager(s3_client=s3_client)

file_bytes = file.size

# The modified time in the manifest is in microseconds, but utime requires the time be expressed in seconds.
modified_time_override = file.mtime / 1000000 # type: ignore[attr-defined]

# Python will handle the path separator '/' correctly on every platform.
local_file_name = Path(local_download_dir).joinpath(file.path)

s3_key = (
f"{cas_prefix}/{file.hash}.{hash_algorithm.value}"
if cas_prefix
else f"{file.hash}.{hash_algorithm.value}"
)

try:
future.result()
future = download_file_with_s3_key(
s3_bucket=s3_bucket,
s3_key=s3_key,
file_bytes=file_bytes,
transfer_manager=transfer_manager,
local_file_name=local_file_name,
session=session,
progress_tracker=progress_tracker,
file_conflict_resolution=file_conflict_resolution,
)

if future is None:
return (file_bytes, None)

except concurrent.futures.CancelledError as ce:
if progress_tracker and progress_tracker.continue_reporting is False:
raise AssetSyncCancelledError("File download cancelled.")
Expand Down Expand Up @@ -490,6 +522,18 @@ def process_client_error(exc: ClientError, status_code: int):
# TODO: Temporary to prevent breaking backwards-compatibility; if file not found, try again without hash alg postfix
status_code = int(exc.response["ResponseMetadata"]["HTTPStatusCode"])
if status_code == 404:

def handler(bytes_downloaded):
nonlocal progress_tracker
nonlocal future

if progress_tracker:
should_continue = progress_tracker.track_progress_callback(bytes_downloaded)
if not should_continue and future is not None:
future.cancel()

subscribers = [ProgressCallbackInvoker(handler)]

s3_key = s3_key.rsplit(".", 1)[0]
future = transfer_manager.download(
bucket=s3_bucket,
Expand All @@ -499,7 +543,8 @@ def process_client_error(exc: ClientError, status_code: int):
subscribers=subscribers,
)
try:
future.result()
if future is not None:
future.result()
except concurrent.futures.CancelledError as ce:
if progress_tracker and progress_tracker.continue_reporting is False:
raise AssetSyncCancelledError("File download cancelled.")
Expand Down
5 changes: 4 additions & 1 deletion src/deadline/job_attachments/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def upload_assets(

if manifest_write_dir:
self._write_local_manifest(
manifest_write_dir, manifest_name, full_manifest_key, manifest
manifest_write_dir,
manifest_name,
full_manifest_key,
manifest,
)

self.upload_bytes_to_s3(
Expand Down
Loading

0 comments on commit 7e75eee

Please sign in to comment.