Skip to content

Commit

Permalink
Automatically reschedule stalled queued tasks in CeleryExecutor (v2) (#…
Browse files Browse the repository at this point in the history
…23690)

Celery can lose tasks on worker shutdown, causing airflow to just wait on them
indefinitely (may be related to celery/celery#7266). This PR expands the
"stalled tasks" functionality which is already in place for adopted tasks, and
adds the ability to apply it to all tasks such that these lost/hung tasks can
be automatically recovered and queued up again.

(cherry picked from commit baae70c)
  • Loading branch information
repl-chris authored and ephraimbuddy committed May 21, 2022
1 parent 43940be commit 7058eb3
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 75 deletions.
15 changes: 13 additions & 2 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1767,12 +1767,23 @@
default: "True"
- name: task_adoption_timeout
description: |
Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
stalled tasks.
Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled,
and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but
applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting
also applies to adopted tasks.
version_added: 2.0.0
type: integer
example: ~
default: "600"
- name: stalled_task_timeout
description: |
Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically
rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified.
When set to 0, automatic clearing of stalled tasks is disabled.
version_added: 2.3.1
type: integer
example: ~
default: "0"
- name: task_publish_max_retries
description: |
The Maximum number of retries for publishing task messages to the broker when failing
Expand Down
11 changes: 9 additions & 2 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -887,10 +887,17 @@ operation_timeout = 1.0
# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob.
task_track_started = True

# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
# stalled tasks.
# Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled,
# and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but
# applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting
# also applies to adopted tasks.
task_adoption_timeout = 600

# Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically
# rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified.
# When set to 0, automatic clearing of stalled tasks is disabled.
stalled_task_timeout = 0

# The Maximum number of retries for publishing task messages to the broker when failing
# due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed.
task_publish_max_retries = 3
Expand Down
143 changes: 112 additions & 31 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import subprocess
import time
import traceback
from collections import Counter, OrderedDict
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from enum import Enum
from multiprocessing import cpu_count
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union

Expand All @@ -40,6 +41,7 @@
from celery.result import AsyncResult
from celery.signals import import_modules as celery_import_modules
from setproctitle import setproctitle
from sqlalchemy.orm.session import Session

import airflow.settings as settings
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
Expand All @@ -50,6 +52,7 @@
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.timezone import utcnow
Expand Down Expand Up @@ -207,6 +210,11 @@ def on_celery_import_modules(*args, **kwargs):
pass


class _CeleryPendingTaskTimeoutType(Enum):
ADOPTED = 1
STALLED = 2


class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow. It allows
Expand All @@ -230,10 +238,14 @@ def __init__(self):
self._sync_parallelism = max(1, cpu_count() - 1)
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.tasks = {}
# Mapping of tasks we've adopted, ordered by the earliest date they timeout
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict()
self.task_adoption_timeout = datetime.timedelta(
seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)
self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
self.stalled_task_timeout = datetime.timedelta(
seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0)
)
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
self.task_adoption_timeout = (
datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600))
or self.stalled_task_timeout
)
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3)
Expand Down Expand Up @@ -285,6 +297,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
result.backend = cached_celery_backend
self.running.add(key)
self.tasks[key] = result
self._set_celery_pending_task_timeout(key, _CeleryPendingTaskTimeoutType.STALLED)

# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
Expand Down Expand Up @@ -315,25 +328,47 @@ def sync(self) -> None:
self.log.debug("No task to query celery, skipping sync")
return
self.update_all_task_states()
self._check_for_timedout_adopted_tasks()
self._check_for_stalled_tasks()

def _check_for_timedout_adopted_tasks(self) -> None:
timedout_keys = self._get_timedout_ti_keys(self.adopted_task_timeouts)
if timedout_keys:
self.log.error(
"Adopted tasks were still pending after %s, assuming they never made it to celery "
"and sending back to the scheduler:\n\t%s",
self.task_adoption_timeout,
"\n\t".join(repr(x) for x in timedout_keys),
)
self._send_stalled_tis_back_to_scheduler(timedout_keys)

if self.adopted_task_timeouts:
self._check_for_stalled_adopted_tasks()
def _check_for_stalled_tasks(self) -> None:
timedout_keys = self._get_timedout_ti_keys(self.stalled_task_timeouts)
if timedout_keys:
self.log.error(
"Tasks were still pending after %s, assuming they never made it to celery "
"and sending back to the scheduler:\n\t%s",
self.stalled_task_timeout,
"\n\t".join(repr(x) for x in timedout_keys),
)
self._send_stalled_tis_back_to_scheduler(timedout_keys)

def _check_for_stalled_adopted_tasks(self):
def _get_timedout_ti_keys(
self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime]
) -> List[TaskInstanceKey]:
"""
See if any of the tasks we adopted from another Executor run have not
progressed after the configured timeout.
If they haven't, they likely never made it to Celery, and we should
just resend them. We do that by clearing the state and letting the
normal scheduler loop deal with that
These timeouts exist to check to see if any of our tasks have not progressed
in the expected time. This can happen for few different reasons, usually related
to race conditions while shutting down schedulers and celery workers.
It is, of course, always possible that these tasks are not actually
stalled - they could just be waiting in a long celery queue.
Unfortunately there's no way for us to know for sure, so we'll just
reschedule them and let the normal scheduler loop requeue them.
"""
now = utcnow()

