diff --git a/src/deadline_test_fixtures/deadline/resources.py b/src/deadline_test_fixtures/deadline/resources.py index 6210709..2978d4f 100644 --- a/src/deadline_test_fixtures/deadline/resources.py +++ b/src/deadline_test_fixtures/deadline/resources.py @@ -6,19 +6,20 @@ import logging import re import time +from collections.abc import Generator from dataclasses import asdict, dataclass, fields from datetime import timedelta from enum import Enum -from typing import Any, Callable, Generator, Literal, TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union from botocore.client import BaseClient -from .client import DeadlineClient from ..models import JobAttachmentSettings, JobRunAsUser from ..util import call_api, clean_kwargs, wait_for +from .client import DeadlineClient if TYPE_CHECKING: - from botocore.paginate import Paginator, PageIterator + from botocore.paginate import PageIterator, Paginator LOG = logging.getLogger(__name__) @@ -707,6 +708,27 @@ def list_steps( ), ) + def get_only_task( + self, + *, + deadline_client: DeadlineClient, + ) -> Task: + """ + Asserts that the job has a single step and a single task, and returns the task. + + Args: + deadline_client (deadline_test_fixtures.client.DeadlineClient): Deadline boto client + Return: + task: The single task of the job + """ + # Assert there is a single step and task + steps = list(self.list_steps(deadline_client=deadline_client)) + assert len(steps) == 1, "Job contains multiple steps" + step = steps[0] + tasks = list(step.list_tasks(deadline_client=deadline_client)) + assert len(tasks) == 1, "Job contains multiple tasks" + return tasks[0] + def assert_single_task_log_contains( self, *, @@ -754,13 +776,7 @@ def assert_single_task_log_contains( if isinstance(expected_pattern, str): expected_pattern = re.compile(expected_pattern) - # Assert there is a single step and task - steps = list(self.list_steps(deadline_client=deadline_client)) - assert len(steps) == 1, "Job contains multiple steps" - step = steps[0] - tasks = list(step.list_tasks(deadline_client=deadline_client)) - assert len(tasks) == 1, "Job contains multiple tasks" - task = tasks[0] + task = self.get_only_task(deadline_client=deadline_client) session = task.get_last_session(deadline_client=deadline_client) session.assert_log_contains( @@ -771,6 +787,54 @@ def assert_single_task_log_contains( retries=retries, ) + def assert_single_task_log_does_not_contain( + self, + *, + deadline_client: DeadlineClient, + logs_client: BaseClient, + expected_pattern: re.Pattern | str, + assert_fail_msg: str = "Expected message found in session log", + consistency_wait_time: timedelta = timedelta(seconds=3), + ) -> None: + """ + Asserts that the expected regular expression pattern doesn't exist in the job's session log. + + This method is intended for jobs with a single step and task. It checks the logs of the + last run session for the single task. + + The method accounts for the eventual-consistency of CloudWatch log delivery and availability + through CloudWatch APIs by a configurable wait time. The method does an initial immediate + check, then waits for the configured consistency wait time before checking again, if the + wait time is greater than zero. If neither check (or one check if wait time is zero) matches + the expected pattern then the log is assumed to not contain the given line. + + Args: + deadline_client (deadline_test_fixtures.client.DeadlineClient): Deadline boto client + logs_client (botocore.clients.BaseClient): CloudWatch logs boto client + expected_pattern (re.Pattern | str): Either a regular expression pattern string, or a + pre-compiled regular expression Pattern object. This is pattern is searched against + each of the job's session logs, contatenated as a multi-line string joined by + a single newline character (\\n). + assert_fail_msg (str): The assertion message to raise if the pattern is found + The CloudWatch log group name is appended to the end of this message to assist + with diagnosis. The default is "Expected message found in session log". + consistency_wait_time (datetime.timedelta): Wait time between first and second check. + Default is 3s, wait times opperates in second increments. + """ + # Coerce Regex str patterns to a re.Pattern + if isinstance(expected_pattern, str): + expected_pattern = re.compile(expected_pattern) + + task = self.get_only_task(deadline_client=deadline_client) + + session = task.get_last_session(deadline_client=deadline_client) + session.assert_log_does_not_contain( + logs_client=logs_client, + expected_pattern=expected_pattern, + assert_fail_msg=assert_fail_msg, + consistency_wait_time=consistency_wait_time, + ) + @property def complete(self) -> bool: # pragma: no cover return self.task_run_status in COMPLETE_TASK_STATUSES @@ -1167,6 +1231,61 @@ def assert_log_contains( else: return + def assert_log_does_not_contain( + self, + *, + logs_client: BaseClient, + expected_pattern: re.Pattern | str, + assert_fail_msg: str = "Expected message found in session log", + consistency_wait_time: timedelta = timedelta(seconds=4.5), + ) -> None: + """ + Asserts that the expected regular expression pattern does not exist in the job's session log. + + This method accounts for the eventual-consistency of CloudWatch log delivery and availability + through CloudWatch APIs by a configurable wait time. The method does an initial immediate + check, then waits for the configured consistency wait time before checking again, if the wait + time is greater than zero. If neither check (or one check if wait time is zero) matches the + expected pattern then the log is assumed to not contain the given line. + + Args: + logs_client (botocore.clients.BaseClient): CloudWatch logs boto client + expected_pattern (re.Pattern | str): Either a regular expression pattern string, or a + pre-compiled regular expression Pattern object. This is pattern is searched against + each of the job's session logs, contatenated as a multi-line string joined by + a single newline character (\\n). + assert_fail_msg (str): The assertion message to raise if the pattern is found + The CloudWatch log group name is appended to the end of this message to assist + with diagnosis. The default is "Expected message found in session log". + consistency_wait_time (datetime.timedelta): Wait time between first and second check. + Default is 4.5 seconds. + """ + # Coerce Regex str patterns to a re.Pattern + if isinstance(expected_pattern, str): + expected_pattern = re.compile(expected_pattern) + + if not (session_log_config_options := self.logs.options): + raise ValueError('No "options" key in session "log" API response') + if not (log_group_name := session_log_config_options.get("logGroupName", None)): + raise ValueError('No "logGroupName" key in session "log" -> "options" API response') + + session_log = self.get_session_log(logs_client=logs_client) + session_log.assert_pattern_not_in_log( + expected_pattern=expected_pattern, + failure_msg=f"{assert_fail_msg}. Logs are in CloudWatch log group: {log_group_name}", + ) + if consistency_wait_time.total_seconds() > 0: + time.sleep(consistency_wait_time.total_seconds()) + session_log = self.get_session_log(logs_client=logs_client) + session_log.assert_pattern_not_in_log( + expected_pattern=expected_pattern, + failure_msg=f"{assert_fail_msg}. Logs are in CloudWatch log group: {log_group_name}", + ) + else: + LOG.warning( + "Expected pattern only checked for once. To check twice use non-zero consistency_wait_time" + ) + @dataclass class SessionLog: @@ -1202,6 +1321,36 @@ def assert_pattern_in_log( full_session_log = "\n".join(le.message for le in self.logs) assert expected_pattern.search(full_session_log), failure_msg + def assert_pattern_not_in_log( + self, + *, + expected_pattern: re.Pattern | str, + failure_msg: str, + ) -> None: + """ + Asserts that a pattern is not found in the session log + + Args: + expected_pattern (re.Pattern | str): Either a regular expression pattern string, or a + pre-compiled regular expression Pattern object. This is pattern is searched against + each of the job's session logs, contatenated as a multi-line string joined by + a single newline character (\\n). + failure_msg (str): A message to be raised in an AssertionError if the expected pattern + is found. + + Raises: + AssertionError + Raised when the expected pattern is found in the session log. The argument to + the AssertionError is the value of the failure_msg argument + """ + # Coerce Regex str patterns to a re.Pattern + if isinstance(expected_pattern, str): + expected_pattern = re.compile(expected_pattern) + + full_session_log = "\n".join(le.message for le in self.logs) + + assert True if expected_pattern.search(full_session_log) is None else False, failure_msg + @dataclass class CloudWatchLogEvent: diff --git a/test/unit/deadline/test_resources.py b/test/unit/deadline/test_resources.py index 297d177..289dbd9 100644 --- a/test/unit/deadline/test_resources.py +++ b/test/unit/deadline/test_resources.py @@ -5,8 +5,10 @@ import datetime import json import re +from collections.abc import Generator from dataclasses import asdict, replace -from typing import Any, Generator, cast +from datetime import timedelta +from typing import Any, cast from unittest.mock import MagicMock, call, patch import pytest @@ -14,11 +16,11 @@ from deadline_test_fixtures import ( CloudWatchLogEvent, Farm, - Queue, Fleet, - QueueFleetAssociation, Job, JobAttachmentSettings, + Queue, + QueueFleetAssociation, Session, Step, Task, @@ -31,6 +33,7 @@ @pytest.fixture(autouse=True) def wait_for_shim() -> Generator[None, None, None]: import sys + from deadline_test_fixtures.util import wait_for # Force the wait_for to have a short interval for unit tests @@ -358,6 +361,26 @@ def session( ) +@pytest.fixture(scope="function") +def log_client(): + def configure_log_client(session: Session, log_messages: list[str]): + logs_client = MagicMock() + logs = mod.SessionLog( + session_id=session.id, + logs=[ + mod.CloudWatchLogEvent( + ingestion_time=i, + message=message, + timestamp=i, + ) + for i, message in enumerate(log_messages) + ], + ) + return logs_client, logs + + return configure_log_client + + class TestFarm: def test_create(self) -> None: # GIVEN @@ -1025,6 +1048,51 @@ def test_get_logs(self, job: Job) -> None: CloudWatchLogEvent.from_api_response(le) for le in log_events[1]["events"] ] + def test_get_only_task_fail_on_multi_task(self, job: Job) -> None: + # GIVEN + deadline_client = MagicMock() + step = MagicMock() + task = MagicMock() + step.list_tasks.return_value = [task, task] + task.get_last_session.return_value = session + + with (patch.object(job, "list_steps", return_value=[step]) as mock_list_steps,): + + # WHEN + def when(): + job.get_only_task(deadline_client=deadline_client) + + # THEN + with pytest.raises(AssertionError) as raise_ctx: + when() + + print(raise_ctx.value) + + assert raise_ctx.match("Job contains multiple tasks") + mock_list_steps.assert_called_once_with(deadline_client=deadline_client) + step.list_tasks.assert_called_once_with(deadline_client=deadline_client) + task.get_last_session.assert_not_called() + + def test_get_only_task_fail_on_multi_step(self, job: Job) -> None: + # GIVEN + deadline_client = MagicMock() + step = MagicMock() + + with (patch.object(job, "list_steps", return_value=[step, step]) as mock_list_steps,): + # WHEN + def when(): + job.get_only_task(deadline_client=deadline_client) + + # THEN + with pytest.raises(AssertionError) as raise_ctx: + when() + + print(raise_ctx.value) + + assert raise_ctx.match("Job contains multiple steps") + mock_list_steps.assert_called_once_with(deadline_client=deadline_client) + step.list_tasks.assert_not_called() + def test_assert_single_task_log_contains_success(self, job: Job, session: Session) -> None: # GIVEN deadline_client = MagicMock() @@ -1061,63 +1129,44 @@ def test_assert_single_task_log_contains_success(self, job: Job, session: Sessio step.list_tasks.assert_called_once_with(deadline_client=deadline_client) task.get_last_session.assert_called_once_with(deadline_client=deadline_client) - def test_assert_single_task_log_contains_multi_step(self, job: Job) -> None: - # GIVEN - deadline_client = MagicMock() - logs_client = MagicMock() - step = MagicMock() - expected_pattern = re.compile(r"a message") - - with (patch.object(job, "list_steps", return_value=[step, step]) as mock_list_steps,): - - # WHEN - def when(): - job.assert_single_task_log_contains( - deadline_client=deadline_client, - logs_client=logs_client, - expected_pattern=expected_pattern, - ) - - # THEN - with pytest.raises(AssertionError) as raise_ctx: - when() - - print(raise_ctx.value) - - assert raise_ctx.match("Job contains multiple steps") - mock_list_steps.assert_called_once_with(deadline_client=deadline_client) - step.list_tasks.assert_not_called() - - def test_assert_single_task_log_contains_multi_task(self, job: Job, session: Session) -> None: + def test_assert_single_task_log_does_not_contain_success( + self, job: Job, session: Session + ) -> None: # GIVEN deadline_client = MagicMock() logs_client = MagicMock() step = MagicMock() task = MagicMock() - step.list_tasks.return_value = [task, task] + step.list_tasks.return_value = [task] task.get_last_session.return_value = session expected_pattern = re.compile(r"a message") - with (patch.object(job, "list_steps", return_value=[step]) as mock_list_steps,): + with ( + patch.object(job, "list_steps", return_value=[step]) as mock_list_steps, + patch.object( + session, "assert_log_does_not_contain" + ) as mock_session_assert_log_does_not_contain, + ): # WHEN - def when(): - job.assert_single_task_log_contains( - deadline_client=deadline_client, - logs_client=logs_client, - expected_pattern=expected_pattern, - ) - - # THEN - with pytest.raises(AssertionError) as raise_ctx: - when() - - print(raise_ctx.value) + job.assert_single_task_log_does_not_contain( + deadline_client=deadline_client, + logs_client=logs_client, + expected_pattern=expected_pattern, + ) - assert raise_ctx.match("Job contains multiple tasks") + # THEN + # This test is only to confirm that no assertion is raised, since the expected message + # is in the logs + mock_session_assert_log_does_not_contain.assert_called_once_with( + logs_client=logs_client, + expected_pattern=expected_pattern, + assert_fail_msg="Expected message found in session log", + consistency_wait_time=timedelta(seconds=3), + ) mock_list_steps.assert_called_once_with(deadline_client=deadline_client) step.list_tasks.assert_called_once_with(deadline_client=deadline_client) - task.get_last_session.assert_not_called() + task.get_last_session.assert_called_once_with(deadline_client=deadline_client) def test_list_steps( self, @@ -1376,13 +1425,14 @@ def test_get_last_session( class TestSession: - @pytest.mark.parametrize( - argnames=("expected_pattern", "log_messages"), - argvalues=( + + base_assertion_args = ( + ("expected_pattern", "log_messages"), + ( pytest.param("PATTERN", ["PATTERN"], id="exact-match"), pytest.param("PATTERN", ["PATTERN at beginning"], id="match-beginning"), pytest.param("PATTERN", ["ends with PATTERN"], id="match-end"), - pytest.param("PATTERN", ["multiline with", "the PATTERN"], id="match-end"), + pytest.param("PATTERN", ["multiline with", "the PATTERN"], id="multi-line"), pytest.param( re.compile(r"This is\na multiline pattern", re.MULTILINE), ["extra lines", "This is", "a multiline pattern", "embedded"], @@ -1395,25 +1445,17 @@ class TestSession: ), ), ) + + @pytest.mark.parametrize(*base_assertion_args) def test_assert_logs_success( self, session: Session, expected_pattern: str | re.Pattern, log_messages: list[str], + log_client, ) -> None: # GIVEN - logs_client = MagicMock() - logs = mod.SessionLog( - session_id=session.id, - logs=[ - mod.CloudWatchLogEvent( - ingestion_time=i, - message=message, - timestamp=i, - ) - for i, message in enumerate(log_messages) - ], - ) + logs_client, logs = log_client(session, log_messages) with ( patch.object(session, "get_session_log", return_value=logs) as mock_get_session_log, @@ -1432,6 +1474,113 @@ def test_assert_logs_success( mock_get_session_log.assert_called_once_with(logs_client=logs_client) mock_time_sleep.assert_not_called() + @pytest.mark.parametrize(*base_assertion_args) + def test_assert_logs_does_not_contain_fail( + self, + session: Session, + expected_pattern: str | re.Pattern, + log_messages: list[str], + log_client, + ) -> None: + # GIVEN + logs_client, logs = log_client(session, log_messages) + expected_assertion_msg = ( + "Expected message found in session log." + " Logs are in CloudWatch log group: sessionLogGroup" + ) + with ( + patch.object(session, "get_session_log", return_value=logs) as mock_get_session_log, + # Speed up tests + patch.object(mod.time, "sleep") as mock_time_sleep, + ): + + # WHEN + def when(): + session.assert_log_does_not_contain( + logs_client=logs_client, + expected_pattern=expected_pattern, + ) + + # THEN + with pytest.raises(AssertionError) as raise_ctx: + when() + assert raise_ctx.value.args[0] == expected_assertion_msg + mock_get_session_log.assert_called_once_with(logs_client=logs_client) + mock_time_sleep.assert_not_called() + + @pytest.mark.parametrize( + argnames=("expected_pattern", "log_messages"), + argvalues=( + pytest.param("DOES NOT MATCH", ["PATTERN"], id="no-match"), + pytest.param( + re.compile(r"There is\nno match", re.MULTILINE), + ["extra lines", "This is", "a multiline pattern", "embedded"], + id="multi-line-no-match", + ), + ), + ) + def test_assert_logs_does_not_contain_success( + self, + session: Session, + expected_pattern: str | re.Pattern, + log_messages: list[str], + log_client, + ) -> None: + # GIVEN + logs_client, logs = log_client(session, log_messages) + + with ( + patch.object(session, "get_session_log", return_value=logs) as mock_get_session_log, + # Speed up tests + patch.object(mod.time, "sleep") as mock_time_sleep, + ): + + # WHEN + session.assert_log_does_not_contain( + logs_client=logs_client, + expected_pattern=expected_pattern, + ) + + # THEN + # (no exception is raised) + mock_get_session_log.assert_has_calls([call(logs_client=logs_client)] * 2) + mock_time_sleep.assert_called_once_with(4.5) # 4.5 is default sleep duration + + @pytest.mark.parametrize( + argnames=("sleep_duration"), + argvalues=( + pytest.param(timedelta(seconds=9), id="custom-duration"), + pytest.param(timedelta(seconds=0), id="no-sleep"), + ), + ) + def test_assert_logs_does_not_contain_sleeps( + self, session: Session, sleep_duration: timedelta, log_client + ) -> None: + # GIVEN + logs_client, logs = log_client(session, ["PATTERN"]) + + with ( + patch.object(session, "get_session_log", return_value=logs) as mock_get_session_log, + # Speed up tests + patch.object(mod.time, "sleep") as mock_time_sleep, + ): + + # WHEN + session.assert_log_does_not_contain( + logs_client=logs_client, + expected_pattern="DOES NOT MATCH", + consistency_wait_time=sleep_duration, + ) + + # THEN + # (no exception is raised) + if sleep_duration.total_seconds() > 0: + mock_time_sleep.assert_called_once_with(sleep_duration.total_seconds()) + mock_get_session_log.assert_has_calls([call(logs_client=logs_client)] * 2) + else: + mock_get_session_log.assert_called_once_with(logs_client=logs_client) + mock_time_sleep.assert_not_called() + @pytest.mark.parametrize( argnames="assert_fail_msg", argvalues=(