Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed Jul 18, 2024
1 parent 9bfe20c commit 9225fbf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 62 deletions.
43 changes: 14 additions & 29 deletions tests/providers/amazon/aws/executors/batch/test_batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import logging
import os
from unittest import mock
from unittest.mock import call

import pytest
import yaml
Expand All @@ -30,6 +29,7 @@
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.amazon.aws.executors.batch import batch_executor, batch_executor_config
from airflow.providers.amazon.aws.executors.batch.batch_executor import (
AwsBatchExecutor,
Expand Down Expand Up @@ -195,17 +195,16 @@ def test_execute(self, mock_executor):
mock_executor.batch.submit_job.assert_called_once()
assert len(mock_executor.active_workers) == 1

@mock.patch.object(AwsBatchExecutor, "send_message_to_task_logs")
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_send_message_to_task_logs, mock_executor):
def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor):
"""
Test how jobs are tried when one job fails, but others pass.
The expected behaviour is that in one sync() iteration, all the jobs are attempted
exactly once. Successful jobs are removed from pending_jobs to active_workers, and
failed jobs are added back to the pending_jobs queue to be run in the next iteration.
"""
airflow_key = mock.Mock(spec=tuple)
airflow_key = TaskInstanceKey("a", "b", "c", 1, -1)
airflow_cmd1 = mock.Mock(spec=list)
airflow_cmd2 = mock.Mock(spec=list)
airflow_commands = [airflow_cmd1, airflow_cmd2]
Expand Down Expand Up @@ -260,25 +259,19 @@ def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_send_message_to_task
mock_executor.attempt_submit_jobs()
submit_job_args["containerOverrides"]["command"] = airflow_commands[0]
assert mock_executor.batch.submit_job.call_args_list[5].kwargs == submit_job_args
mock_send_message_to_task_logs.assert_called_once_with(
logging.ERROR,
"This job has been unsuccessfully attempted too many times (%s). Dropping the task. Reason: %s",
3,
"Failure 1",
ti=airflow_key,
)
log_record = mock_executor._task_event_logs[0]
assert log_record.event == "batch job submit failure"

