Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix on_failure_callback when task receive SIGKILL #15537

Merged
merged 4 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
#

import os
import signal
from typing import Optional

Expand Down Expand Up @@ -154,6 +153,10 @@ def handle_task_exit(self, return_code: int) -> None:
# task exited by itself, so we need to check for error file
# incase it failed due to runtime exception/error
error = None
if self.task_instance.state == State.RUNNING:
# This is for a case where the task received a sigkill
# while running
self.task_instance.set_state(State.FAILED)
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access
Expand Down Expand Up @@ -184,9 +187,9 @@ def heartbeat_callback(self, session=None):
)
raise AirflowException("Hostname of job runner does not match")

current_pid = os.getpid()
current_pid = self.task_runner.process.pid
same_process = ti.pid == current_pid
if not same_process:
if ti.pid is not None and not same_process:
self.log.warning("Recorded pid %s does not match " "the current pid %s", ti.pid, current_pid)
raise AirflowException("PID of job runner does not match")
elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'):
Expand Down
5 changes: 3 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,6 @@ def check_and_change_state_before_execution( # pylint: disable=too-many-argumen
if not test_mode:
session.add(Log(State.RUNNING, self))
self.state = State.RUNNING
self.pid = os.getpid()
self.end_date = None
if not test_mode:
session.merge(self)
Expand Down Expand Up @@ -1127,7 +1126,9 @@ def _run_raw_task(
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()

self.pid = os.getpid()
session.merge(self)
session.commit()
actual_start_date = timezone.utcnow()
Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}')
try:
Expand Down
33 changes: 22 additions & 11 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittest.mock import patch

import pytest
from parameterized import parameterized

from airflow import settings
from airflow.exceptions import AirflowException, AirflowFailException
Expand Down Expand Up @@ -92,8 +93,7 @@ def test_localtaskjob_essential_attr(self):
check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr]
assert all(check_result_2)

@patch('os.getpid')
def test_localtaskjob_heartbeat(self, mock_pid):
def test_localtaskjob_heartbeat(self):
session = settings.Session()
dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

Expand All @@ -114,19 +114,23 @@ def test_localtaskjob_heartbeat(self, mock_pid):
session.commit()

job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
ti.task = op1
ti.refresh_from_task(op1)
job1.task_runner = StandardTaskRunner(job1)
job1.task_runner.process = mock.Mock()
with pytest.raises(AirflowException):
job1.heartbeat_callback() # pylint: disable=no-value-for-parameter

mock_pid.return_value = 1
job1.task_runner.process.pid = 1
ti.state = State.RUNNING
ti.hostname = get_hostname()
ti.pid = 1
session.merge(ti)
session.commit()

assert ti.pid != os.getpid()
job1.heartbeat_callback(session=None)

mock_pid.return_value = 2
job1.task_runner.process.pid = 2
with pytest.raises(AirflowException):
job1.heartbeat_callback() # pylint: disable=no-value-for-parameter

Expand Down Expand Up @@ -496,9 +500,15 @@ def task_function(ti):
assert task_terminated_externally.value == 1
assert not process.is_alive()

def test_process_kill_call_on_failure_callback(self):
@parameterized.expand(
[
(signal.SIGTERM,),
(signal.SIGKILL,),
]
)
def test_process_kill_calls_on_failure_callback(self, signal_type):
"""
Test that ensures that when a task is killed with sigterm
Test that ensures that when a task is killed with sigterm or sigkill
on_failure_callback gets executed
"""
# use shared memory value so we can properly track value change even if
Expand Down Expand Up @@ -547,13 +557,14 @@ def task_function(ti):
process = multiprocessing.Process(target=job1.run)
process.start()

for _ in range(0, 10):
for _ in range(0, 20):
ti.refresh_from_db()
if ti.state == State.RUNNING:
if ti.state == State.RUNNING and ti.pid is not None:
break
time.sleep(0.2)
assert ti.state == State.RUNNING
os.kill(ti.pid, signal.SIGTERM)
assert ti.pid is not None
os.kill(ti.pid, signal_type)
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
Expand Down Expand Up @@ -584,5 +595,5 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes)
mock_get_task_runner.return_value.return_code.side_effects = return_codes

job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
with assert_queries_count(13):
with assert_queries_count(15):
job.run()
6 changes: 3 additions & 3 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2073,8 +2073,8 @@ def tearDown(self) -> None:
@parameterized.expand(
[
# Expected queries, mark_success
(10, False),
(5, True),
(12, False),
(7, True),
]
)
def test_execute_queries_count(self, expected_query_count, mark_success):
Expand Down Expand Up @@ -2110,7 +2110,7 @@ def test_execute_queries_count_store_serialized(self):
session=session,
)

with assert_queries_count(10):
with assert_queries_count(12):
ti._run_raw_task()

def test_operator_field_with_serialization(self):
Expand Down