diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index b642310614249..085ef99d65d92 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -17,7 +17,6 @@ # under the License. # -import os import signal from typing import Optional @@ -154,6 +153,10 @@ def handle_task_exit(self, return_code: int) -> None: # task exited by itself, so we need to check for error file # incase it failed due to runtime exception/error error = None + if self.task_instance.state == State.RUNNING: + # This is for a case where the task received a sigkill + # while running + self.task_instance.set_state(State.FAILED) if self.task_instance.state != State.SUCCESS: error = self.task_runner.deserialize_run_error() self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access @@ -184,9 +187,9 @@ def heartbeat_callback(self, session=None): ) raise AirflowException("Hostname of job runner does not match") - current_pid = os.getpid() + current_pid = self.task_runner.process.pid same_process = ti.pid == current_pid - if not same_process: + if ti.pid is not None and not same_process: self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid) raise AirflowException("PID of job runner does not match") elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d51a5af405a50..cb37e6072ffd9 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1072,7 +1072,6 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argumen if not test_mode: session.add(Log(State.RUNNING, self)) self.state = State.RUNNING - self.pid = os.getpid() self.end_date = None if not test_mode: session.merge(self) @@ -1127,7 +1126,9 @@ def _run_raw_task( self.refresh_from_db(session=session) self.job_id = job_id self.hostname = get_hostname() - + self.pid = os.getpid() + session.merge(self) + session.commit() actual_start_date = timezone.utcnow() Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}') try: diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index bb986b2b092cc..6a7bc5843bd68 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -27,6 +27,7 @@ from unittest.mock import patch import pytest +from parameterized import parameterized from airflow import settings from airflow.exceptions import AirflowException, AirflowFailException @@ -92,8 +93,7 @@ def test_localtaskjob_essential_attr(self): check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr] assert all(check_result_2) - @patch('os.getpid') - def test_localtaskjob_heartbeat(self, mock_pid): + def test_localtaskjob_heartbeat(self): session = settings.Session() dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) @@ -114,19 +114,23 @@ def test_localtaskjob_heartbeat(self, mock_pid): session.commit() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + ti.task = op1 + ti.refresh_from_task(op1) + job1.task_runner = StandardTaskRunner(job1) + job1.task_runner.process = mock.Mock() with pytest.raises(AirflowException): job1.heartbeat_callback() # pylint: disable=no-value-for-parameter - mock_pid.return_value = 1 + job1.task_runner.process.pid = 1 ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() - + assert ti.pid != os.getpid() job1.heartbeat_callback(session=None) - mock_pid.return_value = 2 + job1.task_runner.process.pid = 2 with pytest.raises(AirflowException): job1.heartbeat_callback() # pylint: disable=no-value-for-parameter @@ -496,9 +500,15 @@ def task_function(ti): assert task_terminated_externally.value == 1 assert not process.is_alive() - def test_process_kill_call_on_failure_callback(self): + @parameterized.expand( + [ + (signal.SIGTERM,), + (signal.SIGKILL,), + ] + ) + def test_process_kill_calls_on_failure_callback(self, signal_type): """ - Test that ensures that when a task is killed with sigterm + Test that ensures that when a task is killed with sigterm or sigkill on_failure_callback gets executed """ # use shared memory value so we can properly track value change even if @@ -547,13 +557,14 @@ def task_function(ti): process = multiprocessing.Process(target=job1.run) process.start() - for _ in range(0, 10): + for _ in range(0, 20): ti.refresh_from_db() - if ti.state == State.RUNNING: + if ti.state == State.RUNNING and ti.pid is not None: break time.sleep(0.2) assert ti.state == State.RUNNING - os.kill(ti.pid, signal.SIGTERM) + assert ti.pid is not None + os.kill(ti.pid, signal_type) process.join(timeout=10) assert failure_callback_called.value == 1 assert task_terminated_externally.value == 1 @@ -584,5 +595,5 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes) mock_get_task_runner.return_value.return_code.side_effects = return_codes job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) - with assert_queries_count(13): + with assert_queries_count(15): job.run() diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 8c3b1001f7d7d..909fad5e45532 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2073,8 +2073,8 @@ def tearDown(self) -> None: @parameterized.expand( [ # Expected queries, mark_success - (10, False), - (5, True), + (12, False), + (7, True), ] ) def test_execute_queries_count(self, expected_query_count, mark_success): @@ -2110,7 +2110,7 @@ def test_execute_queries_count_store_serialized(self): session=session, ) - with assert_queries_count(10): + with assert_queries_count(12): ti._run_raw_task() def test_operator_field_with_serialization(self):