From 95e3dac1936b848e71b91adad32a43d05cca9c07 Mon Sep 17 00:00:00 2001 From: Jericho Tolentino <68654047+jericht@users.noreply.github.com> Date: Tue, 10 Oct 2023 22:14:01 +0000 Subject: [PATCH] feat!: add Worker pre-install commands, --start, and Job.get_logs Signed-off-by: Jericho Tolentino <68654047+jericht@users.noreply.github.com> --- src/deadline_test_fixtures/__init__.py | 2 + .../deadline/__init__.py | 2 + .../deadline/resources.py | 83 +++++++++++++++++-- src/deadline_test_fixtures/deadline/worker.py | 7 ++ src/deadline_test_fixtures/fixtures.py | 6 +- test/unit/deadline/test_resources.py | 64 ++++++++++++++ 6 files changed, 156 insertions(+), 8 deletions(-) diff --git a/src/deadline_test_fixtures/__init__.py b/src/deadline_test_fixtures/__init__.py index 83ecc60..c427276 100644 --- a/src/deadline_test_fixtures/__init__.py +++ b/src/deadline_test_fixtures/__init__.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .deadline import ( + CloudWatchLogEvent, CommandResult, DeadlineClient, DeadlineWorker, @@ -34,6 +35,7 @@ __all__ = [ "BootstrapResources", + "CloudWatchLogEvent", "CodeArtifactRepositoryInfo", "CommandResult", "DeadlineResources", diff --git a/src/deadline_test_fixtures/deadline/__init__.py b/src/deadline_test_fixtures/deadline/__init__.py index a2422a2..ec1f0ef 100644 --- a/src/deadline_test_fixtures/deadline/__init__.py +++ b/src/deadline_test_fixtures/deadline/__init__.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from .resources import ( + CloudWatchLogEvent, Farm, Fleet, Job, @@ -19,6 +20,7 @@ ) __all__ = [ + "CloudWatchLogEvent", "CommandResult", "DeadlineClient", "DeadlineWorker", diff --git a/src/deadline_test_fixtures/deadline/resources.py b/src/deadline_test_fixtures/deadline/resources.py index 4e97973..9684da6 100644 --- a/src/deadline_test_fixtures/deadline/resources.py +++ b/src/deadline_test_fixtures/deadline/resources.py @@ -8,6 +8,8 @@ from enum import Enum from typing import Any, Callable, Literal +from botocore.client import BaseClient + from .client import DeadlineClient from ..models import JobAttachmentSettings from ..util import call_api, clean_kwargs, wait_for @@ -24,21 +26,23 @@ def create( *, client: DeadlineClient, display_name: str, + raw_kwargs: dict | None = None, ) -> Farm: response = call_api( description=f"Create farm {display_name}", fn=lambda: client.create_farm( displayName=display_name, + **(raw_kwargs or {}), ), ) farm_id = response["farmId"] LOG.info(f"Created farm: {farm_id}") return Farm(id=farm_id) - def delete(self, *, client: DeadlineClient) -> None: + def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None: call_api( description=f"Delete farm {self.id}", - fn=lambda: client.delete_farm(farmId=self.id), + fn=lambda: client.delete_farm(farmId=self.id, **(raw_kwargs or {})), ) @@ -55,6 +59,7 @@ def create( farm: Farm, role_arn: str | None = None, job_attachments: JobAttachmentSettings | None = None, + raw_kwargs: dict | None = None, ) -> Queue: kwargs = clean_kwargs( { @@ -64,6 +69,7 @@ def create( "jobAttachmentSettings": ( job_attachments.as_queue_settings() if job_attachments else None ), + **(raw_kwargs or {}), } ) @@ -79,10 +85,12 @@ def create( farm=farm, ) - def delete(self, *, client: DeadlineClient) -> None: + def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None: call_api( description=f"Delete queue {self.id}", - fn=lambda: client.delete_queue(queueId=self.id, farmId=self.farm.id), + fn=lambda: client.delete_queue( + queueId=self.id, farmId=self.farm.id, **(raw_kwargs or {}) + ), ) @@ -99,6 +107,7 @@ def create( farm: Farm, configuration: dict, role_arn: str | None = None, + raw_kwargs: dict | None = None, ) -> Fleet: kwargs = clean_kwargs( { @@ -106,6 +115,7 @@ def create( "displayName": display_name, "roleArn": role_arn, "configuration": configuration, + **(raw_kwargs or {}), } ) response = call_api( @@ -127,12 +137,13 @@ def create( return fleet - def delete(self, *, client: DeadlineClient) -> None: + def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None: call_api( description=f"Delete fleet {self.id}", fn=lambda: client.delete_fleet( farmId=self.farm.id, fleetId=self.id, + **(raw_kwargs or {}), ), ) @@ -184,6 +195,7 @@ def create( farm: Farm, queue: Queue, fleet: Fleet, + raw_kwargs: dict | None = None, ) -> QueueFleetAssociation: call_api( description=f"Create queue-fleet association for queue {queue.id} and fleet {fleet.id} in farm {farm.id}", @@ -191,6 +203,7 @@ def create( farmId=farm.id, queueId=queue.id, fleetId=fleet.id, + **(raw_kwargs or {}), ), ) return QueueFleetAssociation( @@ -206,6 +219,7 @@ def delete( stop_mode: Literal[ "STOP_SCHEDULING_AND_CANCEL_TASKS", "STOP_SCHEDULING_AND_FINISH_TASKS" ] = "STOP_SCHEDULING_AND_CANCEL_TASKS", + raw_kwargs: dict | None = None, ) -> None: self.stop(client=client, stop_mode=stop_mode) call_api( @@ -214,6 +228,7 @@ def delete( farmId=self.farm.id, queueId=self.queue.id, fleetId=self.fleet.id, + **(raw_kwargs or {}), ), ) @@ -274,6 +289,7 @@ class StrEnum(str, Enum): class TaskStatus(StrEnum): UNKNOWN = "UNKNOWN" PENDING = "PENDING" + STARTING = "STARTING" READY = "READY" RUNNING = "RUNNING" ASSIGNED = "ASSIGNED" @@ -335,6 +351,7 @@ def submit( target_task_run_status: str | None = None, max_failed_tasks_count: int | None = None, max_retries_per_task: int | None = None, + raw_kwargs: dict | None = None, ) -> Job: kwargs = clean_kwargs( { @@ -348,6 +365,7 @@ def submit( "targetTaskRunStatus": target_task_run_status, "maxFailedTasksCount": max_failed_tasks_count, "maxRetriesPerTask": max_retries_per_task, + **(raw_kwargs or {}), } ) create_job_response = call_api( @@ -378,6 +396,7 @@ def get_job_details( farm: Farm, queue: Queue, job_id: str, + raw_kwargs: dict | None = None, ) -> dict[str, Any]: """ Calls GetJob API and returns the parsed response, which can be used as @@ -389,6 +408,7 @@ def get_job_details( farmId=farm.id, queueId=queue.id, jobId=job_id, + **(raw_kwargs or {}), ), ) @@ -434,6 +454,42 @@ def get_optional_field( "description": get_optional_field("description"), } + def get_logs( + self, + *, + deadline_client: DeadlineClient, + logs_client: BaseClient, + ) -> dict[str, list[CloudWatchLogEvent]]: + """ + Gets the logs for this Job. + + Args: + deadline_client (DeadlineClient): The DeadlineClient to use + logs_client (BaseClient): The CloudWatch logs boto client to use + + Returns: + dict[str, list[CloudWatchLogEvent]]: A mapping session ID to log events + """ + list_sessions_response = deadline_client.list_sessions( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + ) + sessions = list_sessions_response["sessions"] + + session_log_map: dict[str, list[CloudWatchLogEvent]] = {} + for session in sessions: + session_id = session["sessionId"] + get_log_events_response = logs_client.get_log_events( + logGroupName=f"/aws/deadline/{self.farm.id}/{self.queue.id}", + logStreamName=session_id, + ) + session_log_map[session_id] = [ + CloudWatchLogEvent.from_api_response(le) for le in get_log_events_response["events"] + ] + + return session_log_map + def refresh_job_info(self, *, client: DeadlineClient) -> None: """ Calls GetJob API to refresh job information. The result is used to update the fields @@ -458,6 +514,7 @@ def update( target_task_run_status: str | None = None, max_failed_tasks_count: int | None = None, max_retries_per_task: int | None = None, + raw_kwargs: dict | None = None, ) -> None: kwargs = clean_kwargs( { @@ -465,6 +522,7 @@ def update( "targetTaskRunStatus": target_task_run_status, "maxFailedTasksCount": max_failed_tasks_count, "maxRetriesPerTask": max_retries_per_task, + **(raw_kwargs or {}), } ) call_api( @@ -553,3 +611,18 @@ def __str__(self) -> str: # pragma: no cover f"ended_at: {self.ended_at}", ] ) + + +@dataclass +class CloudWatchLogEvent: + ingestion_time: int + message: str + timestamp: int + + @staticmethod + def from_api_response(response: dict) -> CloudWatchLogEvent: + return CloudWatchLogEvent( + ingestion_time=response["ingestionTime"], + message=response["message"], + timestamp=response["timestamp"], + ) diff --git a/src/deadline_test_fixtures/deadline/worker.py b/src/deadline_test_fixtures/deadline/worker.py index aa67b2e..518335e 100644 --- a/src/deadline_test_fixtures/deadline/worker.py +++ b/src/deadline_test_fixtures/deadline/worker.py @@ -33,6 +33,7 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: # """Get the command to configure the Worker. This must be run as root.""" cmds = [ config.worker_agent_install.install_command, + *(config.pre_install_commands or []), # fmt: off ( "install-deadline-worker " @@ -44,6 +45,7 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: # + f"--group {config.group} " + f"{'--allow-shutdown ' if config.allow_shutdown else ''}" + f"{'--no-install-service ' if config.no_install_service else ''}" + + f"{'--start ' if config.start_service else ''}" ), # fmt: on ] @@ -117,10 +119,13 @@ class DeadlineWorkerConfiguration: group: str allow_shutdown: bool worker_agent_install: PipInstall + start_service: bool = False no_install_service: bool = False service_model: ServiceModel | None = None file_mappings: list[tuple[str, str]] | None = None """Mapping of files to copy from host environment to worker environment""" + pre_install_commands: list[str] | None = None + """Commands to run before installing the Worker agent""" @dataclass @@ -389,6 +394,8 @@ def start(self) -> None: # Environment variables for "run_container.sh" run_container_env = { **os.environ, + "FARM_ID": self.configuration.farm_id, + "FLEET_ID": self.configuration.fleet_id, "AGENT_USER": self.configuration.user, "SHARED_GROUP": self.configuration.group, "JOB_USER": "jobuser", diff --git a/src/deadline_test_fixtures/fixtures.py b/src/deadline_test_fixtures/fixtures.py index a08b6d7..6ed79df 100644 --- a/src/deadline_test_fixtures/fixtures.py +++ b/src/deadline_test_fixtures/fixtures.py @@ -327,7 +327,7 @@ def deadline_resources( farm.delete(client=deadline_client) -@pytest.fixture(scope="session") +@pytest.fixture(scope="class") def worker_config( deadline_resources: DeadlineResources, codeartifact: CodeArtifactRepositoryInfo, @@ -402,7 +402,7 @@ def worker_config( ) -@pytest.fixture(scope="session") +@pytest.fixture(scope="class") def worker( request: pytest.FixtureRequest, deadline_client: DeadlineClient, @@ -427,7 +427,7 @@ def worker( """ worker: DeadlineWorker - if os.environ.get("USE_DOCKER_WORKER", False): + if os.environ.get("USE_DOCKER_WORKER", "").lower() == "true": LOG.info("Creating Docker worker") worker = DockerContainerWorker( configuration=worker_config, diff --git a/test/unit/deadline/test_resources.py b/test/unit/deadline/test_resources.py index 4e32a3a..a372bdf 100644 --- a/test/unit/deadline/test_resources.py +++ b/test/unit/deadline/test_resources.py @@ -9,6 +9,7 @@ import pytest from deadline_test_fixtures import ( + CloudWatchLogEvent, Farm, Queue, Fleet, @@ -640,3 +641,66 @@ def test_wait_until_complete(self, job: Job) -> None: # THEN assert mock_client.get_job.call_count == 2 assert job.task_run_status == "FAILED" + + def test_get_logs(self, job: Job) -> None: + # GIVEN + mock_deadline_client = MagicMock() + mock_deadline_client.list_sessions.return_value = { + "sessions": [ + {"sessionId": "session-1"}, + {"sessionId": "session-2"}, + ], + } + mock_logs_client = MagicMock() + log_events = [ + { + "events": [ + { + "ingestionTime": 123, + "timestamp": 321, + "message": "test", + } + ], + }, + { + "events": [ + { + "ingestionTime": 123123, + "timestamp": 321321, + "message": "testtest", + } + ], + }, + ] + mock_logs_client.get_log_events.side_effect = log_events + + # WHEN + session_log_map = job.get_logs( + deadline_client=mock_deadline_client, + logs_client=mock_logs_client, + ) + + # THEN + mock_deadline_client.list_sessions.assert_called_once_with( + farmId=job.farm.id, + queueId=job.queue.id, + jobId=job.id, + ) + mock_logs_client.get_log_events.assert_has_calls( + [ + call( + logGroupName=f"/aws/deadline/{job.farm.id}/{job.queue.id}", + logStreamName=session_id, + ) + for session_id in ["session-1", "session-2"] + ] + ) + + assert "session-1" in session_log_map + assert session_log_map["session-1"] == [ + CloudWatchLogEvent.from_api_response(le) for le in log_events[0]["events"] + ] + assert "session-2" in session_log_map + assert session_log_map["session-2"] == [ + CloudWatchLogEvent.from_api_response(le) for le in log_events[1]["events"] + ]