Skip to content

Commit

Permalink
ECSOperator returns last logs when ECS task fails (#17209)
Browse files Browse the repository at this point in the history
closes: apache/airflow#17038

This PR changes the message in the AirflowException when the ECS task launched by ECSOperator is stopped.

**Before:**
The message when it failed was:
`This task is not in success state {<huge JSON from AWS containing all the ECS task details>}`

**Now:**
The message is:
```
This task is not in success state - last logs from Cloudwatch:
<last_logs_from_cloudwatch>
```
which makes it much more useful to understand what failed in the underlying code directly from the alert.

The number of logs can be customized with the parameter `number_logs_exception`.

GitOrigin-RevId: e6cb2f7beb4c6ea4ad4a965f9c0f2b8f6978129c
  • Loading branch information
pmalafosse authored and Cloud Composer Team committed Sep 12, 2024
1 parent d2d9c97 commit e0b6561
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
22 changes: 18 additions & 4 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,18 @@ class ECSOperator(BaseOperator):
Only required if you want logs to be shown in the Airflow UI after your job has
finished.
:type awslogs_stream_prefix: str
:param quota_retry: Config if and how to retry the launch of a new ECS task, to handle
transient errors.
:type quota_retry: dict
:param reattach: If set to True, will check if the task previously launched by the task_instance
is already running. If so, the operator will attach to it instead of starting a new task.
This is to avoid relaunching a new task when the connection drops between Airflow and ECS while
the task is running (when the Airflow worker is restarted for example).
:type reattach: bool
:param quota_retry: Config if and how to retry _start_task() for transient errors.
:type quota_retry: dict
:param number_logs_exception: Number of lines from the last Cloudwatch logs to return in the
AirflowException if an ECS task is stopped (to receive Airflow alerts with the logs of what
failed in the code running in ECS).
:type number_logs_exception: int
"""

ui_color = '#f0ede4'
Expand Down Expand Up @@ -178,6 +183,7 @@ def __init__(
propagate_tags: Optional[str] = None,
quota_retry: Optional[dict] = None,
reattach: bool = False,
number_logs_exception: int = 10,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -201,6 +207,7 @@ def __init__(
self.awslogs_region = awslogs_region
self.propagate_tags = propagate_tags
self.reattach = reattach
self.number_logs_exception = number_logs_exception

if self.awslogs_region is None:
self.awslogs_region = region_name
Expand Down Expand Up @@ -342,9 +349,12 @@ def _cloudwatch_log_events(self) -> Generator:
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

def _last_log_messages(self, number_messages):
return [log["message"] for log in deque(self._cloudwatch_log_events(), maxlen=number_messages)]

def _last_log_message(self):
try:
return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"]
return self._last_log_messages(1)[0]
except IndexError:
return None

Expand Down Expand Up @@ -377,7 +387,11 @@ def _check_success_task(self) -> None:
containers = task['containers']
for container in containers:
if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0:
raise AirflowException(f'This task is not in success state {task}')
last_logs = "\n".join(self._last_log_messages(self.number_logs_exception))
raise AirflowException(
f"This task is not in success state - last {self.number_logs_exception} "
f"logs from Cloudwatch:\n{last_logs}"
)
elif container.get('lastStatus') == 'PENDING':
raise AirflowException(f'This task is still pending {task}')
elif 'error' in container.get('reason', '').lower():
Expand Down
54 changes: 43 additions & 11 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,26 @@ def test_wait_end_tasks(self):
client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn'])
assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts

def test_check_success_tasks_raises(self):
@mock.patch.object(
ECSOperator, '_cloudwatch_log_events', return_value=({"message": str(i)} for i in range(10))
)
def test_last_log_messages(self, mock_cloudwatch_log_events):
client_mock = mock.Mock()
self.ecs.arn = 'arn'
self.ecs.client = client_mock

assert self.ecs._last_log_messages(5) == ["5", "6", "7", "8", "9"]

@mock.patch.object(ECSOperator, '_cloudwatch_log_events', return_value=())
def test_last_log_messages_empty(self, mock_cloudwatch_log_events):
client_mock = mock.Mock()
self.ecs.arn = 'arn'
self.ecs.client = client_mock

assert self.ecs._last_log_messages(10) == []

@mock.patch.object(ECSOperator, '_last_log_messages', return_value=["1", "2", "3", "4", "5"])
def test_check_success_tasks_raises_cloudwatch_logs(self, mock_last_log_messages):
client_mock = mock.Mock()
self.ecs.arn = 'arn'
self.ecs.client = client_mock
Expand All @@ -255,11 +274,24 @@ def test_check_success_tasks_raises(self):
with pytest.raises(Exception) as ctx:
self.ecs._check_success_task()

# Ordering of str(dict) is not guaranteed.
assert "This task is not in success state " in str(ctx.value)
assert "'name': 'foo'" in str(ctx.value)
assert "'lastStatus': 'STOPPED'" in str(ctx.value)
assert "'exitCode': 1" in str(ctx.value)
assert str(ctx.value) == (
"This task is not in success state - last 10 logs from Cloudwatch:\n1\n2\n3\n4\n5"
)
client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])

@mock.patch.object(ECSOperator, '_last_log_messages', return_value=[])
def test_check_success_tasks_raises_cloudwatch_logs_empty(self, mock_last_log_messages):
client_mock = mock.Mock()
self.ecs.arn = 'arn'
self.ecs.client = client_mock

client_mock.describe_tasks.return_value = {
'tasks': [{'containers': [{'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1}]}]
}
with pytest.raises(Exception) as ctx:
self.ecs._check_success_task()

assert str(ctx.value) == "This task is not in success state - last 10 logs from Cloudwatch:\n"
client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])

def test_check_success_tasks_raises_pending(self):
Expand Down Expand Up @@ -442,17 +474,17 @@ def test_reattach_save_task_arn_xcom(
xcom_del_mock.assert_called_once()
assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'

@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
@mock.patch.object(ECSOperator, '_last_log_messages', return_value=["Log output"])
def test_execute_xcom_with_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
assert self.ecs.execute(None) == mock_cloudwatch_log_message.return_value
assert self.ecs.execute(None) == "Log output"

@mock.patch.object(ECSOperator, '_last_log_message', return_value=None)
@mock.patch.object(ECSOperator, '_last_log_messages', return_value=[])
def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
assert self.ecs.execute(None) == mock_cloudwatch_log_message.return_value
assert self.ecs.execute(None) is None

@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
@mock.patch.object(ECSOperator, '_last_log_messages', return_value=["Log output"])
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = False
assert self.ecs.execute(None) is None
Expand Down

0 comments on commit e0b6561

Please sign in to comment.