Skip to content

Commit

Permalink
paginate apis
Browse files Browse the repository at this point in the history
  • Loading branch information
jericht committed Oct 12, 2023
1 parent e57b537 commit 7706c1f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
44 changes: 34 additions & 10 deletions src/deadline_test_fixtures/deadline/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import logging
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Callable, Literal
from typing import Any, Callable, Literal, TYPE_CHECKING

from botocore.client import BaseClient

from .client import DeadlineClient
from ..models import JobAttachmentSettings
from ..util import call_api, clean_kwargs, wait_for

if TYPE_CHECKING:
from botocore.paginate import Paginator, PageIterator

LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -470,23 +473,44 @@ def get_logs(
Returns:
JobLogs: The job logs
"""
list_sessions_response = deadline_client.list_sessions(
farmId=self.farm.id,
queueId=self.queue.id,
jobId=self.id,

def paginate_list_sessions():
response = deadline_client.list_sessions(
farmId=self.farm.id,
queueId=self.queue.id,
jobId=self.id,
)
yield response
while response.get("nextToken"):
response = deadline_client.list_sessions(
farmId=self.farm.id,
queueId=self.queue.id,
jobId=self.id,
nextToken=response["nextToken"],
)
yield response

list_sessions_pages = call_api(
description=f"Listing sessions for job {self.id}",
fn=paginate_list_sessions,
)
sessions = list_sessions_response["sessions"]
sessions = [s for p in list_sessions_pages for s in p["sessions"]]

log_group_name = f"/aws/deadline/{self.farm.id}/{self.queue.id}"
filter_log_events_paginator: Paginator = logs_client.get_paginator("filter_log_events")
session_log_map: dict[str, list[CloudWatchLogEvent]] = {}
for session in sessions:
session_id = session["sessionId"]
get_log_events_response = logs_client.get_log_events(
logGroupName=log_group_name,
logStreamName=session_id,
filter_log_events_pages: PageIterator = call_api(
description=f"Fetching log events for session {session_id} in log group {log_group_name}",
fn=lambda: filter_log_events_paginator.paginate(
logGroupName=log_group_name,
logStreamNames=[session_id],
),
)
log_events = filter_log_events_pages.build_full_result()
session_log_map[session_id] = [
CloudWatchLogEvent.from_api_response(le) for le in get_log_events_response["events"]
CloudWatchLogEvent.from_api_response(e) for e in log_events["events"]
]

return JobLogs(
Expand Down
57 changes: 41 additions & 16 deletions test/unit/deadline/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,14 +645,21 @@ def test_wait_until_complete(self, job: Job) -> None:
def test_get_logs(self, job: Job) -> None:
# GIVEN
mock_deadline_client = MagicMock()
mock_deadline_client.list_sessions.return_value = {
"sessions": [
{"sessionId": "session-1"},
{"sessionId": "session-2"},
],
}
mock_deadline_client.list_sessions.side_effect = [
{
"sessions": [
{"sessionId": "session-1"},
],
"nextToken": "1",
},
{
"sessions": [
{"sessionId": "session-2"},
],
},
]
mock_logs_client = MagicMock()
log_events = [
log_events: list = [
{
"events": [
{
Expand All @@ -661,6 +668,7 @@ def test_get_logs(self, job: Job) -> None:
"message": "test",
}
],
"nextToken": "a",
},
{
"events": [
Expand All @@ -672,7 +680,8 @@ def test_get_logs(self, job: Job) -> None:
],
},
]
mock_logs_client.get_log_events.side_effect = log_events
mock_logs_paginator = mock_logs_client.get_paginator.return_value
mock_logs_paginator.paginate.return_value.build_full_result.side_effect = log_events

# WHEN
job_logs = job.get_logs(
Expand All @@ -681,18 +690,34 @@ def test_get_logs(self, job: Job) -> None:
)

# THEN
mock_deadline_client.list_sessions.assert_called_once_with(
farmId=job.farm.id,
queueId=job.queue.id,
jobId=job.id,
mock_deadline_client.list_sessions.assert_has_calls(
[
call(
farmId=job.farm.id,
queueId=job.queue.id,
jobId=job.id,
),
call(
farmId=job.farm.id,
queueId=job.queue.id,
jobId=job.id,
nextToken="1",
),
]
)
mock_logs_client.get_log_events.assert_has_calls(
mock_logs_client.get_paginator.assert_called_once_with("filter_log_events")
mock_logs_paginator.paginate.assert_has_calls(
[
call(
logGroupName=f"/aws/deadline/{job.farm.id}/{job.queue.id}",
logStreamName=session_id,
)
for session_id in ["session-1", "session-2"]
logStreamNames=["session-1"],
),
call().build_full_result(),
call(
logGroupName=f"/aws/deadline/{job.farm.id}/{job.queue.id}",
logStreamNames=["session-2"],
),
call().build_full_result(),
]
)

Expand Down

0 comments on commit 7706c1f

Please sign in to comment.