diff --git a/src/deadline_test_fixtures/deadline/resources.py b/src/deadline_test_fixtures/deadline/resources.py index bbe6c0d..d9d7628 100644 --- a/src/deadline_test_fixtures/deadline/resources.py +++ b/src/deadline_test_fixtures/deadline/resources.py @@ -6,7 +6,7 @@ import logging from dataclasses import dataclass, fields from enum import Enum -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, TYPE_CHECKING from botocore.client import BaseClient @@ -14,6 +14,9 @@ from ..models import JobAttachmentSettings from ..util import call_api, clean_kwargs, wait_for +if TYPE_CHECKING: + from botocore.paginate import Paginator, PageIterator + LOG = logging.getLogger(__name__) @@ -470,23 +473,44 @@ def get_logs( Returns: JobLogs: The job logs """ - list_sessions_response = deadline_client.list_sessions( - farmId=self.farm.id, - queueId=self.queue.id, - jobId=self.id, + + def paginate_list_sessions(): + response = deadline_client.list_sessions( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + ) + yield response + while response.get("nextToken"): + response = deadline_client.list_sessions( + farmId=self.farm.id, + queueId=self.queue.id, + jobId=self.id, + nextToken=response["nextToken"], + ) + yield response + + list_sessions_pages = call_api( + description=f"Listing sessions for job {self.id}", + fn=paginate_list_sessions, ) - sessions = list_sessions_response["sessions"] + sessions = [s for p in list_sessions_pages for s in p["sessions"]] log_group_name = f"/aws/deadline/{self.farm.id}/{self.queue.id}" + filter_log_events_paginator: Paginator = logs_client.get_paginator("filter_log_events") session_log_map: dict[str, list[CloudWatchLogEvent]] = {} for session in sessions: session_id = session["sessionId"] - get_log_events_response = logs_client.get_log_events( - logGroupName=log_group_name, - logStreamName=session_id, + filter_log_events_pages: PageIterator = call_api( + description=f"Fetching log events for session {session_id} in log group {log_group_name}", + fn=lambda: filter_log_events_paginator.paginate( + logGroupName=log_group_name, + logStreamNames=[session_id], + ), ) + log_events = filter_log_events_pages.build_full_result() session_log_map[session_id] = [ - CloudWatchLogEvent.from_api_response(le) for le in get_log_events_response["events"] + CloudWatchLogEvent.from_api_response(e) for e in log_events["events"] ] return JobLogs( diff --git a/test/unit/deadline/test_resources.py b/test/unit/deadline/test_resources.py index a260d6a..7bc7cef 100644 --- a/test/unit/deadline/test_resources.py +++ b/test/unit/deadline/test_resources.py @@ -645,14 +645,21 @@ def test_wait_until_complete(self, job: Job) -> None: def test_get_logs(self, job: Job) -> None: # GIVEN mock_deadline_client = MagicMock() - mock_deadline_client.list_sessions.return_value = { - "sessions": [ - {"sessionId": "session-1"}, - {"sessionId": "session-2"}, - ], - } + mock_deadline_client.list_sessions.side_effect = [ + { + "sessions": [ + {"sessionId": "session-1"}, + ], + "nextToken": "1", + }, + { + "sessions": [ + {"sessionId": "session-2"}, + ], + }, + ] mock_logs_client = MagicMock() - log_events = [ + log_events: list = [ { "events": [ { @@ -661,6 +668,7 @@ def test_get_logs(self, job: Job) -> None: "message": "test", } ], + "nextToken": "a", }, { "events": [ @@ -672,7 +680,8 @@ def test_get_logs(self, job: Job) -> None: ], }, ] - mock_logs_client.get_log_events.side_effect = log_events + mock_logs_paginator = mock_logs_client.get_paginator.return_value + mock_logs_paginator.paginate.return_value.build_full_result.side_effect = log_events # WHEN job_logs = job.get_logs( @@ -681,18 +690,34 @@ def test_get_logs(self, job: Job) -> None: ) # THEN - mock_deadline_client.list_sessions.assert_called_once_with( - farmId=job.farm.id, - queueId=job.queue.id, - jobId=job.id, + mock_deadline_client.list_sessions.assert_has_calls( + [ + call( + farmId=job.farm.id, + queueId=job.queue.id, + jobId=job.id, + ), + call( + farmId=job.farm.id, + queueId=job.queue.id, + jobId=job.id, + nextToken="1", + ), + ] ) - mock_logs_client.get_log_events.assert_has_calls( + mock_logs_client.get_paginator.assert_called_once_with("filter_log_events") + mock_logs_paginator.paginate.assert_has_calls( [ call( logGroupName=f"/aws/deadline/{job.farm.id}/{job.queue.id}", - logStreamName=session_id, - ) - for session_id in ["session-1", "session-2"] + logStreamNames=["session-1"], + ), + call().build_full_result(), + call( + logGroupName=f"/aws/deadline/{job.farm.id}/{job.queue.id}", + logStreamNames=["session-2"], + ), + call().build_full_result(), ] )