diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 07f5702815ae9..b1d6d5324b98c 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -136,13 +136,18 @@ class ECSOperator(BaseOperator): Only required if you want logs to be shown in the Airflow UI after your job has finished. :type awslogs_stream_prefix: str + :param quota_retry: Config if and how to retry the launch of a new ECS task, to handle + transient errors. + :type quota_retry: dict :param reattach: If set to True, will check if the task previously launched by the task_instance is already running. If so, the operator will attach to it instead of starting a new task. This is to avoid relaunching a new task when the connection drops between Airflow and ECS while the task is running (when the Airflow worker is restarted for example). :type reattach: bool - :param quota_retry: Config if and how to retry _start_task() for transient errors. - :type quota_retry: dict + :param number_logs_exception: Number of lines from the last Cloudwatch logs to return in the + AirflowException if an ECS task is stopped (to receive Airflow alerts with the logs of what + failed in the code running in ECS). + :type number_logs_exception: int """ ui_color = '#f0ede4' @@ -178,6 +183,7 @@ def __init__( propagate_tags: Optional[str] = None, quota_retry: Optional[dict] = None, reattach: bool = False, + number_logs_exception: int = 10, **kwargs, ): super().__init__(**kwargs) @@ -201,6 +207,7 @@ def __init__( self.awslogs_region = awslogs_region self.propagate_tags = propagate_tags self.reattach = reattach + self.number_logs_exception = number_logs_exception if self.awslogs_region is None: self.awslogs_region = region_name @@ -343,9 +350,12 @@ def _cloudwatch_log_events(self) -> Generator: def _aws_logs_enabled(self): return self.awslogs_group and self.awslogs_stream_prefix + def _last_log_messages(self, number_messages): + return [log["message"] for log in deque(self._cloudwatch_log_events(), maxlen=number_messages)] + def _last_log_message(self): try: - return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"] + return self._last_log_messages(1)[0] except IndexError: return None @@ -378,7 +388,11 @@ def _check_success_task(self) -> None: containers = task['containers'] for container in containers: if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0: - raise AirflowException(f'This task is not in success state {task}') + last_logs = "\n".join(self._last_log_messages(self.number_logs_exception)) + raise AirflowException( + f"This task is not in success state - last {self.number_logs_exception} " + f"logs from Cloudwatch:\n{last_logs}" + ) elif container.get('lastStatus') == 'PENDING': raise AirflowException(f'This task is still pending {task}') elif 'error' in container.get('reason', '').lower(): diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 4013450ec9c55..12a34f45142f6 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -217,7 +217,26 @@ def test_wait_end_tasks(self): client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts - def test_check_success_tasks_raises(self): + @mock.patch.object( + ECSOperator, '_cloudwatch_log_events', return_value=({"message": str(i)} for i in range(10)) + ) + def test_last_log_messages(self, mock_cloudwatch_log_events): + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + assert self.ecs._last_log_messages(5) == ["5", "6", "7", "8", "9"] + + @mock.patch.object(ECSOperator, '_cloudwatch_log_events', return_value=()) + def test_last_log_messages_empty(self, mock_cloudwatch_log_events): + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + assert self.ecs._last_log_messages(10) == [] + + @mock.patch.object(ECSOperator, '_last_log_messages', return_value=["1", "2", "3", "4", "5"]) + def test_check_success_tasks_raises_cloudwatch_logs(self, mock_last_log_messages): client_mock = mock.Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock @@ -228,11 +247,24 @@ def test_check_success_tasks_raises(self): with pytest.raises(Exception) as ctx: self.ecs._check_success_task() - # Ordering of str(dict) is not guaranteed. - assert "This task is not in success state " in str(ctx.value) - assert "'name': 'foo'" in str(ctx.value) - assert "'lastStatus': 'STOPPED'" in str(ctx.value) - assert "'exitCode': 1" in str(ctx.value) + assert str(ctx.value) == ( + "This task is not in success state - last 10 logs from Cloudwatch:\n1\n2\n3\n4\n5" + ) + client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) + + @mock.patch.object(ECSOperator, '_last_log_messages', return_value=[]) + def test_check_success_tasks_raises_cloudwatch_logs_empty(self, mock_last_log_messages): + client_mock = mock.Mock() + self.ecs.arn = 'arn' + self.ecs.client = client_mock + + client_mock.describe_tasks.return_value = { + 'tasks': [{'containers': [{'name': 'foo', 'lastStatus': 'STOPPED', 'exitCode': 1}]}] + } + with pytest.raises(Exception) as ctx: + self.ecs._check_success_task() + + assert str(ctx.value) == "This task is not in success state - last 10 logs from Cloudwatch:\n" client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): @@ -415,17 +447,17 @@ def test_reattach_save_task_arn_xcom( xcom_del_mock.assert_called_once() assert self.ecs.arn == 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' - @mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output") + @mock.patch.object(ECSOperator, '_last_log_messages', return_value=["Log output"]) def test_execute_xcom_with_log(self, mock_cloudwatch_log_message): self.ecs.do_xcom_push = True - assert self.ecs.execute(None) == mock_cloudwatch_log_message.return_value + assert self.ecs.execute(None) == "Log output" - @mock.patch.object(ECSOperator, '_last_log_message', return_value=None) + @mock.patch.object(ECSOperator, '_last_log_messages', return_value=[]) def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message): self.ecs.do_xcom_push = True - assert self.ecs.execute(None) == mock_cloudwatch_log_message.return_value + assert self.ecs.execute(None) is None - @mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output") + @mock.patch.object(ECSOperator, '_last_log_messages', return_value=["Log output"]) def test_execute_xcom_disabled(self, mock_cloudwatch_log_message): self.ecs.do_xcom_push = False assert self.ecs.execute(None) is None