Skip to content

Commit

Permalink
[AIRFLOW-5102] Worker jobs should terminate themselves if they can't …
Browse files Browse the repository at this point in the history
…heartbeat (apache#6284)

If a LocalTaskJob fails to heartbeat for
scheduler_zombie_task_threshold, it should shut itself down.

However, at some point, a change was made to catch exceptions inside the
heartbeat, so the LocalTaskJob thought it had managed to heartbeat
successfully.

This effectively means that zombie tasks don't shut themselves down.
When the scheduler reschedules the job, this means we could have two
instances of the task running concurrently.
  • Loading branch information
ashb authored Oct 8, 2019
1 parent 93bb5e4 commit 68b8ec5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 17 deletions.
24 changes: 15 additions & 9 deletions airflow/jobs/base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,33 +160,37 @@ def heartbeat(self):
heart rate. If you go over 60 seconds before calling it, it won't
sleep at all.
"""
previous_heartbeat = self.latest_heartbeat

try:
with create_session() as session:
job = session.query(BaseJob).filter_by(id=self.id).one()
make_transient(job)
session.commit()
# This will cause it to load from the db
session.merge(self)
previous_heartbeat = self.latest_heartbeat

if job.state == State.SHUTDOWN:
if self.state == State.SHUTDOWN:
self.kill()

is_unit_test = conf.getboolean('core', 'unit_test_mode')
if not is_unit_test:
# Figure out how long to sleep for
sleep_for = 0
if job.latest_heartbeat:
if self.latest_heartbeat:
seconds_remaining = self.heartrate - \
(timezone.utcnow() - job.latest_heartbeat)\
(timezone.utcnow() - self.latest_heartbeat)\
.total_seconds()
sleep_for = max(0, seconds_remaining)

sleep(sleep_for)

# Update last heartbeat time
with create_session() as session:
job = session.query(BaseJob).filter(BaseJob.id == self.id).first()
job.latest_heartbeat = timezone.utcnow()
session.merge(job)
# Make the sesion aware of this object
session.merge(self)
self.latest_heartbeat = timezone.utcnow()
session.commit()
# At this point, the DB has updated.
previous_heartbeat = self.latest_heartbeat

self.heartbeat_callback(session=session)
self.log.debug('[heartbeat]')
Expand All @@ -195,6 +199,8 @@ def heartbeat(self):
convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1,
1)
self.log.exception("%s heartbeat got an exception", self.__class__.__name__)
# We didn't manage to heartbeat, so make sure that the timestamp isn't updated
self.latest_heartbeat = previous_heartbeat

def run(self):
Stats.incr(self.__class__.__name__.lower() + '_start', 1, 1)
Expand Down
7 changes: 3 additions & 4 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.jobs.base_job import BaseJob
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.db import provide_session
from airflow.utils.net import get_hostname
from airflow.utils.state import State
Expand Down Expand Up @@ -98,14 +99,12 @@ def signal_handler(signum, frame):
self.log.info("Task exited with return code %s", return_code)
return

# Periodically heartbeat so that the scheduler doesn't think this
# is a zombie
last_heartbeat_time = time.time()
self.heartbeat()

# If it's been too long since we've heartbeat, then it's possible that
# the scheduler rescheduled this task, so kill launched processes.
time_since_last_heartbeat = time.time() - last_heartbeat_time
# This can only really happen if the worker can't readh the DB for a long time
time_since_last_heartbeat = (timezone.utcnow() - self.latest_heartbeat).total_seconds()
if time_since_last_heartbeat > heartbeat_time_limit:
Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1)
self.log.error("Heartbeat time limited exceeded!")
Expand Down
19 changes: 19 additions & 0 deletions tests/jobs/test_base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
import datetime
import unittest

from sqlalchemy.exc import OperationalError

from airflow.jobs import BaseJob
from airflow.utils import timezone
from airflow.utils.db import create_session
from airflow.utils.state import State
from tests.compat import Mock, patch


class TestBaseJob(unittest.TestCase):
Expand Down Expand Up @@ -96,3 +99,19 @@ def test_is_alive(self):
job.state = State.SUCCESS
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10)
self.assertFalse(job.is_alive(), "Completed jobs even with recent heartbeat should not be alive")

@patch('airflow.jobs.base_job.create_session')
def test_heartbeat_failed(self, mock_create_session):
when = timezone.utcnow() - datetime.timedelta(seconds=60)
with create_session() as session:
mock_session = Mock(spec_set=session, name="MockSession")
mock_create_session.return_value.__enter__.return_value = mock_session

job = self.TestJob(None, heartrate=10, state=State.RUNNING)
job.latest_heartbeat = when

mock_session.commit.side_effect = OperationalError("Force fail", {}, None)

job.heartbeat()

self.assertEqual(job.latest_heartbeat, when, "attriubte not updated when heartbeat fails")
7 changes: 3 additions & 4 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def test_localtaskjob_heartbeat(self, mock_pid):
session.merge(ti)
session.commit()

ret = job1.heartbeat_callback()
self.assertEqual(ret, None)
job1.heartbeat_callback()

mock_pid.return_value = 2
self.assertRaises(AirflowException, job1.heartbeat_callback)
Expand All @@ -126,7 +125,7 @@ def test_heartbeat_failed_fast(self, mock_getpid):

heartbeat_records = []

def heartbeat_recorder():
def heartbeat_recorder(**kwargs):
heartbeat_records.append(timezone.utcnow())

with create_session() as session:
Expand All @@ -153,7 +152,7 @@ def heartbeat_recorder():

job = LocalTaskJob(task_instance=ti, executor=TestExecutor(do_update=False))
job.heartrate = 2
job.heartbeat = heartbeat_recorder
job.heartbeat_callback = heartbeat_recorder
job._execute()
self.assertGreater(len(heartbeat_records), 1)
for i in range(1, len(heartbeat_records)):
Expand Down

0 comments on commit 68b8ec5

Please sign in to comment.