From 03d63459cf90f42fa0cec32a26d5a348f45eefd1 Mon Sep 17 00:00:00 2001 From: Godot Bian <13778003+godobyte@users.noreply.github.com> Date: Thu, 14 Nov 2024 02:48:27 +0000 Subject: [PATCH 1/2] feat: integrate with job attachment download cli as a openjd action run Signed-off-by: Godot Bian <13778003+godobyte@users.noreply.github.com> --- src/deadline_worker_agent/api_models.py | 11 +- src/deadline_worker_agent/boto/shim.py | 32 +- src/deadline_worker_agent/feature_flag.py | 5 + .../scheduler/scheduler.py | 6 +- .../scheduler/session_queue.py | 203 ++++++++--- .../sessions/actions/__init__.py | 2 + .../actions/run_attachment_download.py | 338 ++++++++++++++++++ src/deadline_worker_agent/sessions/session.py | 19 +- .../actions/test_run_attachment_download.py | 177 +++++++++ 9 files changed, 729 insertions(+), 64 deletions(-) create mode 100644 src/deadline_worker_agent/feature_flag.py create mode 100644 src/deadline_worker_agent/sessions/actions/run_attachment_download.py create mode 100644 test/unit/sessions/actions/test_run_attachment_download.py diff --git a/src/deadline_worker_agent/api_models.py b/src/deadline_worker_agent/api_models.py index 8d169d76..baa11ec7 100644 --- a/src/deadline_worker_agent/api_models.py +++ b/src/deadline_worker_agent/api_models.py @@ -46,6 +46,7 @@ EnvironmentActionType = Literal["ENV_ENTER", "ENV_EXIT"] StepActionType = Literal["TASK_RUN"] # noqa SyncInputJobAttachmentsActionType = Literal["SYNC_INPUT_JOB_ATTACHMENTS"] # noqa +AttachmentDownloadActionType = Literal["SYNC_INPUT_JOB_ATTACHMENTS"] # noqa CompletedActionStatus = Literal["SUCCEEDED", "FAILED", "INTERRUPTED", "CANCELED", "NEVER_ATTEMPTED"] @@ -87,6 +88,12 @@ class SyncInputJobAttachmentsAction(TypedDict): stepId: NotRequired[str] +class AttachmentDownloadAction(TypedDict): + sessionActionId: str + actionType: AttachmentDownloadActionType + stepId: NotRequired[str] + + class LogConfiguration(TypedDict): error: NotRequired[str] logDriver: str @@ -97,7 +104,9 @@ class LogConfiguration(TypedDict): class AssignedSession(TypedDict): queueId: str jobId: str - sessionActions: list[EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction] + sessionActions: list[ + EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction | AttachmentDownloadAction + ] logConfiguration: NotRequired[LogConfiguration] diff --git a/src/deadline_worker_agent/boto/shim.py b/src/deadline_worker_agent/boto/shim.py index cd45b864..2439ea4e 100644 --- a/src/deadline_worker_agent/boto/shim.py +++ b/src/deadline_worker_agent/boto/shim.py @@ -10,6 +10,7 @@ from boto3 import Session as _Session +from ..feature_flag import ASSET_SYNC_JOB_USER_FEATURE from ..api_models import ( AssignedSession, AssumeFleetRoleForWorkerResponse, @@ -19,6 +20,7 @@ EnvironmentAction, HostProperties, SyncInputJobAttachmentsAction, + AttachmentDownloadAction, TaskRunAction, UpdatedSessionActionInfo, UpdateWorkerResponse, @@ -172,9 +174,26 @@ def parse_sync_input_job_attachments_action( mapped_action["stepId"] = step_id return mapped_action + def parse_attachment_download_action( + action: dict, action_id: str + ) -> AttachmentDownloadAction: + mapped_action = AttachmentDownloadAction( + sessionActionId=action_id, + actionType="SYNC_INPUT_JOB_ATTACHMENTS", + ) + if step_id := action.get("stepId", None): + mapped_action["stepId"] = step_id + return mapped_action + SESSION_ACTION_MAP: dict[ str, - Callable[[Any, str], EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction], + Callable[ + [Any, str], + EnvironmentAction + | TaskRunAction + | SyncInputJobAttachmentsAction + | AttachmentDownloadAction, + ], ] = { "envEnter": lambda action, action_id: EnvironmentAction( sessionActionId=action_id, @@ -187,14 +206,21 @@ def parse_sync_input_job_attachments_action( environmentId=action["environmentId"], ), "taskRun": parse_task_run_action, - "syncInputJobAttachments": parse_sync_input_job_attachments_action, + "syncInputJobAttachments": ( + parse_sync_input_job_attachments_action + if not ASSET_SYNC_JOB_USER_FEATURE + else parse_attachment_download_action + ), } # Map the new session action structure to our internal model mapped_sessions: dict[str, AssignedSession] = {} for session_id, session in response["assignedSessions"].items(): mapped_actions: list[ - EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction + EnvironmentAction + | TaskRunAction + | SyncInputJobAttachmentsAction + | AttachmentDownloadAction ] = [] for session_action in session["sessionActions"]: assert len(session_action["definition"].items()) == 1 diff --git a/src/deadline_worker_agent/feature_flag.py b/src/deadline_worker_agent/feature_flag.py new file mode 100644 index 00000000..00abfa7b --- /dev/null +++ b/src/deadline_worker_agent/feature_flag.py @@ -0,0 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import os + +ASSET_SYNC_JOB_USER_FEATURE = os.environ.get("ASSET_SYNC_JOB_USER_FEATURE") diff --git a/src/deadline_worker_agent/scheduler/scheduler.py b/src/deadline_worker_agent/scheduler/scheduler.py index 063ce069..b814324d 100644 --- a/src/deadline_worker_agent/scheduler/scheduler.py +++ b/src/deadline_worker_agent/scheduler/scheduler.py @@ -46,6 +46,7 @@ EnvironmentAction, TaskRunAction, SyncInputJobAttachmentsAction, + AttachmentDownloadAction, ) from ..aws.deadline import ( DeadlineRequestConditionallyRecoverableError, @@ -1362,7 +1363,10 @@ def _return_sessionactions_from_stopped_session( self, *, assigned_session_actions: list[ - EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction + EnvironmentAction + | TaskRunAction + | SyncInputJobAttachmentsAction + | AttachmentDownloadAction ], failure_message: str, ) -> None: diff --git a/src/deadline_worker_agent/scheduler/session_queue.py b/src/deadline_worker_agent/scheduler/session_queue.py index 0106d498..7b55d32c 100644 --- a/src/deadline_worker_agent/scheduler/session_queue.py +++ b/src/deadline_worker_agent/scheduler/session_queue.py @@ -11,9 +11,11 @@ from openjd.model import UnsupportedSchema from openjd.sessions import ActionState, ActionStatus +from ..feature_flag import ASSET_SYNC_JOB_USER_FEATURE from ..api_models import ( EnvironmentAction as EnvironmentActionApiModel, SyncInputJobAttachmentsAction as SyncInputJobAttachmentsActionApiModel, + AttachmentDownloadAction as AttachmentDownloadActionApiModel, TaskRunAction as TaskRunActionApiModel, EntityIdentifier, EnvironmentDetailsIdentifier, @@ -29,6 +31,7 @@ RunStepTaskAction, SessionActionDefinition, SyncInputJobAttachmentsAction, + AttachmentDownloadAction, ) from .session_action_status import SessionActionStatus from ..sessions.errors import ( @@ -44,7 +47,11 @@ from ..sessions.job_entities import JobEntities D = TypeVar( - "D", EnvironmentActionApiModel, TaskRunActionApiModel, SyncInputJobAttachmentsActionApiModel + "D", + EnvironmentActionApiModel, + TaskRunActionApiModel, + SyncInputJobAttachmentsActionApiModel, + AttachmentDownloadActionApiModel, ) else: D = TypeVar("D") @@ -71,6 +78,10 @@ class SessionActionQueueEntry(Generic[D]): SyncInputJobAttachmentsStepDependenciesQueueEntry = SessionActionQueueEntry[ SyncInputJobAttachmentsActionApiModel ] +AttachmentDownloadActioQueueEntry = SessionActionQueueEntry[AttachmentDownloadActionApiModel] +AttachmentDownloadActioStepDependenciesQueueEntry = SessionActionQueueEntry[ + AttachmentDownloadActionApiModel +] CancelOutcome = Literal["FAILED", "NEVER_ATTEMPTED"] @@ -91,13 +102,17 @@ class SessionActionQueue: | TaskRunQueueEntry | SyncInputJobAttachmentsQueueEntry | SyncInputJobAttachmentsStepDependenciesQueueEntry + | AttachmentDownloadActioQueueEntry + | AttachmentDownloadActioStepDependenciesQueueEntry ] _actions_by_id: dict[ str, EnvironmentQueueEntry | TaskRunQueueEntry | SyncInputJobAttachmentsQueueEntry - | SyncInputJobAttachmentsStepDependenciesQueueEntry, + | SyncInputJobAttachmentsStepDependenciesQueueEntry + | AttachmentDownloadActioQueueEntry + | AttachmentDownloadActioStepDependenciesQueueEntry, ] _action_update_callback: Callable[[SessionActionStatus], None] _job_entities: JobEntities @@ -156,7 +171,13 @@ def list_all_action_identifiers(self) -> list[EntityIdentifier]: ), ) elif action_type == "SYNC_INPUT_JOB_ATTACHMENTS": - action_definition = cast(SyncInputJobAttachmentsActionApiModel, action_definition) + if ASSET_SYNC_JOB_USER_FEATURE: + action_definition = cast(AttachmentDownloadActionApiModel, action_definition) + else: + action_definition = cast( + SyncInputJobAttachmentsActionApiModel, action_definition + ) + if "stepId" in action_definition: identifier = StepDetailsIdentifier( stepDetails=StepDetailsIdentifierFields( @@ -273,6 +294,7 @@ def replace( EnvironmentActionApiModel | TaskRunActionApiModel | SyncInputJobAttachmentsActionApiModel + | AttachmentDownloadActionApiModel ], ) -> None: """Update the queue's actions""" @@ -281,6 +303,8 @@ def replace( | EnvironmentQueueEntry | SyncInputJobAttachmentsQueueEntry | SyncInputJobAttachmentsStepDependenciesQueueEntry + | AttachmentDownloadActioQueueEntry + | AttachmentDownloadActioStepDependenciesQueueEntry ] = [] action_ids_added = list[str]() @@ -305,17 +329,31 @@ def replace( definition=action, ) elif action_type == "SYNC_INPUT_JOB_ATTACHMENTS": - action = cast(SyncInputJobAttachmentsActionApiModel, action) - if "stepId" not in action: - queue_entry = SyncInputJobAttachmentsQueueEntry( - cancel=cancel_event, - definition=action, - ) + action = cast(AttachmentDownloadActionApiModel, action) + if ASSET_SYNC_JOB_USER_FEATURE: + action = cast(AttachmentDownloadActionApiModel, action) + if "stepId" not in action: + queue_entry = AttachmentDownloadActioQueueEntry( + cancel=cancel_event, + definition=action, + ) + else: + queue_entry = AttachmentDownloadActioStepDependenciesQueueEntry( + cancel=cancel_event, + definition=action, + ) else: - queue_entry = SyncInputJobAttachmentsStepDependenciesQueueEntry( - cancel=cancel_event, - definition=action, - ) + action = cast(SyncInputJobAttachmentsActionApiModel, action) + if "stepId" not in action: + queue_entry = SyncInputJobAttachmentsQueueEntry( + cancel=cancel_event, + definition=action, + ) + else: + queue_entry = SyncInputJobAttachmentsStepDependenciesQueueEntry( + cancel=cancel_event, + definition=action, + ) else: raise NotImplementedError(f"Unknown action type '{action_type}'") self._actions_by_id[action_id] = queue_entry @@ -441,52 +479,107 @@ def dequeue(self) -> SessionActionDefinition | None: ) elif action_type == "SYNC_INPUT_JOB_ATTACHMENTS": action_definition = action_queue_entry.definition - action_definition = cast(SyncInputJobAttachmentsActionApiModel, action_definition) - if "stepId" not in action_definition: - action_queue_entry = cast(SyncInputJobAttachmentsQueueEntry, action_queue_entry) - try: - job_attachment_details = self._job_entities.job_attachment_details() - except UnsupportedSchema as e: - raise JobEntityUnsupportedSchemaError( - action_id, SessionActionLogKind.JA_SYNC, e._version - ) from e - except ValueError as e: - raise JobAttachmentDetailsError( - action_id, SessionActionLogKind.JA_SYNC, str(e) - ) from e - next_action = SyncInputJobAttachmentsAction( - id=action_id, - session_id=self._session_id, - job_attachment_details=job_attachment_details, - ) + if ASSET_SYNC_JOB_USER_FEATURE: + action_definition = cast(AttachmentDownloadActionApiModel, action_definition) + if "stepId" not in action_definition: + action_queue_entry = cast( + AttachmentDownloadActioQueueEntry, action_queue_entry + ) + try: + job_attachment_details = self._job_entities.job_attachment_details() + except UnsupportedSchema as e: + raise JobEntityUnsupportedSchemaError( + action_id, SessionActionLogKind.JA_SYNC, e._version + ) from e + except ValueError as e: + raise JobAttachmentDetailsError( + action_id, SessionActionLogKind.JA_SYNC, str(e) + ) from e + next_action = AttachmentDownloadAction( + id=action_id, + session_id=self._session_id, + job_attachment_details=job_attachment_details, + ) + else: + action_queue_entry = cast( + AttachmentDownloadActioStepDependenciesQueueEntry, action_queue_entry + ) + + try: + step_details = self._job_entities.step_details( + step_id=action_definition["stepId"], + ) + except UnsupportedSchema as e: + raise JobEntityUnsupportedSchemaError( + action_id, + SessionActionLogKind.JA_DEP_SYNC, + e._version, + step_id=action_definition["stepId"], + ) from e + except ValueError as e: + raise StepDetailsError( + action_id, + SessionActionLogKind.JA_DEP_SYNC, + str(e), + step_id=action_definition["stepId"], + ) from e + next_action = AttachmentDownloadAction( + id=action_id, + session_id=self._session_id, + step_details=step_details, + ) + else: - action_queue_entry = cast( - SyncInputJobAttachmentsStepDependenciesQueueEntry, action_queue_entry + action_definition = cast( + SyncInputJobAttachmentsActionApiModel, action_definition ) + if "stepId" not in action_definition: + action_queue_entry = cast( + SyncInputJobAttachmentsQueueEntry, action_queue_entry + ) + try: + job_attachment_details = self._job_entities.job_attachment_details() + except UnsupportedSchema as e: + raise JobEntityUnsupportedSchemaError( + action_id, SessionActionLogKind.JA_SYNC, e._version + ) from e + except ValueError as e: + raise JobAttachmentDetailsError( + action_id, SessionActionLogKind.JA_SYNC, str(e) + ) from e + next_action = SyncInputJobAttachmentsAction( + id=action_id, + session_id=self._session_id, + job_attachment_details=job_attachment_details, + ) + else: + action_queue_entry = cast( + SyncInputJobAttachmentsStepDependenciesQueueEntry, action_queue_entry + ) - try: - step_details = self._job_entities.step_details( - step_id=action_definition["stepId"], + try: + step_details = self._job_entities.step_details( + step_id=action_definition["stepId"], + ) + except UnsupportedSchema as e: + raise JobEntityUnsupportedSchemaError( + action_id, + SessionActionLogKind.JA_DEP_SYNC, + e._version, + step_id=action_definition["stepId"], + ) from e + except ValueError as e: + raise StepDetailsError( + action_id, + SessionActionLogKind.JA_DEP_SYNC, + str(e), + step_id=action_definition["stepId"], + ) from e + next_action = SyncInputJobAttachmentsAction( + id=action_id, + session_id=self._session_id, + step_details=step_details, ) - except UnsupportedSchema as e: - raise JobEntityUnsupportedSchemaError( - action_id, - SessionActionLogKind.JA_DEP_SYNC, - e._version, - step_id=action_definition["stepId"], - ) from e - except ValueError as e: - raise StepDetailsError( - action_id, - SessionActionLogKind.JA_DEP_SYNC, - str(e), - step_id=action_definition["stepId"], - ) from e - next_action = SyncInputJobAttachmentsAction( - id=action_id, - session_id=self._session_id, - step_details=step_details, - ) else: raise ValueError( f'Unknown action type "{action_type}". Complete action = {action_definition}' diff --git a/src/deadline_worker_agent/sessions/actions/__init__.py b/src/deadline_worker_agent/sessions/actions/__init__.py index 75291030..474d09eb 100644 --- a/src/deadline_worker_agent/sessions/actions/__init__.py +++ b/src/deadline_worker_agent/sessions/actions/__init__.py @@ -6,6 +6,7 @@ from .openjd_action import OpenjdAction from .run_step_task import RunStepTaskAction from .sync_input_job_attachments import SyncInputJobAttachmentsAction +from .run_attachment_download import AttachmentDownloadAction __all__ = [ "EnterEnvironmentAction", @@ -14,4 +15,5 @@ "RunStepTaskAction", "SessionActionDefinition", "SyncInputJobAttachmentsAction", + "AttachmentDownloadAction", ] diff --git a/src/deadline_worker_agent/sessions/actions/run_attachment_download.py b/src/deadline_worker_agent/sessions/actions/run_attachment_download.py new file mode 100644 index 00000000..093ee714 --- /dev/null +++ b/src/deadline_worker_agent/sessions/actions/run_attachment_download.py @@ -0,0 +1,338 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations +from concurrent.futures import ( + Executor, +) +import os +from pathlib import Path +import sys +import json +from shlex import quote +from logging import getLogger, LoggerAdapter +import sysconfig +from typing import Any, TYPE_CHECKING, Optional +from dataclasses import asdict + +from deadline.job_attachments.asset_manifests import BaseAssetManifest +from deadline.job_attachments.models import ( + Attachments, + PathFormat, + JobAttachmentS3Settings, + ManifestProperties, + PathMappingRule, + JobAttachmentsFileSystem, +) +from deadline.job_attachments.os_file_permission import ( + FileSystemPermissionSettings, + PosixFileSystemPermissionSettings, + WindowsFileSystemPermissionSettings, + WindowsPermissionEnum, +) + +from openjd.sessions import ( + LOG as OPENJD_LOG, + PathMappingRule as OpenjdPathMapping, + PosixSessionUser, + WindowsSessionUser, +) +from openjd.model.v2023_09 import ( + EmbeddedFileTypes as EmbeddedFileTypes_2023_09, + EmbeddedFileText as EmbeddedFileText_2023_09, + Action as Action_2023_09, + StepScript as StepScript_2023_09, + StepActions as StepActions_2023_09, +) +from openjd.model import ParameterValue + +from ..session import Session +from ...log_messages import SessionActionLogKind + + +from .openjd_action import OpenjdAction + +if TYPE_CHECKING: + from concurrent.futures import Future + from ..session import Session + from ..job_entities import JobAttachmentDetails, StepDetails + + +logger = getLogger(__name__) + + +class SyncCanceled(Exception): + """Exception indicating the synchronization was canceled""" + + pass + + +class AttachmentDownloadAction(OpenjdAction): + """Action to synchronize input job attachments for a AWS Deadline Cloud job + + Parameters + ---------- + id : str + The unique action identifier + """ + + _future: Future[None] + _job_attachment_details: Optional[JobAttachmentDetails] + _step_details: Optional[StepDetails] + _step_script: Optional[StepScript_2023_09] + + def __init__( + self, + *, + id: str, + session_id: str, + job_attachment_details: Optional[JobAttachmentDetails] = None, + step_details: Optional[StepDetails] = None, + ) -> None: + super(AttachmentDownloadAction, self).__init__( + id=id, + action_log_kind=( + SessionActionLogKind.JA_SYNC + if step_details is None + else SessionActionLogKind.JA_DEP_SYNC + ), + step_id=step_details.step_id if step_details is not None else None, + ) + self._job_attachment_details = job_attachment_details + self._step_details = step_details + self._logger = LoggerAdapter(OPENJD_LOG, extra={"session_id": session_id}) + + def set_step_script(self, manifests, path_mapping, s3_settings) -> None: + profile = os.environ.get("AWS_PROFILE") + deadline_path = os.path.join(Path(sysconfig.get_path("scripts")), "deadline") + + script = "#!/usr/bin/env bash\n\n{} attachment download -m {} --path-mapping-rules {} --s3-root-uri {} --profile {}".format( + deadline_path, + " -m ".join(quote(v) for v in manifests), + quote(path_mapping), + s3_settings.to_s3_root_uri(), + profile, + ) + + self._step_script = StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09(command="{{ Task.File.AttachmentDownload }}") + ), + embeddedFiles=[ + EmbeddedFileText_2023_09( + name="AttachmentDownload", + type=EmbeddedFileTypes_2023_09.TEXT, + runnable=True, + data=script, + ) + ], + ) + + def __eq__(self, other: Any) -> bool: + return ( + type(self) is type(other) + and self._id == other._id + and self._job_attachment_details == other._job_attachment_details + and self._step_details == other._step_details + ) + + def start( + self, + *, + session: Session, + executor: Executor, + ) -> None: + """Initiates the synchronization of the input job attachments + + Parameters + ---------- + session : Session + The Session that is the target of the action + executor : Executor + An executor for running futures + """ + + self._logger.info(f"Syncing inputs using session {session}") + + if self._step_details: + section_title = "Job Attachments Download for Step" + else: + section_title = "Job Attachments Download for Job" + + # Banner mimicing the one printed by the openjd-sessions runtime + self._logger.info("==============================================") + self._logger.info(f"--------- AttachmentDownloadAction {section_title}") + self._logger.info("==============================================") + + if not (job_attachment_settings := session._job_details.job_attachment_settings): + raise RuntimeError("Job attachment settings were not contained in JOB_DETAILS entity") + + if self._job_attachment_details: + session._job_attachment_details = self._job_attachment_details + + # Validate that job attachment details have been provided before syncing step dependencies. + if session._job_attachment_details is None: + raise RuntimeError( + "Job attachments must be synchronized before downloading Step dependencies." + ) + step_dependencies = self._step_details.dependencies if self._step_details else [] + + assert job_attachment_settings.s3_bucket_name is not None + assert job_attachment_settings.root_prefix is not None + assert session._asset_sync is not None + + s3_settings = JobAttachmentS3Settings( + s3BucketName=job_attachment_settings.s3_bucket_name, + rootPrefix=job_attachment_settings.root_prefix, + ) + + manifest_properties_list: list[ManifestProperties] = [] + if not step_dependencies: + for manifest_properties in session._job_attachment_details.manifests: + manifest_properties_list.append( + ManifestProperties( + rootPath=manifest_properties.root_path, + fileSystemLocationName=manifest_properties.file_system_location_name, + rootPathFormat=PathFormat(manifest_properties.root_path_format), + inputManifestPath=manifest_properties.input_manifest_path, + inputManifestHash=manifest_properties.input_manifest_hash, + outputRelativeDirectories=manifest_properties.output_relative_directories, + ) + ) + + attachments = Attachments( + manifests=manifest_properties_list, + fileSystem=session._job_attachment_details.job_attachments_file_system, + ) + + storage_profiles_path_mapping_rules_dict: dict[str, str] = { + str(rule.source_path): str(rule.destination_path) + for rule in session._job_details.path_mapping_rules + } + + # Generate absolute Path Mapping to local session (no storage profile) + # returns root path to PathMappingRule mapping + dynamic_mapping_rules: dict[str, PathMappingRule] = ( + session._asset_sync.generate_dynamic_path_mapping( + session_dir=session._session.working_directory, + attachments=attachments, + ) + ) + + # Aggregate manifests (with step step dependency handling) + merged_manifests_by_root: dict[str, BaseAssetManifest] = ( + session._asset_sync._aggregate_asset_root_manifests( + session_dir=session._session.working_directory, + s3_settings=s3_settings, + queue_id=session._queue_id, + job_id=session._queue._job_id, + attachments=attachments, + step_dependencies=step_dependencies, + dynamic_mapping_rules=dynamic_mapping_rules, + storage_profiles_path_mapping_rules=storage_profiles_path_mapping_rules_dict, + ) + ) + + vfs_handled = self._vfs_handling( + session=session, + attachments=attachments, + merged_manifests_by_root=merged_manifests_by_root, + s3_settings=s3_settings, + ) + + if not vfs_handled: + job_attachment_path_mappings = list([asdict(r) for r in dynamic_mapping_rules.values()]) + + # Open Job Description session implementation details -- path mappings are sorted. + # bisect.insort only supports the 'key' arg in 3.10 or later, so + # we first extend the list and sort it afterwards. + if session._session._path_mapping_rules: + session._session._path_mapping_rules.extend( + OpenjdPathMapping.from_dict(r) for r in job_attachment_path_mappings + ) + else: + session._session._path_mapping_rules = [ + OpenjdPathMapping.from_dict(r) for r in job_attachment_path_mappings + ] + + # Open Job Description Sessions sort the path mapping rules based on length of the parts make + # rules that are subsets of each other behave in a predictable manner. We must + # sort here since we're modifying that internal list appending to the list. + session._session._path_mapping_rules.sort(key=lambda rule: -len(rule.source_path.parts)) + + # =========================== TO BE DELETED =========================== + path_mapping_file_path: str = os.path.join( + session._session.working_directory, "path_mapping" + ) + for rule in job_attachment_path_mappings: + rule["source_path"] = rule["destination_path"] + + with open(path_mapping_file_path, "w", encoding="utf8") as f: + f.write(json.dumps([rule for rule in job_attachment_path_mappings])) + # =========================== TO BE DELETED =========================== + + manifest_paths = session._asset_sync._check_and_write_local_manifests( + merged_manifests_by_root=merged_manifests_by_root, + manifest_write_dir=str(session._session.working_directory), + ) + + self.set_step_script( + manifests=manifest_paths, + path_mapping=path_mapping_file_path, + s3_settings=s3_settings, + ) + assert self._step_script is not None + session.run_task( + step_script=self._step_script, + task_parameter_values=dict[str, ParameterValue](), + ) + + def _vfs_handling( + self, + session: Session, + attachments: Attachments, + merged_manifests_by_root: dict[str, BaseAssetManifest], + s3_settings: JobAttachmentS3Settings, + ) -> bool: + fs_permission_settings: Optional[FileSystemPermissionSettings] = None + if session._os_user is not None: + if os.name == "posix": + if not isinstance(session._os_user, PosixSessionUser): + raise ValueError(f"The user must be a posix-user. Got {type(session._os_user)}") + fs_permission_settings = PosixFileSystemPermissionSettings( + os_user=session._os_user.user, + os_group=session._os_user.group, + dir_mode=0o20, + file_mode=0o20, + ) + else: + if not isinstance(session._os_user, WindowsSessionUser): + raise ValueError( + f"The user must be a windows-user. Got {type(session._os_user)}" + ) + if session._os_user is not None: + fs_permission_settings = WindowsFileSystemPermissionSettings( + os_user=session._os_user.user, + dir_mode=WindowsPermissionEnum.WRITE, + file_mode=WindowsPermissionEnum.WRITE, + ) + + if ( + attachments.fileSystem == JobAttachmentsFileSystem.VIRTUAL.value + and sys.platform != "win32" + and fs_permission_settings is not None + and os.environ is not None + and "AWS_PROFILE" in os.environ + and isinstance(fs_permission_settings, PosixFileSystemPermissionSettings) + ): + assert session._asset_sync is not None + session._asset_sync._launch_vfs( + s3_settings=s3_settings, + session_dir=session._session.working_directory, + fs_permission_settings=fs_permission_settings, + merged_manifests_by_root=merged_manifests_by_root, + os_env_vars=dict(os.environ), + ) + return True + + else: + return False diff --git a/src/deadline_worker_agent/sessions/session.py b/src/deadline_worker_agent/sessions/session.py index 79eb7302..68be6b95 100644 --- a/src/deadline_worker_agent/sessions/session.py +++ b/src/deadline_worker_agent/sessions/session.py @@ -29,6 +29,7 @@ from deadline_worker_agent.api_models import ( EntityIdentifier, SyncInputJobAttachmentsAction, + AttachmentDownloadAction, ) if TYPE_CHECKING: @@ -321,7 +322,7 @@ def wait(self, timeout: timedelta | None = None) -> None: self._stopped_running.wait(timeout=timeout.seconds if timeout else None) def _run(self) -> None: - """The contains the main run loop for processing session actions. + """The function contains the main run loop for processing session actions. This code will loop until Session.stop() is called from another thread. """ @@ -447,7 +448,12 @@ def _cleanup(self) -> None: def replace_assigned_actions( self, *, - actions: Iterable[EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction], + actions: Iterable[ + EnvironmentAction + | TaskRunAction + | SyncInputJobAttachmentsAction + | AttachmentDownloadAction + ], ) -> None: """Replaces the assigned actions @@ -461,7 +467,7 @@ def replace_assigned_actions( Parameters ---------- - actions : Iterable[EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction] + actions : Iterable[EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction | AttachmentDownloadAction] The new sequence of actions to be assigned to the session. The order of the actions provided is used as the processing order """ @@ -471,7 +477,12 @@ def replace_assigned_actions( def _replace_assigned_actions_impl( self, *, - actions: Iterable[EnvironmentAction | TaskRunAction | SyncInputJobAttachmentsAction], + actions: Iterable[ + EnvironmentAction + | TaskRunAction + | SyncInputJobAttachmentsAction + | AttachmentDownloadAction + ], ) -> None: """Replaces the assigned actions diff --git a/test/unit/sessions/actions/test_run_attachment_download.py b/test/unit/sessions/actions/test_run_attachment_download.py new file mode 100644 index 00000000..c378eab0 --- /dev/null +++ b/test/unit/sessions/actions/test_run_attachment_download.py @@ -0,0 +1,177 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations +from pathlib import Path +import os +import sysconfig +import tempfile +from typing import TYPE_CHECKING, Generator +from unittest.mock import MagicMock, Mock, patch, ANY + +import pytest + +from deadline_worker_agent.sessions.actions import AttachmentDownloadAction +from deadline_worker_agent.sessions.job_entities.job_details import JobDetails +from openjd.sessions import SessionUser +from openjd.model import ParameterValue +from openjd.model.v2023_09 import ( + EmbeddedFileTypes as EmbeddedFileTypes_2023_09, + EmbeddedFileText as EmbeddedFileText_2023_09, + Action as Action_2023_09, + StepScript as StepScript_2023_09, + StepActions as StepActions_2023_09, +) + +import deadline_worker_agent.sessions.session as session_mod +from deadline.job_attachments.models import JobAttachmentS3Settings + +if TYPE_CHECKING: + from deadline_worker_agent.sessions.job_entities import JobAttachmentDetails + + +@pytest.fixture +def executor() -> Mock: + return Mock() + + +@pytest.fixture +def session_id() -> str: + return "session_id" + + +@pytest.fixture +def asset_sync() -> MagicMock: + """A fixture returning a Mock to be passed in place of a deadline.job_attachments. AssetSync + instance when creating the Worker Agent Session instance""" + return MagicMock() + + +@pytest.fixture +def session_dir(session_id: str): + with tempfile.TemporaryDirectory() as tmpdir_path: + session_dir: str = os.path.join(tmpdir_path, session_id) + os.makedirs(session_dir) + yield session_dir + + +@pytest.fixture +def mock_openjd_session_cls(session_dir: str) -> Generator[MagicMock, None, None]: + """Mocks the Worker Agent Session module's import of the Open Job Description Session class""" + with patch.object(session_mod, "OPENJDSession") as mock_openjd_session: + mock_openjd_session.working_directory = session_dir + yield mock_openjd_session + + +@pytest.fixture +def action_id() -> str: + return "sessionaction-abc123" + + +@pytest.fixture +def action( + action_id: str, + job_attachment_details: JobAttachmentDetails, +) -> AttachmentDownloadAction: + return AttachmentDownloadAction( + id=action_id, + session_id="session-1234", + job_attachment_details=job_attachment_details, + ) + + +class TestStart: + """Tests for AttachmentDownloadAction.start()""" + + QUEUE_ID = "queue-test" + JOB_ID = "job-test" + + @pytest.fixture + def session( + self, + session_id: str, + job_details: JobDetails, + job_user: SessionUser, + job_attachment_details: JobAttachmentDetails, + mock_openjd_session_cls: Mock, + ) -> Mock: + session = Mock() + session.id = session_id + session._job_details = job_details + session._job_attachment_details = job_attachment_details + session._os_user = job_user + session._session = mock_openjd_session_cls + session._queue_id = TestStart.QUEUE_ID + session._queue._job_id = TestStart.JOB_ID + return session + + @pytest.fixture(autouse=True) + def mock_asset_sync(self, session: Mock) -> Generator[MagicMock, None, None]: + with patch.object(session, "_asset_sync") as mock_asset_sync: + yield mock_asset_sync + + def test_attachment_download_action_start( + self, + executor: Mock, + session: Mock, + action: AttachmentDownloadAction, + session_dir: str, + mock_asset_sync: MagicMock, + job_details: JobDetails, + ) -> None: + """ + Tests that AttachmentDownloadAction.start() calls AssetSync functions to prepare input + for constructing step script to run openjd action + """ + # WHEN + action.start(session=session, executor=executor) + + # THEN + assert job_details.job_attachment_settings is not None + assert job_details.job_attachment_settings.s3_bucket_name is not None + assert job_details.job_attachment_settings.root_prefix is not None + + mock_asset_sync._aggregate_asset_root_manifests.assert_called_once_with( + session_dir=session_dir, + s3_settings=JobAttachmentS3Settings( + s3BucketName=job_details.job_attachment_settings.s3_bucket_name, + rootPrefix=job_details.job_attachment_settings.root_prefix, + ), + queue_id=TestStart.QUEUE_ID, + job_id=TestStart.JOB_ID, + attachments=ANY, + step_dependencies=[], + dynamic_mapping_rules=ANY, + storage_profiles_path_mapping_rules={}, + ) + mock_asset_sync.generate_dynamic_path_mapping.assert_called_once_with( + session_dir=session_dir, + attachments=ANY, + ) + mock_asset_sync._check_and_write_local_manifests.assert_called_once_with( + merged_manifests_by_root=ANY, + manifest_write_dir=session_dir, + ) + + assert action._step_script == StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09(command="{{ Task.File.AttachmentDownload }}") + ), + embeddedFiles=[ + EmbeddedFileText_2023_09( + name="AttachmentDownload", + type=EmbeddedFileTypes_2023_09.TEXT, + runnable=True, + data="#!/usr/bin/env bash\n\n{} attachment download -m {} --path-mapping-rules {} --s3-root-uri {} --profile {}".format( + os.path.join(Path(sysconfig.get_path("scripts")), "deadline"), + " -m ".join([]), + os.path.join(session_dir, "path_mapping"), + "s3://job_attachments_bucket/job_attachments", + None, + ), + ) + ], + ) + session.run_task.assert_called_once_with( + step_script=action._step_script, + task_parameter_values=dict[str, ParameterValue](), + ) From 3ea7c45a300257a82f732153a4f162cf0645cb06 Mon Sep 17 00:00:00 2001 From: Godot Bian <13778003+godobyte@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:45:45 +0000 Subject: [PATCH 2/2] address comments Signed-off-by: Godot Bian <13778003+godobyte@users.noreply.github.com> --- src/deadline_worker_agent/feature_flag.py | 6 +- .../scheduler/session_queue.py | 24 ++-- .../actions/run_attachment_download.py | 124 +++++++++--------- .../actions/test_run_attachment_download.py | 13 +- 4 files changed, 88 insertions(+), 79 deletions(-) diff --git a/src/deadline_worker_agent/feature_flag.py b/src/deadline_worker_agent/feature_flag.py index 00abfa7b..5e192761 100644 --- a/src/deadline_worker_agent/feature_flag.py +++ b/src/deadline_worker_agent/feature_flag.py @@ -2,4 +2,8 @@ import os -ASSET_SYNC_JOB_USER_FEATURE = os.environ.get("ASSET_SYNC_JOB_USER_FEATURE") +# This feature is still a work-in-progress and untested on Windows + +ASSET_SYNC_JOB_USER_FEATURE = ( + os.environ.get("ASSET_SYNC_JOB_USER_FEATURE", "false").lower() == "true" +) diff --git a/src/deadline_worker_agent/scheduler/session_queue.py b/src/deadline_worker_agent/scheduler/session_queue.py index 7b55d32c..9132d0fd 100644 --- a/src/deadline_worker_agent/scheduler/session_queue.py +++ b/src/deadline_worker_agent/scheduler/session_queue.py @@ -78,8 +78,8 @@ class SessionActionQueueEntry(Generic[D]): SyncInputJobAttachmentsStepDependenciesQueueEntry = SessionActionQueueEntry[ SyncInputJobAttachmentsActionApiModel ] -AttachmentDownloadActioQueueEntry = SessionActionQueueEntry[AttachmentDownloadActionApiModel] -AttachmentDownloadActioStepDependenciesQueueEntry = SessionActionQueueEntry[ +AttachmentDownloadActionQueueEntry = SessionActionQueueEntry[AttachmentDownloadActionApiModel] +AttachmentDownloadActionStepDependenciesQueueEntry = SessionActionQueueEntry[ AttachmentDownloadActionApiModel ] CancelOutcome = Literal["FAILED", "NEVER_ATTEMPTED"] @@ -102,8 +102,8 @@ class SessionActionQueue: | TaskRunQueueEntry | SyncInputJobAttachmentsQueueEntry | SyncInputJobAttachmentsStepDependenciesQueueEntry - | AttachmentDownloadActioQueueEntry - | AttachmentDownloadActioStepDependenciesQueueEntry + | AttachmentDownloadActionQueueEntry + | AttachmentDownloadActionStepDependenciesQueueEntry ] _actions_by_id: dict[ str, @@ -111,8 +111,8 @@ class SessionActionQueue: | TaskRunQueueEntry | SyncInputJobAttachmentsQueueEntry | SyncInputJobAttachmentsStepDependenciesQueueEntry - | AttachmentDownloadActioQueueEntry - | AttachmentDownloadActioStepDependenciesQueueEntry, + | AttachmentDownloadActionQueueEntry + | AttachmentDownloadActionStepDependenciesQueueEntry, ] _action_update_callback: Callable[[SessionActionStatus], None] _job_entities: JobEntities @@ -303,8 +303,8 @@ def replace( | EnvironmentQueueEntry | SyncInputJobAttachmentsQueueEntry | SyncInputJobAttachmentsStepDependenciesQueueEntry - | AttachmentDownloadActioQueueEntry - | AttachmentDownloadActioStepDependenciesQueueEntry + | AttachmentDownloadActionQueueEntry + | AttachmentDownloadActionStepDependenciesQueueEntry ] = [] action_ids_added = list[str]() @@ -333,12 +333,12 @@ def replace( if ASSET_SYNC_JOB_USER_FEATURE: action = cast(AttachmentDownloadActionApiModel, action) if "stepId" not in action: - queue_entry = AttachmentDownloadActioQueueEntry( + queue_entry = AttachmentDownloadActionQueueEntry( cancel=cancel_event, definition=action, ) else: - queue_entry = AttachmentDownloadActioStepDependenciesQueueEntry( + queue_entry = AttachmentDownloadActionStepDependenciesQueueEntry( cancel=cancel_event, definition=action, ) @@ -483,7 +483,7 @@ def dequeue(self) -> SessionActionDefinition | None: action_definition = cast(AttachmentDownloadActionApiModel, action_definition) if "stepId" not in action_definition: action_queue_entry = cast( - AttachmentDownloadActioQueueEntry, action_queue_entry + AttachmentDownloadActionQueueEntry, action_queue_entry ) try: job_attachment_details = self._job_entities.job_attachment_details() @@ -502,7 +502,7 @@ def dequeue(self) -> SessionActionDefinition | None: ) else: action_queue_entry = cast( - AttachmentDownloadActioStepDependenciesQueueEntry, action_queue_entry + AttachmentDownloadActionStepDependenciesQueueEntry, action_queue_entry ) try: diff --git a/src/deadline_worker_agent/sessions/actions/run_attachment_download.py b/src/deadline_worker_agent/sessions/actions/run_attachment_download.py index 093ee714..e81d602f 100644 --- a/src/deadline_worker_agent/sessions/actions/run_attachment_download.py +++ b/src/deadline_worker_agent/sessions/actions/run_attachment_download.py @@ -9,7 +9,7 @@ import sys import json from shlex import quote -from logging import getLogger, LoggerAdapter +from logging import LoggerAdapter import sysconfig from typing import Any, TYPE_CHECKING, Optional from dataclasses import asdict @@ -32,6 +32,7 @@ from openjd.sessions import ( LOG as OPENJD_LOG, + LogContent, PathMappingRule as OpenjdPathMapping, PosixSessionUser, WindowsSessionUser, @@ -52,22 +53,12 @@ from .openjd_action import OpenjdAction if TYPE_CHECKING: - from concurrent.futures import Future from ..session import Session from ..job_entities import JobAttachmentDetails, StepDetails -logger = getLogger(__name__) - - -class SyncCanceled(Exception): - """Exception indicating the synchronization was canceled""" - - pass - - class AttachmentDownloadAction(OpenjdAction): - """Action to synchronize input job attachments for a AWS Deadline Cloud job + """Action to synchronize input job attachments for a AWS Deadline Cloud Session Parameters ---------- @@ -75,7 +66,6 @@ class AttachmentDownloadAction(OpenjdAction): The unique action identifier """ - _future: Future[None] _job_attachment_details: Optional[JobAttachmentDetails] _step_details: Optional[StepDetails] _step_script: Optional[StepScript_2023_09] @@ -102,6 +92,7 @@ def __init__( self._logger = LoggerAdapter(OPENJD_LOG, extra={"session_id": session_id}) def set_step_script(self, manifests, path_mapping, s3_settings) -> None: + # TODO - update to run python as embedded file profile = os.environ.get("AWS_PROFILE") deadline_path = os.path.join(Path(sysconfig.get_path("scripts")), "deadline") @@ -151,17 +142,25 @@ def start( An executor for running futures """ - self._logger.info(f"Syncing inputs using session {session}") - if self._step_details: section_title = "Job Attachments Download for Step" else: section_title = "Job Attachments Download for Job" # Banner mimicing the one printed by the openjd-sessions runtime - self._logger.info("==============================================") - self._logger.info(f"--------- AttachmentDownloadAction {section_title}") - self._logger.info("==============================================") + # TODO - Consider a better approach to manage the banner title + self._logger.info( + "==============================================", + extra={"openjd_log_content": LogContent.BANNER}, + ) + self._logger.info( + f"--------- AttachmentDownloadAction {section_title}", + extra={"openjd_log_content": LogContent.BANNER}, + ) + self._logger.info( + "==============================================", + extra={"openjd_log_content": LogContent.BANNER}, + ) if not (job_attachment_settings := session._job_details.job_attachment_settings): raise RuntimeError("Job attachment settings were not contained in JOB_DETAILS entity") @@ -232,61 +231,62 @@ def start( ) ) - vfs_handled = self._vfs_handling( + if self._start_vfs( session=session, attachments=attachments, merged_manifests_by_root=merged_manifests_by_root, s3_settings=s3_settings, - ) + ): + # successfully launched VFS + return - if not vfs_handled: - job_attachment_path_mappings = list([asdict(r) for r in dynamic_mapping_rules.values()]) + job_attachment_path_mappings = list([asdict(r) for r in dynamic_mapping_rules.values()]) - # Open Job Description session implementation details -- path mappings are sorted. - # bisect.insort only supports the 'key' arg in 3.10 or later, so - # we first extend the list and sort it afterwards. - if session._session._path_mapping_rules: - session._session._path_mapping_rules.extend( - OpenjdPathMapping.from_dict(r) for r in job_attachment_path_mappings - ) - else: - session._session._path_mapping_rules = [ - OpenjdPathMapping.from_dict(r) for r in job_attachment_path_mappings - ] - - # Open Job Description Sessions sort the path mapping rules based on length of the parts make - # rules that are subsets of each other behave in a predictable manner. We must - # sort here since we're modifying that internal list appending to the list. - session._session._path_mapping_rules.sort(key=lambda rule: -len(rule.source_path.parts)) - - # =========================== TO BE DELETED =========================== - path_mapping_file_path: str = os.path.join( - session._session.working_directory, "path_mapping" + # Open Job Description session implementation details -- path mappings are sorted. + # bisect.insort only supports the 'key' arg in 3.10 or later, so + # we first extend the list and sort it afterwards. + if session._session._path_mapping_rules: + session._session._path_mapping_rules.extend( + OpenjdPathMapping.from_dict(r) for r in job_attachment_path_mappings ) - for rule in job_attachment_path_mappings: - rule["source_path"] = rule["destination_path"] + else: + session._session._path_mapping_rules = [ + OpenjdPathMapping.from_dict(r) for r in job_attachment_path_mappings + ] + + # Open Job Description Sessions sort the path mapping rules based on length of the parts make + # rules that are subsets of each other behave in a predictable manner. We must + # sort here since we're modifying that internal list appending to the list. + session._session._path_mapping_rules.sort(key=lambda rule: -len(rule.source_path.parts)) + + # =========================== TO BE DELETED =========================== + path_mapping_file_path: str = os.path.join( + session._session.working_directory, "path_mapping" + ) + for rule in job_attachment_path_mappings: + rule["source_path"] = rule["destination_path"] - with open(path_mapping_file_path, "w", encoding="utf8") as f: - f.write(json.dumps([rule for rule in job_attachment_path_mappings])) - # =========================== TO BE DELETED =========================== + with open(path_mapping_file_path, "w", encoding="utf8") as f: + f.write(json.dumps([rule for rule in job_attachment_path_mappings])) + # =========================== TO BE DELETED =========================== - manifest_paths = session._asset_sync._check_and_write_local_manifests( - merged_manifests_by_root=merged_manifests_by_root, - manifest_write_dir=str(session._session.working_directory), - ) + manifest_paths = session._asset_sync._check_and_write_local_manifests( + merged_manifests_by_root=merged_manifests_by_root, + manifest_write_dir=str(session._session.working_directory), + ) - self.set_step_script( - manifests=manifest_paths, - path_mapping=path_mapping_file_path, - s3_settings=s3_settings, - ) - assert self._step_script is not None - session.run_task( - step_script=self._step_script, - task_parameter_values=dict[str, ParameterValue](), - ) + self.set_step_script( + manifests=manifest_paths, + path_mapping=path_mapping_file_path, + s3_settings=s3_settings, + ) + assert self._step_script is not None + session.run_task( + step_script=self._step_script, + task_parameter_values=dict[str, ParameterValue](), + ) - def _vfs_handling( + def _start_vfs( self, session: Session, attachments: Attachments, diff --git a/test/unit/sessions/actions/test_run_attachment_download.py b/test/unit/sessions/actions/test_run_attachment_download.py index c378eab0..b6b4f9ae 100644 --- a/test/unit/sessions/actions/test_run_attachment_download.py +++ b/test/unit/sessions/actions/test_run_attachment_download.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path import os +import sys import sysconfig import tempfile from typing import TYPE_CHECKING, Generator @@ -109,6 +110,10 @@ def mock_asset_sync(self, session: Mock) -> Generator[MagicMock, None, None]: with patch.object(session, "_asset_sync") as mock_asset_sync: yield mock_asset_sync + @pytest.mark.skipif( + sys.platform == "win32", + reason="Failed in windows due to embeddedFiles.data quotation mark, which will be replaced by python embedded file soon.", + ) def test_attachment_download_action_start( self, executor: Mock, @@ -122,14 +127,14 @@ def test_attachment_download_action_start( Tests that AttachmentDownloadAction.start() calls AssetSync functions to prepare input for constructing step script to run openjd action """ - # WHEN - action.start(session=session, executor=executor) - - # THEN + # GIVEN assert job_details.job_attachment_settings is not None assert job_details.job_attachment_settings.s3_bucket_name is not None assert job_details.job_attachment_settings.root_prefix is not None + # WHEN + action.start(session=session, executor=executor) + mock_asset_sync._aggregate_asset_root_manifests.assert_called_once_with( session_dir=session_dir, s3_settings=JobAttachmentS3Settings(