diff --git a/src/deadline_test_fixtures/__init__.py b/src/deadline_test_fixtures/__init__.py index 83ecc60..28a0222 100644 --- a/src/deadline_test_fixtures/__init__.py +++ b/src/deadline_test_fixtures/__init__.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .deadline import ( + CloudWatchLogEvent, CommandResult, DeadlineClient, DeadlineWorker, @@ -10,6 +11,7 @@ Farm, Fleet, PipInstall, + PosixUser, Queue, QueueFleetAssociation, TaskStatus, @@ -34,6 +36,7 @@ __all__ = [ "BootstrapResources", + "CloudWatchLogEvent", "CodeArtifactRepositoryInfo", "CommandResult", "DeadlineResources", @@ -51,6 +54,7 @@ "JobAttachmentSettings", "JobAttachmentManager", "PipInstall", + "PosixUser", "S3Object", "ServiceModel", "StubDeadlineClient", diff --git a/src/deadline_test_fixtures/deadline/__init__.py b/src/deadline_test_fixtures/deadline/__init__.py index a2422a2..2179b4e 100644 --- a/src/deadline_test_fixtures/deadline/__init__.py +++ b/src/deadline_test_fixtures/deadline/__init__.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .resources import ( + CloudWatchLogEvent, Farm, Fleet, Job, @@ -16,9 +17,11 @@ DockerContainerWorker, EC2InstanceWorker, PipInstall, + PosixUser, ) __all__ = [ + "CloudWatchLogEvent", "CommandResult", "DeadlineClient", "DeadlineWorker", @@ -29,6 +32,7 @@ "Fleet", "Job", "PipInstall", + "PosixUser", "Queue", "QueueFleetAssociation", "TaskStatus", diff --git a/src/deadline_test_fixtures/deadline/resources.py b/src/deadline_test_fixtures/deadline/resources.py index 4e97973..9684da6 100644 --- a/src/deadline_test_fixtures/deadline/resources.py +++ b/src/deadline_test_fixtures/deadline/resources.py @@ -8,6 +8,8 @@ from enum import Enum from typing import Any, Callable, Literal +from botocore.client import BaseClient + from .client import DeadlineClient from ..models import JobAttachmentSettings from ..util import call_api, clean_kwargs, wait_for @@ -24,21 +26,23 @@ def create( *, client: DeadlineClient, display_name: str, + raw_kwargs: dict | None = None, ) -> Farm: response = call_api( description=f"Create farm {display_name}", fn=lambda: client.create_farm( displayName=display_name, + **(raw_kwargs or {}), ), ) farm_id = response["farmId"] LOG.info(f"Created farm: {farm_id}") return Farm(id=farm_id) - def delete(self, *, client: DeadlineClient) -> None: + def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None: call_api( description=f"Delete farm {self.id}", - fn=lambda: client.delete_farm(farmId=self.id), + fn=lambda: client.delete_farm(farmId=self.id, **(raw_kwargs or {})), ) @@ -55,6 +59,7 @@ def create( farm: Farm, role_arn: str | None = None, job_attachments: JobAttachmentSettings | None = None, + raw_kwargs: dict | None = None, ) -> Queue: kwargs = clean_kwargs( { @@ -64,6 +69,7 @@ def create( "jobAttachmentSettings": ( job_attachments.as_queue_settings() if job_attachments else None ), + **(raw_kwargs or {}), } ) @@ -79,10 +85,12 @@ def create( farm=farm, ) - def delete(self, *, client: DeadlineClient) -> None: + def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None: call_api( description=f"Delete queue {self.id}", - fn=lambda: client.delete_queue(queueId=self.id, farmId=self.farm.id), + fn=lambda: client.delete_queue( + queueId=self.id, farmId=self.farm.id, **(raw_kwargs or {}) + ), ) @@ -99,6 +107,7 @@ def create( farm: Farm, configuration: dict, role_arn: str | None = None, + raw_kwargs: dict | None = None, ) -> Fleet: kwargs = clean_kwargs( { @@ -106,6 +115,7 @@ def create( "displayName": display_name, "roleArn": role_arn, "configuration": configuration, + **(raw_kwargs or {}), } ) response = call_api( @@ -127,12 +137,13 @@ def create( return fleet - def delete(self, *, client: DeadlineClient) -> None: + def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None: call_api( description=f"Delete fleet {self.id}", fn=lambda: client.delete_fleet( farmId=self.farm.id, fleetId=self.id, + **(raw_kwargs or {}), ), ) @@ -184,6 +195,7 @@ def create( farm: Farm, queue: Queue, fleet: Fleet, + raw_kwargs: dict | None = None, ) -> QueueFleetAssociation: call_api( description=f"Create queue-fleet association for queue {queue.id} and fleet {fleet.id} in farm {farm.id}", @@ -191,6 +203,7 @@ def create( farmId=farm.id, queueId=queue.id, fleetId=fleet.id, + **(raw_kwargs or {}), ), ) return QueueFleetAssociation( @@ -206,6 +219,7 @@ def delete( stop_mode: Literal[ "STOP_SCHEDULING_AND_CANCEL_TASKS", "STOP_SCHEDULING_AND_FINISH_TASKS" ] = "STOP_SCHEDULING_AND_CANCEL_TASKS", + raw_kwargs: dict | None = None, ) -> None: self.stop(client=client, stop_mode=stop_mode) call_api( @@ -214,6 +228,7 @@ def delete( farmId=self.farm.id, queueId=self.queue.id, fleetId=self.fleet.id, + **(raw_kwargs or {}), ), ) @@ -274,6 +289,7 @@ class StrEnum(str, Enum): class TaskStatus(StrEnum): UNKNOWN = "UNKNOWN" PENDING = "PENDING" + STARTING = "STARTING" READY = "READY" RUNNING = "RUNNING" ASSIGNED = "ASSIGNED" @@ -335,6 +351,7 @@ def submit( target_task_run_status: str | None = None, max_failed_tasks_count: int | None = None, max_retries_per_task: int | None = None, + raw_kwargs: dict | None = None, ) -> Job: kwargs = clean_kwargs( { @@ -348,6 +365,7 @@ def submit( "targetTaskRunStatus": target_task_run_status, "maxFailedTasksCount": max_failed_tasks_count, "maxRetriesPerTask": max_retries_per_task, + **(raw_kwargs or {}), } ) create_job_response = call_api( @@ -378,6 +396,7 @@ def get_job_details( farm: Farm, queue: Queue, job_id: str, + raw_kwargs: dict | None = None, ) -> dict[str, Any]: """ Calls GetJob API and returns the parsed response, which can be used as @@ -389,6 +408,7 @@ def get_job_details( farmId=farm.id, queueId=queue.id, jobId=job_id, + **(raw_kwargs or {}), ), ) @@ -434,6 +454,42 @@ def get_optional_field( "description": get_optional_field("description"), } + def get_logs( + self, + *, + deadline_client: DeadlineClient, + logs_client: BaseClient, + ) -> dict[str, list[CloudWatchLogEvent]]: + """ + Gets the logs for this Job. + + Args: + deadline_client (DeadlineClient): The DeadlineClient to use + logs_client (BaseClient): The CloudWatch logs boto client to use + + Returns: + dict[str, list[CloudWatchLogEvent]]: A mapping session ID to log events + """ + list_sessions_response = deadline_client.list_sessions( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + ) + sessions = list_sessions_response["sessions"] + + session_log_map: dict[str, list[CloudWatchLogEvent]] = {} + for session in sessions: + session_id = session["sessionId"] + get_log_events_response = logs_client.get_log_events( + logGroupName=f"/aws/deadline/{self.farm.id}/{self.queue.id}", + logStreamName=session_id, + ) + session_log_map[session_id] = [ + CloudWatchLogEvent.from_api_response(le) for le in get_log_events_response["events"] + ] + + return session_log_map + def refresh_job_info(self, *, client: DeadlineClient) -> None: """ Calls GetJob API to refresh job information. The result is used to update the fields @@ -458,6 +514,7 @@ def update( target_task_run_status: str | None = None, max_failed_tasks_count: int | None = None, max_retries_per_task: int | None = None, + raw_kwargs: dict | None = None, ) -> None: kwargs = clean_kwargs( { @@ -465,6 +522,7 @@ def update( "targetTaskRunStatus": target_task_run_status, "maxFailedTasksCount": max_failed_tasks_count, "maxRetriesPerTask": max_retries_per_task, + **(raw_kwargs or {}), } ) call_api( @@ -553,3 +611,18 @@ def __str__(self) -> str: # pragma: no cover f"ended_at: {self.ended_at}", ] ) + + +@dataclass +class CloudWatchLogEvent: + ingestion_time: int + message: str + timestamp: int + + @staticmethod + def from_api_response(response: dict) -> CloudWatchLogEvent: + return CloudWatchLogEvent( + ingestion_time=response["ingestionTime"], + message=response["message"], + timestamp=response["timestamp"], + ) diff --git a/src/deadline_test_fixtures/deadline/worker.py b/src/deadline_test_fixtures/deadline/worker.py index aa67b2e..929278e 100644 --- a/src/deadline_test_fixtures/deadline/worker.py +++ b/src/deadline_test_fixtures/deadline/worker.py @@ -33,6 +33,7 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: # """Get the command to configure the Worker. This must be run as root.""" cmds = [ config.worker_agent_install.install_command, + *(config.pre_install_commands or []), # fmt: off ( "install-deadline-worker " @@ -44,6 +45,7 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: # + f"--group {config.group} " + f"{'--allow-shutdown ' if config.allow_shutdown else ''}" + f"{'--no-install-service ' if config.no_install_service else ''}" + + f"{'--start ' if config.start_service else ''}" ), # fmt: on ] @@ -108,6 +110,12 @@ def __str__(self) -> str: ) +@dataclass(frozen=True) +class PosixUser: + user: str + group: str + + @dataclass(frozen=True) class DeadlineWorkerConfiguration: farm_id: str @@ -117,10 +125,14 @@ class DeadlineWorkerConfiguration: group: str allow_shutdown: bool worker_agent_install: PipInstall + job_users: list[PosixUser] = field(default_factory=lambda: [PosixUser("jobuser", "jobuser")]) + start_service: bool = False no_install_service: bool = False service_model: ServiceModel | None = None file_mappings: list[tuple[str, str]] | None = None """Mapping of files to copy from host environment to worker environment""" + pre_install_commands: list[str] | None = None + """Commands to run before installing the Worker agent""" @dataclass @@ -272,6 +284,26 @@ def _launch_instance(self, *, s3_files: list[tuple[str, str]] | None = None) -> ] ) + job_users_cmds = [] + for job_user in self.configuration.job_users: + job_users_cmds.append(f"groupadd {job_user.group}") + job_users_cmds.append( + f"useradd --create-home --system --shell=/bin/bash --groups={self.configuration.group} -g {job_user.group} {job_user.user}" + ) + job_users_cmds.append(f"usermod -a -G {job_user.group} {self.configuration.user}") + + sudoer_rule_users = ",".join( + [ + self.configuration.user, + *[job_user.user for job_user in self.configuration.job_users], + ] + ) + job_users_cmds.append( + f'echo "{self.configuration.user} ALL=({sudoer_rule_users}) NOPASSWD: ALL" > /etc/sudoers.d/{self.configuration.user}' + ) + + configure_job_users = "\n".join(job_users_cmds) + LOG.info("Launching EC2 instance") run_instance_response = self.ec2_client.run_instances( MinCount=1, @@ -297,8 +329,8 @@ def _launch_instance(self, *, s3_files: list[tuple[str, str]] | None = None) -> # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. set -x groupadd --system {self.configuration.group} -useradd --create-home --system --shell=/bin/bash --groups={self.configuration.group} jobuser useradd --create-home --system --shell=/bin/bash --groups={self.configuration.group} {self.configuration.user} +{configure_job_users} {copy_s3_command} runuser --login {self.configuration.user} --command 'python3 -m venv $HOME/.venv && echo ". $HOME/.venv/bin/activate" >> $HOME/.bashrc' @@ -386,12 +418,18 @@ def __post_init__(self) -> None: def start(self) -> None: self._tmpdir = pathlib.Path(tempfile.mkdtemp()) + # TODO: Support multiple job users on Docker + assert ( + len(self.configuration.job_users) == 1 + ), f"Multiple job users not supported on Docker worker: {self.configuration.job_users}" # Environment variables for "run_container.sh" run_container_env = { **os.environ, + "FARM_ID": self.configuration.farm_id, + "FLEET_ID": self.configuration.fleet_id, "AGENT_USER": self.configuration.user, "SHARED_GROUP": self.configuration.group, - "JOB_USER": "jobuser", + "JOB_USER": self.configuration.job_users[0].user, "CONFIGURE_WORKER_AGENT_CMD": configure_worker_command( config=self.configuration, ), diff --git a/src/deadline_test_fixtures/fixtures.py b/src/deadline_test_fixtures/fixtures.py index a08b6d7..c7e0bc4 100644 --- a/src/deadline_test_fixtures/fixtures.py +++ b/src/deadline_test_fixtures/fixtures.py @@ -427,7 +427,7 @@ def worker( """ worker: DeadlineWorker - if os.environ.get("USE_DOCKER_WORKER", False): + if os.environ.get("USE_DOCKER_WORKER", "").lower() == "true": LOG.info("Creating Docker worker") worker = DockerContainerWorker( configuration=worker_config, diff --git a/test/unit/deadline/test_resources.py b/test/unit/deadline/test_resources.py index 4e32a3a..a372bdf 100644 --- a/test/unit/deadline/test_resources.py +++ b/test/unit/deadline/test_resources.py @@ -9,6 +9,7 @@ import pytest from deadline_test_fixtures import ( + CloudWatchLogEvent, Farm, Queue, Fleet, @@ -640,3 +641,66 @@ def test_wait_until_complete(self, job: Job) -> None: # THEN assert mock_client.get_job.call_count == 2 assert job.task_run_status == "FAILED" + + def test_get_logs(self, job: Job) -> None: + # GIVEN + mock_deadline_client = MagicMock() + mock_deadline_client.list_sessions.return_value = { + "sessions": [ + {"sessionId": "session-1"}, + {"sessionId": "session-2"}, + ], + } + mock_logs_client = MagicMock() + log_events = [ + { + "events": [ + { + "ingestionTime": 123, + "timestamp": 321, + "message": "test", + } + ], + }, + { + "events": [ + { + "ingestionTime": 123123, + "timestamp": 321321, + "message": "testtest", + } + ], + }, + ] + mock_logs_client.get_log_events.side_effect = log_events + + # WHEN + session_log_map = job.get_logs( + deadline_client=mock_deadline_client, + logs_client=mock_logs_client, + ) + + # THEN + mock_deadline_client.list_sessions.assert_called_once_with( + farmId=job.farm.id, + queueId=job.queue.id, + jobId=job.id, + ) + mock_logs_client.get_log_events.assert_has_calls( + [ + call( + logGroupName=f"/aws/deadline/{job.farm.id}/{job.queue.id}", + logStreamName=session_id, + ) + for session_id in ["session-1", "session-2"] + ] + ) + + assert "session-1" in session_log_map + assert session_log_map["session-1"] == [ + CloudWatchLogEvent.from_api_response(le) for le in log_events[0]["events"] + ] + assert "session-2" in session_log_map + assert session_log_map["session-2"] == [ + CloudWatchLogEvent.from_api_response(le) for le in log_events[1]["events"] + ]