Skip to content

Commit

Permalink
feat: add Worker pre-install commands, --start, and Job.get_logs (aws…
Browse files Browse the repository at this point in the history
…-deadline#25)

* feat: add Worker pre-install commands, --start, and Job.get_logs

Signed-off-by: Jericho Tolentino <[email protected]>
  • Loading branch information
jericht authored Oct 12, 2023
1 parent 74e9b1c commit ddfb51c
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 9 deletions.
4 changes: 4 additions & 0 deletions src/deadline_test_fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
from .deadline import (
CloudWatchLogEvent,
CommandResult,
DeadlineClient,
DeadlineWorker,
Expand All @@ -10,6 +11,7 @@
Farm,
Fleet,
PipInstall,
PosixUser,
Queue,
QueueFleetAssociation,
TaskStatus,
Expand All @@ -34,6 +36,7 @@

__all__ = [
"BootstrapResources",
"CloudWatchLogEvent",
"CodeArtifactRepositoryInfo",
"CommandResult",
"DeadlineResources",
Expand All @@ -51,6 +54,7 @@
"JobAttachmentSettings",
"JobAttachmentManager",
"PipInstall",
"PosixUser",
"S3Object",
"ServiceModel",
"StubDeadlineClient",
Expand Down
4 changes: 4 additions & 0 deletions src/deadline_test_fixtures/deadline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from .resources import (
CloudWatchLogEvent,
Farm,
Fleet,
Job,
Expand All @@ -16,9 +17,11 @@
DockerContainerWorker,
EC2InstanceWorker,
PipInstall,
PosixUser,
)

__all__ = [
"CloudWatchLogEvent",
"CommandResult",
"DeadlineClient",
"DeadlineWorker",
Expand All @@ -29,6 +32,7 @@
"Fleet",
"Job",
"PipInstall",
"PosixUser",
"Queue",
"QueueFleetAssociation",
"TaskStatus",
Expand Down
118 changes: 112 additions & 6 deletions src/deadline_test_fixtures/deadline/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
import logging
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Callable, Literal
from typing import Any, Callable, Literal, TYPE_CHECKING

from botocore.client import BaseClient

from .client import DeadlineClient
from ..models import JobAttachmentSettings
from ..util import call_api, clean_kwargs, wait_for

if TYPE_CHECKING:
from botocore.paginate import Paginator, PageIterator

LOG = logging.getLogger(__name__)


Expand All @@ -24,21 +29,23 @@ def create(
*,
client: DeadlineClient,
display_name: str,
raw_kwargs: dict | None = None,
) -> Farm:
response = call_api(
description=f"Create farm {display_name}",
fn=lambda: client.create_farm(
displayName=display_name,
**(raw_kwargs or {}),
),
)
farm_id = response["farmId"]
LOG.info(f"Created farm: {farm_id}")
return Farm(id=farm_id)

def delete(self, *, client: DeadlineClient) -> None:
def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None:
call_api(
description=f"Delete farm {self.id}",
fn=lambda: client.delete_farm(farmId=self.id),
fn=lambda: client.delete_farm(farmId=self.id, **(raw_kwargs or {})),
)


Expand All @@ -55,6 +62,7 @@ def create(
farm: Farm,
role_arn: str | None = None,
job_attachments: JobAttachmentSettings | None = None,
raw_kwargs: dict | None = None,
) -> Queue:
kwargs = clean_kwargs(
{
Expand All @@ -64,6 +72,7 @@ def create(
"jobAttachmentSettings": (
job_attachments.as_queue_settings() if job_attachments else None
),
**(raw_kwargs or {}),
}
)

Expand All @@ -79,10 +88,12 @@ def create(
farm=farm,
)

def delete(self, *, client: DeadlineClient) -> None:
def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None:
call_api(
description=f"Delete queue {self.id}",
fn=lambda: client.delete_queue(queueId=self.id, farmId=self.farm.id),
fn=lambda: client.delete_queue(
queueId=self.id, farmId=self.farm.id, **(raw_kwargs or {})
),
)


Expand All @@ -99,13 +110,15 @@ def create(
farm: Farm,
configuration: dict,
role_arn: str | None = None,
raw_kwargs: dict | None = None,
) -> Fleet:
kwargs = clean_kwargs(
{
"farmId": farm.id,
"displayName": display_name,
"roleArn": role_arn,
"configuration": configuration,
**(raw_kwargs or {}),
}
)
response = call_api(
Expand All @@ -127,12 +140,13 @@ def create(

return fleet

def delete(self, *, client: DeadlineClient) -> None:
def delete(self, *, client: DeadlineClient, raw_kwargs: dict | None = None) -> None:
call_api(
description=f"Delete fleet {self.id}",
fn=lambda: client.delete_fleet(
farmId=self.farm.id,
fleetId=self.id,
**(raw_kwargs or {}),
),
)

Expand Down Expand Up @@ -184,13 +198,15 @@ def create(
farm: Farm,
queue: Queue,
fleet: Fleet,
raw_kwargs: dict | None = None,
) -> QueueFleetAssociation:
call_api(
description=f"Create queue-fleet association for queue {queue.id} and fleet {fleet.id} in farm {farm.id}",
fn=lambda: client.create_queue_fleet_association(
farmId=farm.id,
queueId=queue.id,
fleetId=fleet.id,
**(raw_kwargs or {}),
),
)
return QueueFleetAssociation(
Expand All @@ -206,6 +222,7 @@ def delete(
stop_mode: Literal[
"STOP_SCHEDULING_AND_CANCEL_TASKS", "STOP_SCHEDULING_AND_FINISH_TASKS"
] = "STOP_SCHEDULING_AND_CANCEL_TASKS",
raw_kwargs: dict | None = None,
) -> None:
self.stop(client=client, stop_mode=stop_mode)
call_api(
Expand All @@ -214,6 +231,7 @@ def delete(
farmId=self.farm.id,
queueId=self.queue.id,
fleetId=self.fleet.id,
**(raw_kwargs or {}),
),
)

Expand Down Expand Up @@ -336,6 +354,7 @@ def submit(
target_task_run_status: str | None = None,
max_failed_tasks_count: int | None = None,
max_retries_per_task: int | None = None,
raw_kwargs: dict | None = None,
) -> Job:
kwargs = clean_kwargs(
{
Expand All @@ -349,6 +368,7 @@ def submit(
"targetTaskRunStatus": target_task_run_status,
"maxFailedTasksCount": max_failed_tasks_count,
"maxRetriesPerTask": max_retries_per_task,
**(raw_kwargs or {}),
}
)
create_job_response = call_api(
Expand Down Expand Up @@ -379,6 +399,7 @@ def get_job_details(
farm: Farm,
queue: Queue,
job_id: str,
raw_kwargs: dict | None = None,
) -> dict[str, Any]:
"""
Calls GetJob API and returns the parsed response, which can be used as
Expand All @@ -390,6 +411,7 @@ def get_job_details(
farmId=farm.id,
queueId=queue.id,
jobId=job_id,
**(raw_kwargs or {}),
),
)

Expand Down Expand Up @@ -435,6 +457,67 @@ def get_optional_field(
"description": get_optional_field("description"),
}

def get_logs(
self,
*,
deadline_client: DeadlineClient,
logs_client: BaseClient,
) -> JobLogs:
"""
Gets the logs for this Job.
Args:
deadline_client (DeadlineClient): The DeadlineClient to use
logs_client (BaseClient): The CloudWatch logs boto client to use
Returns:
JobLogs: The job logs
"""

def paginate_list_sessions():
response = deadline_client.list_sessions(
farmId=self.farm.id,
queueId=self.queue.id,
jobId=self.id,
)
yield response
while response.get("nextToken"):
response = deadline_client.list_sessions(
farmId=self.farm.id,
queueId=self.queue.id,
jobId=self.id,
nextToken=response["nextToken"],
)
yield response

list_sessions_pages = call_api(
description=f"Listing sessions for job {self.id}",
fn=paginate_list_sessions,
)
sessions = [s for p in list_sessions_pages for s in p["sessions"]]

log_group_name = f"/aws/deadline/{self.farm.id}/{self.queue.id}"
filter_log_events_paginator: Paginator = logs_client.get_paginator("filter_log_events")
session_log_map: dict[str, list[CloudWatchLogEvent]] = {}
for session in sessions:
session_id = session["sessionId"]
filter_log_events_pages: PageIterator = call_api(
description=f"Fetching log events for session {session_id} in log group {log_group_name}",
fn=lambda: filter_log_events_paginator.paginate(
logGroupName=log_group_name,
logStreamNames=[session_id],
),
)
log_events = filter_log_events_pages.build_full_result()
session_log_map[session_id] = [
CloudWatchLogEvent.from_api_response(e) for e in log_events["events"]
]

return JobLogs(
log_group_name=log_group_name,
logs=session_log_map,
)

def refresh_job_info(self, *, client: DeadlineClient) -> None:
"""
Calls GetJob API to refresh job information. The result is used to update the fields
Expand All @@ -459,13 +542,15 @@ def update(
target_task_run_status: str | None = None,
max_failed_tasks_count: int | None = None,
max_retries_per_task: int | None = None,
raw_kwargs: dict | None = None,
) -> None:
kwargs = clean_kwargs(
{
"priority": priority,
"targetTaskRunStatus": target_task_run_status,
"maxFailedTasksCount": max_failed_tasks_count,
"maxRetriesPerTask": max_retries_per_task,
**(raw_kwargs or {}),
}
)
call_api(
Expand Down Expand Up @@ -554,3 +639,24 @@ def __str__(self) -> str: # pragma: no cover
f"ended_at: {self.ended_at}",
]
)


@dataclass
class JobLogs:
log_group_name: str
logs: dict[str, list[CloudWatchLogEvent]]


@dataclass
class CloudWatchLogEvent:
ingestion_time: int
message: str
timestamp: int

@staticmethod
def from_api_response(response: dict) -> CloudWatchLogEvent:
return CloudWatchLogEvent(
ingestion_time=response["ingestionTime"],
message=response["message"],
timestamp=response["timestamp"],
)
Loading

0 comments on commit ddfb51c

Please sign in to comment.