diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index cce4e64176981..70ef60c154921 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -78,12 +78,9 @@ def _execute(self): def signal_handler(signum, frame): """Setting kill signal handler""" self.log.error("Received SIGTERM. Terminating subprocesses") - self.on_kill() - self.task_instance.refresh_from_db() - if self.task_instance.state not in State.finished: - self.task_instance.set_state(State.FAILED) - self.task_instance._run_finished_callback(error="task received sigterm") - raise AirflowException("LocalTaskJob received SIGTERM signal") + self.task_runner.terminate() + self.handle_task_exit(128 + signum) + return signal.signal(signal.SIGTERM, signal_handler) @@ -148,16 +145,19 @@ def signal_handler(signum, frame): self.on_kill() def handle_task_exit(self, return_code: int) -> None: - """Handle case where self.task_runner exits by itself""" + """Handle case where self.task_runner exits by itself or is externally killed""" + # Without setting this, heartbeat may get us + self.terminating = True self.log.info("Task exited with return code %s", return_code) self.task_instance.refresh_from_db() - # task exited by itself, so we need to check for error file + + if self.task_instance.state == State.RUNNING: + # This is for a case where the task received a SIGKILL + # while running or the task runner received a sigterm + self.task_instance.handle_failure(error=None) + # We need to check for error file # in case it failed due to runtime exception/error error = None - if self.task_instance.state == State.RUNNING: - # This is for a case where the task received a sigkill - # while running - self.task_instance.set_state(State.FAILED) if self.task_instance.state != State.SUCCESS: error = self.task_runner.deserialize_run_error() self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index d9f1398d9e7b2..94f894d134d6d 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -21,6 +21,7 @@ import signal import time import uuid +from datetime import timedelta from multiprocessing import Lock, Value from unittest import mock from unittest.mock import patch @@ -272,7 +273,6 @@ def test_heartbeat_failed_fast(self): delta = (time2 - time1).total_seconds() assert abs(delta - job.heartrate) < 0.5 - @pytest.mark.quarantined def test_mark_success_no_kill(self): """ Test that ensures that mark_success in the UI doesn't cause @@ -300,7 +300,6 @@ def test_mark_success_no_kill(self): job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) process = multiprocessing.Process(target=job1.run) process.start() - ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break @@ -510,7 +509,6 @@ def dummy_return_code(*args, **kwargs): assert ti.state == State.FAILED # task exits with failure state assert failure_callback_called.value == 1 - @pytest.mark.quarantined def test_mark_success_on_success_callback(self, dag_maker): """ Test that ensures that where a task is marked success in the UI @@ -567,15 +565,9 @@ def task_function(ti): assert task_terminated_externally.value == 1 assert not process.is_alive() - @parameterized.expand( - [ - (signal.SIGTERM,), - (signal.SIGKILL,), - ] - ) - def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker): + def test_task_sigkill_calls_on_failure_callback(self, dag_maker): """ - Test that ensures that when a task is killed with sigterm or sigkill + Test that ensures that when a task is killed with sigkill on_failure_callback gets executed """ # use shared memory value so we can properly track value change even if @@ -587,10 +579,50 @@ def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker): def failure_callback(context): with shared_mem_lock: failure_callback_called.value += 1 - assert context['dag_run'].dag_id == 'test_mark_failure' + assert context['dag_run'].dag_id == 'test_send_sigkill' def task_function(ti): + os.kill(os.getpid(), signal.SIGKILL) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + with dag_maker(dag_id='test_send_sigkill'): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, + ) + + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + time.sleep(0.3) + process.join(timeout=10) + assert failure_callback_called.value == 1 + assert task_terminated_externally.value == 1 + assert not process.is_alive() + + def test_process_sigterm_calls_on_failure_callback(self, dag_maker): + """ + Test that ensures that when a task runner is killed with sigterm + on_failure_callback gets executed + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + failure_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + def failure_callback(context): + with shared_mem_lock: + failure_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure' + + def task_function(ti): time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed with shared_mem_lock: @@ -605,20 +637,16 @@ def task_function(ti): ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) - job1.task_runner = StandardTaskRunner(job1) - settings.engine.dispose() process = multiprocessing.Process(target=job1.run) process.start() - - for _ in range(0, 20): + for _ in range(0, 25): ti.refresh_from_db() - if ti.state == State.RUNNING and ti.pid is not None: + if ti.state == State.RUNNING: break time.sleep(0.2) - assert ti.pid is not None - assert ti.state == State.RUNNING - os.kill(ti.pid, signal_type) + os.kill(process.pid, signal.SIGTERM) + ti.refresh_from_db() process.join(timeout=10) assert failure_callback_called.value == 1 assert task_terminated_externally.value == 1 @@ -726,6 +754,102 @@ def test_fast_follow( if scheduler_job.processor_agent: scheduler_job.processor_agent.end() + def test_task_sigkill_works_with_retries(self, dag_maker): + """ + Test that ensures that tasks are retried when they receive sigkill + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + retry_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + + def retry_callback(context): + with shared_mem_lock: + retry_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure_2' + + def task_function(ti): + os.kill(os.getpid(), signal.SIGKILL) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + with dag_maker( + dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} + ): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + retries=1, + retry_delay=timedelta(seconds=2), + on_retry_callback=retry_callback, + ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.start() + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + time.sleep(0.4) + process.join() + ti.refresh_from_db() + assert ti.state == State.UP_FOR_RETRY + assert retry_callback_called.value == 1 + assert task_terminated_externally.value == 1 + + def test_process_sigterm_works_with_retries(self, dag_maker): + """ + Test that ensures that task runner sets tasks to retry when they(task runner) + receive sigterm + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + retry_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + + def retry_callback(context): + with shared_mem_lock: + retry_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure_2' + + def task_function(ti): + time.sleep(60) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + with dag_maker(dag_id='test_mark_failure_2'): + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + retries=1, + retry_delay=timedelta(seconds=2), + on_retry_callback=retry_callback, + ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.start() + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + for _ in range(0, 25): + ti.refresh_from_db() + if ti.state == State.RUNNING and ti.pid is not None: + break + time.sleep(0.2) + os.kill(process.pid, signal.SIGTERM) + process.join() + ti.refresh_from_db() + assert ti.state == State.UP_FOR_RETRY + assert retry_callback_called.value == 1 + assert task_terminated_externally.value == 1 + def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self): """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) @@ -788,5 +912,5 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes) mock_get_task_runner.return_value.return_code.side_effects = return_codes job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) - with assert_queries_count(16): + with assert_queries_count(18): job.run() diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index c1882e11ded7a..db232714205a4 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -18,6 +18,7 @@ import datetime import os +import signal import time import unittest import urllib @@ -522,6 +523,37 @@ def raise_skip_exception(): ti.run() assert State.SKIPPED == ti.state + def test_task_sigterm_works_with_retries(self): + """ + Test that ensures that tasks are retried when they receive sigterm + """ + dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) + + def task_function(ti): + # pylint: disable=unused-argument + os.kill(ti.pid, signal.SIGTERM) + + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + retries=1, + retry_delay=datetime.timedelta(seconds=2), + dag=dag, + ) + + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + ) + ti = TI(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + with self.assertRaises(AirflowException): + ti.run() + ti.refresh_from_db() + assert ti.state == State.UP_FOR_RETRY + def test_retry_delay(self): """ Test that retry delays are respected