@mock.patch.object(AwsBatchExecutor, "send_message_to_task_logs")
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_all_jobs_when_jobs_fail(self, _, mock_send_message_to_task_logs, mock_executor):
def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor):
"""
Test job retry behaviour when jobs fail validation.
Test that when a job fails with a client sided exception, all the jobs are
attempted once. If all jobs fail, then the length of pending tasks should not change,
until all the tasks have been attempted the maximum number of times.
"""
airflow_key = mock.Mock(spec=tuple)
airflow_key = TaskInstanceKey("a", "b", "c", 1, -1)
airflow_cmd1 = mock.Mock(spec=list)
airflow_cmd2 = mock.Mock(spec=list)
commands = [airflow_cmd1, airflow_cmd2]
Expand Down Expand Up @@ -313,18 +306,8 @@ def test_attempt_all_jobs_when_jobs_fail(self, _, mock_send_message_to_task_logs

mock_executor.batch.submit_job.side_effect = failures
mock_executor.attempt_submit_jobs()
calls = []
for i in range(2):
calls.append(
call(
logging.ERROR,
"This job has been unsuccessfully attempted too many times (%s). Dropping the task. Reason: %s",
3,
f"Failure {i + 1}",
ti=airflow_key,
)
)
mock_send_message_to_task_logs.assert_has_calls(calls)
events = [(x.event, x.task_id, x.try_number) for x in mock_executor._task_event_logs]
assert events == [("batch job submit failure", "b", 1)] * 2

def test_attempt_submit_jobs_failure(self, mock_executor):
mock_executor.batch.submit_job.side_effect = NoCredentialsError()
Expand Down Expand Up @@ -465,12 +448,14 @@ def test_sync(self, success_mock, fail_mock, mock_airflow_key, mock_executor):

@mock.patch.object(BaseExecutor, "fail")
@mock.patch.object(BaseExecutor, "success")
@mock.patch.object(AwsBatchExecutor, "send_message_to_task_logs")
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_failed_sync(self, _, _2, success_mock, fail_mock, mock_airflow_key, mock_executor):
def test_failed_sync(self, _, success_mock, fail_mock, mock_airflow_key, mock_executor):
"""Test failure states"""
self._mock_sync(
executor=mock_executor, airflow_key=mock_airflow_key(), status="FAILED", attempt_number=2
executor=mock_executor,
airflow_key=mock_airflow_key(),
status="FAILED",
attempt_number=2,
)

mock_executor.sync()
Expand Down
54 changes: 21 additions & 33 deletions tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from functools import partial
from typing import Callable
from unittest import mock
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock

import pytest
import yaml
Expand Down Expand Up @@ -462,19 +462,19 @@ def test_failed_execute_api(self, mock_executor):
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0

@mock.patch.object(AwsEcsExecutor, "send_message_to_task_logs")
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_task_runs_attempts_when_tasks_fail(
self, _, mock_send_message_to_task_logs, mock_executor
):
def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor):
"""
Test case when all tasks fail to run.
The executor should attempt each task exactly once per sync() iteration.
It should preserve the order of tasks, and attempt each task up to
`MAX_RUN_TASK_ATTEMPTS` times before dropping the task.
"""
airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)]
airflow_keys = [
TaskInstanceKey("a", "task_a", "c", 1, -1),
TaskInstanceKey("a", "task_b", "c", 1, -1),
]
airflow_cmd1 = mock.Mock(spec=list)
airflow_cmd2 = mock.Mock(spec=list)
commands = [airflow_cmd1, airflow_cmd2]
Expand Down Expand Up @@ -515,25 +515,14 @@ def test_attempt_task_runs_attempts_when_tasks_fail(
assert len(mock_executor.active_workers.get_all_arns()) == 0
assert len(mock_executor.pending_tasks) == 0

calls = []
for i in range(2):
calls.append(
call(
logging.ERROR,
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
airflow_keys[i],
3,
f"Failure {i + 1}",
ti=airflow_keys[i],
)
)
mock_send_message_to_task_logs.assert_has_calls(calls)
events = [(x.event, x.task_id, x.try_number) for x in mock_executor._task_event_logs]
assert events == [
("ecs task submit failure", "task_a", 1),
("ecs task submit failure", "task_b", 1),
]

@mock.patch.object(AwsEcsExecutor, "send_message_to_task_logs")
@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_task_runs_attempts_when_some_tasks_fal(
self, _, mock_send_message_to_task_logs, mock_executor
):
def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor):
"""
Test case when one task fail to run, and a new task gets queued.
Expand All @@ -542,7 +531,10 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(
`MAX_RUN_TASK_ATTEMPTS` times before dropping the task. If a task succeeds, the task
should be removed from pending_jobs and into active_workers.
"""
airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)]
airflow_keys = [
TaskInstanceKey("a", "task_a", "c", 1, -1),
TaskInstanceKey("a", "task_b", "c", 1, -1),
]
airflow_cmd1 = mock.Mock(spec=list)
airflow_cmd2 = mock.Mock(spec=list)
airflow_commands = [airflow_cmd1, airflow_cmd2]
Expand Down Expand Up @@ -604,15 +596,11 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(

RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0]
assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS

mock_send_message_to_task_logs.assert_called_once_with(
logging.ERROR,
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
airflow_keys[0],
3,
"Failure 1",
ti=airflow_keys[0],
)
events = [(x.event, x.task_id, x.try_number) for x in mock_executor._task_event_logs]
assert events == [
("ecs task submit failure", "task_a", 1),
("ecs task submit failure", "task_b", 1),
]

@mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog):
Expand Down

0 comments on commit 9225fbf

Please sign in to comment.