sorted_adopted_task_timeouts = sorted(self.adopted_task_timeouts.items(), key=lambda k: k[1])

timedout_keys = []
for key, stalled_after in sorted_adopted_task_timeouts:
for key, stalled_after in task_timeouts.items():
if stalled_after > now:
# Since items are stored sorted, if we get to a stalled_after
# in the future then we can stop
Expand All @@ -343,20 +378,46 @@ def _check_for_stalled_adopted_tasks(self):
# already finished, then it will be removed from this list -- so
# the only time it's still in this list is when it a) never made it
# to celery in the first place (i.e. race condition somewhere in
# the dying executor) or b) a really long celery queue and it just
# the dying executor), b) celery lost the task before execution
# started, or c) a really long celery queue and it just
# hasn't started yet -- better cancel it and let the scheduler
# re-queue rather than have this task risk stalling for ever
timedout_keys.append(key)
return timedout_keys

if timedout_keys:
self.log.error(
"Adopted tasks were still pending after %s, assuming they never made it to celery and "
"clearing:\n\t%s",
self.task_adoption_timeout,
"\n\t".join(repr(x) for x in timedout_keys),
@provide_session
def _send_stalled_tis_back_to_scheduler(
self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION
) -> None:
try:
session.query(TaskInstance).filter(
TaskInstance.filter_for_tis(keys),
TaskInstance.state == State.QUEUED,
TaskInstance.queued_by_job_id == self.job_id,
).update(
{
TaskInstance.state: State.SCHEDULED,
TaskInstance.queued_dttm: None,
TaskInstance.queued_by_job_id: None,
TaskInstance.external_executor_id: None,
},
synchronize_session=False,
)
for key in timedout_keys:
self.change_state(key, State.FAILED)
session.commit()
except Exception:
self.log.exception("Error sending tasks back to scheduler")
session.rollback()
return

for key in keys:
self._set_celery_pending_task_timeout(key, None)
self.running.discard(key)
celery_async_result = self.tasks.pop(key, None)
if celery_async_result:
try:
app.control.revoke(celery_async_result.task_id)
except Exception as ex:
self.log.error("Error revoking task instance %s from celery: %s", key, ex)

def debug_dump(self) -> None:
"""Called in response to SIGUSR2 by the scheduler"""
Expand All @@ -369,6 +430,11 @@ def debug_dump(self) -> None:
len(self.adopted_task_timeouts),
"\n\t".join(map(repr, self.adopted_task_timeouts.items())),
)
self.log.info(
"executor.stalled_task_timeouts (%d)\n\t%s",
len(self.stalled_task_timeouts),
"\n\t".join(map(repr, self.stalled_task_timeouts.items())),
)

def update_all_task_states(self) -> None:
"""Updates states of the tasks."""
Expand All @@ -384,7 +450,7 @@ def update_all_task_states(self) -> None:
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)
self.adopted_task_timeouts.pop(key, None)
self._set_celery_pending_task_timeout(key, None)

def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
"""Updates state of a single task."""
Expand All @@ -394,8 +460,8 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
elif state in (celery_states.FAILURE, celery_states.REVOKED):
self.fail(key, info)
elif state == celery_states.STARTED:
# It's now actually running, so know it made it to celery okay!
self.adopted_task_timeouts.pop(key, None)
# It's now actually running, so we know it made it to celery okay!
self._set_celery_pending_task_timeout(key, None)
elif state == celery_states.PENDING:
pass
else:
Expand Down Expand Up @@ -455,7 +521,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance

# Set the correct elements of the state dicts, then update this
# like we just queried it.
self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout
self._set_celery_pending_task_timeout(ti.key, _CeleryPendingTaskTimeoutType.ADOPTED)
self.tasks[ti.key] = result
self.running.add(ti.key)
self.update_task_state(ti.key, state, info)
Expand All @@ -469,6 +535,21 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance

return not_adopted_tis

def _set_celery_pending_task_timeout(
self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType]
) -> None:
"""
We use the fact that dicts maintain insertion order, and the the timeout for a
task is always "now + delta" to maintain the property that oldest item = first to
time out.
"""
self.adopted_task_timeouts.pop(key, None)
self.stalled_task_timeouts.pop(key, None)
if timeout_type == _CeleryPendingTaskTimeoutType.ADOPTED and self.task_adoption_timeout:
self.adopted_task_timeouts[key] = utcnow() + self.task_adoption_timeout
elif timeout_type == _CeleryPendingTaskTimeoutType.STALLED and self.stalled_task_timeout:
self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout


def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]:
"""
Expand Down
Loading

0 comments on commit 7058eb3

Please sign in to comment.