From 392abf9f1b7de77fb17c23f8d67fb42a018220bd Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Nov 2024 12:25:35 +0000 Subject: [PATCH 1/5] Rewite LocalExecutor to be simpler, and to shutdown cleanly on Python 3.10+ Something changed between Python 3.7 and 3.10 meaning that a limited parallelism LocalExecutor scheduler now doesn't shutdown cleanly on receiving a signal. On closer inspection of the limited vs unlimited path it apepars to me that the code was "over-generalized" and the entire concept of `self.impl` has been removed hopefully making this code much more direct and easier to understand. The key things are now: - When a task needs to be run, we send the message on a mp.SimpleQueue object, and increment an internal counter. (We use our own counter, not qsize method as that is not portable) - Inside _check_workers we see if we think there are any outstanding messages, and create a worker if there are. The reason we do this is the on macOS (where the default mp start method is "spawn") a process will be started via `exeucte_async`, but it will take a second or two to pull the message of the queue, by which time the scheduler will have called `executor.sync()` again, meaning we'd over create workers (but never above the limit). Avoiding that case is why we keep the internal `_outstanding_messages` counter -- `self.activity_queue.empty()` would return False when the worker is booting up. - Everytime the scheduler calls the `sync()` method we read out of the result queue and decrement the internal counter. - We remove the entire use of `multiprocessing.Manager` -- it doesn't seem to do anything other than create queue objects but for our use it just adds complexity to understanding - Almost as a side-effect we now only create worker subprocesses on demand, instead of pre-launching them. We do not currently shut down idle processes, though adding it would be quite straight forward if we wanted to in the future This branch name was "rewrite-local-exexc-concurrentfutures" (sic) as when originally opened in 2022 for 3.10 that was the plan. However since then 3.12 has come out and it now starts issuing warnings when Fork and threads are used, and concurrent.futures uses a thread internally, so a different approach was used. --- airflow/executors/local_executor.py | 481 +++++++++---------------- tests/executors/test_local_executor.py | 112 +++--- 2 files changed, 228 insertions(+), 365 deletions(-) diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index cb19b57a81501..9d2d8ae197f22 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -25,192 +25,139 @@ from __future__ import annotations -import contextlib import logging +import multiprocessing import os import subprocess -from abc import abstractmethod -from multiprocessing import Manager, Process -from queue import Empty +from multiprocessing import Queue, SimpleQueue from typing import TYPE_CHECKING, Any, Optional, Tuple -from setproctitle import getproctitle, setproctitle - from airflow import settings -from airflow.exceptions import AirflowException from airflow.executors.base_executor import PARALLELISM, BaseExecutor -from airflow.traces.tracer import Trace, add_span +from airflow.traces.tracer import add_span from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from multiprocessing.managers import SyncManager - from queue import Queue - from airflow.executors.base_executor import CommandType - from airflow.models.taskinstance import TaskInstanceStateType from airflow.models.taskinstancekey import TaskInstanceKey # This is a work to be executed by a worker. # It can Key and Command - but it can also be None, None which is actually a # "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly. - ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]] - + ExecutorWorkType = Optional[Tuple[TaskInstanceKey, CommandType]] + TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState, Optional[Exception]] -class LocalWorkerBase(Process, LoggingMixin): - """ - LocalWorkerBase implementation to run airflow commands. - - Executes the given command and puts the result into a result queue when done, terminating execution. - :param result_queue: the queue to store result state - """ +def _run_worker(logger_name: str, input: SimpleQueue[ExecutorWorkType], output: Queue[TaskInstanceStateType]): + import signal - def __init__(self, result_queue: Queue[TaskInstanceStateType]): - super().__init__(target=self.do_work) - self.daemon: bool = True - self.result_queue: Queue[TaskInstanceStateType] = result_queue + from setproctitle import setproctitle - def run(self): - # We know we've just started a new process, so lets disconnect from the metadata db now - settings.engine.pool.dispose() - settings.engine.dispose() - setproctitle("airflow worker -- LocalExecutor") - return super().run() + # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion + signal.signal(signal.SIGINT, signal.SIG_IGN) - @add_span - def execute_work(self, key: TaskInstanceKey, command: CommandType) -> None: - """ - Execute command received and stores result state in queue. - - :param key: the key to identify the task instance - :param command: the command to execute - """ - if key is None: - return + log = logging.getLogger(logger_name) - self.log.info("%s running %s", self.__class__.__name__, command) - setproctitle(f"airflow worker -- LocalExecutor: {command}") - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - state = self._execute_work_in_subprocess(command) - else: - state = self._execute_work_in_fork(command) + # We know we've just started a new process, so lets disconnect from the metadata db now + settings.engine.pool.dispose() + settings.engine.dispose() - self.result_queue.put((key, state)) - # Remove the command since the worker is done executing the task - setproctitle("airflow worker -- LocalExecutor") + setproctitle("airflow worker -- LocalExecutor: ") - @add_span - def _execute_work_in_subprocess(self, command: CommandType) -> TaskInstanceState: + while True: try: - subprocess.check_call(command, close_fds=True) - return TaskInstanceState.SUCCESS - except subprocess.CalledProcessError as e: - self.log.error("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - - @add_span - def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState: - pid = os.fork() - if pid: - # In parent, wait for the child - pid, ret = os.waitpid(pid, 0) - return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED - - from airflow.sentry import Sentry + item = input.get() + except EOFError: + log.info( + "Failed to read tasks from the task queue because the other " + "end has closed the connection. Terminating worker %s.", + multiprocessing.current_process().name, + ) + break + + if item is None: + # Received poison pill, no more tasks to run + return - ret = 1 + (key, command) = item try: - import signal + state = _execute_work(log, key, command) - from airflow.cli.cli_parser import get_parser + output.put((key, state, None)) + except Exception as e: + output.put((key, TaskInstanceState.FAILED, e)) - signal.signal(signal.SIGINT, signal.SIG_DFL) - signal.signal(signal.SIGTERM, signal.SIG_DFL) - signal.signal(signal.SIGUSR2, signal.SIG_DFL) - parser = get_parser() - # [1:] - remove "airflow" from the start of the command - args = parser.parse_args(command[1:]) - args.shut_down_logging = False +def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandType) -> TaskInstanceState: + """ + Execute command received and stores result state in queue. - setproctitle(f"airflow task supervisor: {command}") + :param key: the key to identify the task instance + :param command: the command to execute + """ + from setproctitle import setproctitle - args.func(args) - ret = 0 - return TaskInstanceState.SUCCESS - except Exception as e: - self.log.exception("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - finally: - Sentry.flush() - logging.shutdown() - os._exit(ret) + setproctitle(f"airflow worker -- LocalExecutor: {command}") + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) + try: + with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + return _execute_work_in_subprocess(log, command) + else: + return _execute_work_in_fork(log, command) + finally: + # Remove the command since the worker is done executing the task + setproctitle("airflow worker -- LocalExecutor: ") - @abstractmethod - def do_work(self): - """Execute tasks; called in the subprocess.""" - raise NotImplementedError() +def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> TaskInstanceState: + try: + subprocess.check_call(command, close_fds=True) + return TaskInstanceState.SUCCESS + except subprocess.CalledProcessError as e: + log.error("Failed to execute task %s.", e) + return TaskInstanceState.FAILED -class LocalWorker(LocalWorkerBase): - """ - Local worker that executes the task. - :param result_queue: queue where results of the tasks are put. - :param key: key identifying task instance - :param command: Command to execute - """ +def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> TaskInstanceState: + pid = os.fork() + if pid: + # In parent, wait for the child + pid, ret = os.waitpid(pid, 0) + return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED - def __init__( - self, result_queue: Queue[TaskInstanceStateType], key: TaskInstanceKey, command: CommandType - ): - super().__init__(result_queue) - self.key: TaskInstanceKey = key - self.command: CommandType = command + from airflow.sentry import Sentry - @add_span - def do_work(self) -> None: - self.execute_work(key=self.key, command=self.command) + ret = 1 + try: + import signal + from setproctitle import setproctitle -class QueuedLocalWorker(LocalWorkerBase): - """ - LocalWorker implementation that is waiting for tasks from a queue. + from airflow.cli.cli_parser import get_parser - Will continue executing commands as they become available in the queue. - It will terminate execution once the poison token is found. + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGUSR2, signal.SIG_DFL) - :param task_queue: queue from which worker reads tasks - :param result_queue: queue where worker puts results after finishing tasks - """ + parser = get_parser() + # [1:] - remove "airflow" from the start of the command + args = parser.parse_args(command[1:]) + args.shut_down_logging = False - def __init__(self, task_queue: Queue[ExecutorWorkType], result_queue: Queue[TaskInstanceStateType]): - super().__init__(result_queue=result_queue) - self.task_queue = task_queue + setproctitle(f"airflow task supervisor: {command}") - @add_span - def do_work(self) -> None: - while True: - try: - key, command = self.task_queue.get() - except EOFError: - self.log.info( - "Failed to read tasks from the task queue because the other " - "end has closed the connection. Terminating worker %s.", - self.name, - ) - break - try: - if key is None or command is None: - # Received poison pill, no more tasks to run - break - self.execute_work(key=key, command=command) - finally: - self.task_queue.task_done() + args.func(args) + ret = 0 + return TaskInstanceState.SUCCESS + except Exception as e: + log.exception("Failed to execute task %s.", e) + return TaskInstanceState.FAILED + finally: + Sentry.flush() + logging.shutdown() + os._exit(ret) class LocalExecutor(BaseExecutor): @@ -228,171 +175,16 @@ class LocalExecutor(BaseExecutor): def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism) + self._outstanding_messages: int = 0 if self.parallelism < 0: - raise AirflowException("parallelism must be bigger than or equal to 0") - self.manager: SyncManager | None = None - self.result_queue: Queue[TaskInstanceStateType] | None = None - self.workers: list[QueuedLocalWorker] = [] - self.workers_used: int = 0 - self.workers_active: int = 0 - self.impl: None | (LocalExecutor.UnlimitedParallelism | LocalExecutor.LimitedParallelism) = None - - class UnlimitedParallelism: - """ - Implement LocalExecutor with unlimited parallelism, starting one process per command executed. - - :param executor: the executor instance to implement. - """ - - def __init__(self, executor: LocalExecutor): - self.executor: LocalExecutor = executor - - def start(self) -> None: - """Start the executor.""" - self.executor.workers_used = 0 - self.executor.workers_active = 0 - - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """ - Execute task asynchronously. - - :param key: the key to identify the task instance - :param command: the command to execute - :param queue: Name of the queue - :param executor_config: configuration for the executor - """ - if TYPE_CHECKING: - assert self.executor.result_queue - - span = Trace.get_current_span() - if span.is_recording(): - span.set_attributes( - { - "dag_id": key.dag_id, - "run_id": key.run_id, - "task_id": key.task_id, - "try_number": key.try_number, - "commands_to_run": str(command), - } - ) - - local_worker = LocalWorker(self.executor.result_queue, key=key, command=command) - self.executor.workers_used += 1 - self.executor.workers_active += 1 - local_worker.start() - - def sync(self) -> None: - """Sync will get called periodically by the heartbeat method.""" - if not self.executor.result_queue: - raise AirflowException("Executor should be started first") - while not self.executor.result_queue.empty(): - results = self.executor.result_queue.get() - self.executor.change_state(*results) - self.executor.workers_active -= 1 - - def end(self) -> None: - """Wait synchronously for the previously submitted job to complete.""" - while self.executor.workers_active > 0: - self.executor.sync() - - class LimitedParallelism: - """ - Implements LocalExecutor with limited parallelism. - - Uses a task queue to coordinate work distribution. - - :param executor: the executor instance to implement. - """ - - def __init__(self, executor: LocalExecutor): - self.executor: LocalExecutor = executor - self.queue: Queue[ExecutorWorkType] | None = None - - def start(self) -> None: - """Start limited parallelism implementation.""" - if TYPE_CHECKING: - assert self.executor.manager - assert self.executor.result_queue - - self.queue = self.executor.manager.Queue() - self.executor.workers = [ - QueuedLocalWorker(self.queue, self.executor.result_queue) - for _ in range(self.executor.parallelism) - ] - - self.executor.workers_used = len(self.executor.workers) - - for worker in self.executor.workers: - worker.start() - - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """ - Execute task asynchronously. - - :param key: the key to identify the task instance - :param command: the command to execute - :param queue: name of the queue - :param executor_config: configuration for the executor - """ - if TYPE_CHECKING: - assert self.queue - - self.queue.put((key, command)) - - def sync(self): - """Sync will get called periodically by the heartbeat method.""" - with contextlib.suppress(Empty): - while True: - results = self.executor.result_queue.get_nowait() - try: - self.executor.change_state(*results) - finally: - self.executor.result_queue.task_done() - - def end(self): - """ - End the executor. - - Sends the poison pill to all workers. - """ - for _ in self.executor.workers: - self.queue.put((None, None)) - - # Wait for commands to finish - self.queue.join() - self.executor.sync() + raise ValueError("parallelism must be greater than or equal to 0") + self.activity_queue: SimpleQueue[ExecutorWorkType] = SimpleQueue() + self.result_queue: SimpleQueue[TaskInstanceStateType] = SimpleQueue() + self.workers: dict[int, multiprocessing.Process] = {} def start(self) -> None: """Start the executor.""" - old_proctitle = getproctitle() - setproctitle("airflow executor -- LocalExecutor") - self.manager = Manager() - setproctitle(old_proctitle) - self.result_queue = self.manager.Queue() - self.workers = [] - self.workers_used = 0 - self.workers_active = 0 - self.impl = ( - LocalExecutor.UnlimitedParallelism(self) - if self.parallelism == 0 - else LocalExecutor.LimitedParallelism(self) - ) - - self.impl.start() + pass @add_span def execute_async( @@ -403,32 +195,91 @@ def execute_async( executor_config: Any | None = None, ) -> None: """Execute asynchronously.""" - if TYPE_CHECKING: - assert self.impl - self.validate_airflow_tasks_run_command(command) + self.activity_queue.put((key, command)) + self._outstanding_messages += 1 + self._check_workers(can_start=True) + + def _check_workers(self, can_start: bool = True): + # Reap any dead workers + to_remove = set() + for pid, proc in self.workers.items(): + if not proc.is_alive(): + to_remove.add(pid) + proc.close() + + if to_remove: + self.workers = {pid: proc for pid, proc in self.workers.items() if pid not in to_remove} + + # If we're using spawn in multiprocessing (default on macos now) to start tasks, this can get called a + # via sync() a few times before the spawned process actually starts picking up messages. Try not to + # create too much + + if self._outstanding_messages <= 0 or self.activity_queue.empty(): + # Nothing to do, should we shut down idle workers? + return - self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) + need_more_workers = len(self.workers) < self._outstanding_messages + if need_more_workers and (self.parallelism == 0 or len(self.workers) < self.parallelism): + self._spawn_worker() + + def _spawn_worker(self): + p = multiprocessing.Process( + target=_run_worker, + kwargs={ + "logger_name": self.log.name, + "input": self.activity_queue, + "output": self.result_queue, + }, + ) + p.start() + if TYPE_CHECKING: + assert p.pid # Since we've called start + self.workers[p.pid] = p def sync(self) -> None: """Sync will get called periodically by the heartbeat method.""" - if TYPE_CHECKING: - assert self.impl + self._read_results() + self._check_workers() - self.impl.sync() + def _read_results(self): + while not self.result_queue.empty(): + key, state, exc = self.result_queue.get() + self._outstanding_messages = self._outstanding_messages - 1 + + if exc: + # TODO: This needs a better stacktrace, it appears from here + if hasattr(exc, "add_note"): + exc.add_note("(This stacktrace is incorrect -- the exception came from a subprocess)") + raise exc + + self.change_state(key, state) def end(self) -> None: """End the executor.""" - if TYPE_CHECKING: - assert self.impl - assert self.manager - self.log.info( "Shutting down LocalExecutor" "; waiting for running tasks to finish. Signal again if you don't want to wait." ) - self.impl.end() - self.manager.shutdown() + + # We can't tell which proc will pick which close message up, so we send all the messages, and then + # wait on all the procs + + for proc in self.workers.values(): + # Send the shutdown message once for each alive worker + if proc.is_alive(): + self.activity_queue.put(None) + + for proc in self.workers.values(): + if proc.is_alive(): + proc.join() + proc.close() + + # Process any extra results before closing + self._read_results() + + self.activity_queue.close() + self.result_queue.close() def terminate(self): """Terminate the executor is not doing anything.""" diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index f6ba8dca464d3..3f1ebabe148c8 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -18,10 +18,13 @@ from __future__ import annotations import datetime +import multiprocessing +import os import subprocess from unittest import mock import pytest +from kgb import spy_on from airflow import settings from airflow.exceptions import AirflowException @@ -30,6 +33,12 @@ pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +# Runtime is fine, we just can't run the tests on macOS +skip_spawn_mp_start = pytest.mark.skipif( + multiprocessing.get_context().get_start_method() == "spawn", + reason="mock patching in test don't work with 'spawn' mode (default on macOS)", +) + class TestLocalExecutor: TEST_SUCCESS_COMMANDS = 5 @@ -44,85 +53,71 @@ def test_serve_logs_default_value(self): assert LocalExecutor.serve_logs @mock.patch("airflow.executors.local_executor.subprocess.check_call") - def execution_parallelism_subprocess(self, mock_check_call, parallelism=0): - success_command = ["airflow", "tasks", "run", "true", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "false", "task_id", "2020-10-07"] + @mock.patch("airflow.cli.commands.task_command.task_run") + def _test_execute(self, mock_run, mock_check_call, parallelism=1): + success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] + fail_command = ["airflow", "tasks", "run", "failure", "task_id", "2020-10-07"] + # We just mock both styles here, only one will be hit though def fake_execute_command(command, close_fds=True): if command != success_command: raise subprocess.CalledProcessError(returncode=1, cmd=command) else: return 0 - mock_check_call.side_effect = fake_execute_command - - self._test_execute(parallelism, success_command, fail_command) - - @mock.patch("airflow.cli.commands.task_command.task_run") - def execution_parallelism_fork(self, mock_run, parallelism=0): - success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "failure", "some_parameter", "2020-10-07"] - def fake_task_run(args): + print(repr(args)) if args.dag_id != "success": raise AirflowException("Simulate failed task") + mock_check_call.side_effect = fake_execute_command mock_run.side_effect = fake_task_run - self._test_execute(parallelism, success_command, fail_command) - - def _test_execute(self, parallelism, success_command, fail_command): executor = LocalExecutor(parallelism=parallelism) executor.start() success_key = "success {}" assert executor.result_queue.empty() - logical_date = datetime.datetime.now() - for i in range(self.TEST_SUCCESS_COMMANDS): - key_id, command = success_key.format(i), success_command - key = key_id, "fake_ti", logical_date, 0 - executor.running.add(key) - executor.execute_async(key=key, command=command) + with spy_on(executor._spawn_worker) as spy: + run_id = "manual_" + datetime.datetime.now().isoformat() + for i in range(self.TEST_SUCCESS_COMMANDS): + key_id, command = success_key.format(i), success_command + key = key_id, "fake_ti", run_id, 0 + executor.running.add(key) + executor.execute_async(key=key, command=command) + + fail_key = "fail", "fake_ti", run_id, 0 + executor.running.add(fail_key) + executor.execute_async(key=fail_key, command=fail_command) - fail_key = "fail", "fake_ti", logical_date, 0 - executor.running.add(fail_key) - executor.execute_async(key=fail_key, command=fail_command) + executor.end() - executor.end() + expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism + assert len(spy.calls) == expected # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 + assert executor._outstanding_messages == 0 for i in range(self.TEST_SUCCESS_COMMANDS): key_id = success_key.format(i) - key = key_id, "fake_ti", logical_date, 0 + key = key_id, "fake_ti", run_id, 0 assert executor.event_buffer[key][0] == State.SUCCESS assert executor.event_buffer[fail_key][0] == State.FAILED - expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism - assert executor.workers_used == expected - - def test_execution_subprocess_unlimited_parallelism(self): - with mock.patch.object( - settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", new_callable=mock.PropertyMock - ) as option: - option.return_value = True - self.execution_parallelism_subprocess(parallelism=0) - - def test_execution_subprocess_limited_parallelism(self): - with mock.patch.object( - settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", new_callable=mock.PropertyMock - ) as option: - option.return_value = True - self.execution_parallelism_subprocess(parallelism=2) - - @mock.patch.object(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", False) - def test_execution_unlimited_parallelism_fork(self): - self.execution_parallelism_fork(parallelism=0) - - @mock.patch.object(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", False) - def test_execution_limited_parallelism_fork(self): - self.execution_parallelism_fork(parallelism=2) + @skip_spawn_mp_start + @pytest.mark.parametrize( + ("parallelism", "fork_or_subproc"), + [ + pytest.param(0, True, id="unlimited_subprocess"), + pytest.param(2, True, id="limited_subprocess"), + pytest.param(0, False, id="unlimited_fork"), + pytest.param(2, False, id="limited_fork"), + ], + ) + def test_execution(self, parallelism: int, fork_or_subproc: bool, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", fork_or_subproc) + self._test_execute(parallelism=parallelism) @mock.patch("airflow.executors.local_executor.LocalExecutor.sync") @mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") @@ -142,3 +137,20 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock ), ] mock_stats_gauge.assert_has_calls(calls) + + @pytest.mark.execution_timeout(5) + def test_clean_stop_on_signal(self): + import signal + + executor = LocalExecutor(parallelism=2) + executor.start() + + # We want to ensure we start a worker process, as we now only create them on demand + executor._spawn_worker() + + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + pass + finally: + executor.end() From 5cbebdddf6484244043229313ec3fbba50ca22ba Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Nov 2024 14:09:51 +0000 Subject: [PATCH 2/5] fixup! Rewite LocalExecutor to be simpler, and to shutdown cleanly on Python 3.10+ --- airflow/executors/local_executor.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 9d2d8ae197f22..e1157dde195a6 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -173,18 +173,25 @@ class LocalExecutor(BaseExecutor): serve_logs: bool = True + activity_queue: SimpleQueue[ExecutorWorkType] + result_queue: SimpleQueue[TaskInstanceStateType] + workers: dict[int, multiprocessing.Process] + _outstanding_messages: int = 0 + def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism) - self._outstanding_messages: int = 0 if self.parallelism < 0: raise ValueError("parallelism must be greater than or equal to 0") - self.activity_queue: SimpleQueue[ExecutorWorkType] = SimpleQueue() - self.result_queue: SimpleQueue[TaskInstanceStateType] = SimpleQueue() - self.workers: dict[int, multiprocessing.Process] = {} def start(self) -> None: """Start the executor.""" - pass + # We delay opening these queues until the start method mostly for unit tests. ExecutorLoader caches + # instances, so each test reusues the same instance! (i.e. test 1 runs, closes the queues, then test 2 + # comes back and gets the same LocalExecutor instance, so we have to open new here.) + self.activity_queue = SimpleQueue() + self.result_queue = SimpleQueue() + self.workers = {} + self._outstanding_messages = 0 @add_span def execute_async( From eb4bc4adc8736c04d0f265dfb9128522ed1f5c8c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Nov 2024 14:49:51 +0000 Subject: [PATCH 3/5] fixup! Rewite LocalExecutor to be simpler, and to shutdown cleanly on Python 3.10+ --- tests/executors/test_local_executor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 3f1ebabe148c8..c81718a810157 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -79,7 +79,7 @@ def fake_task_run(args): success_key = "success {}" assert executor.result_queue.empty() - with spy_on(executor._spawn_worker) as spy: + with spy_on(executor._spawn_worker) as spawn_worker: run_id = "manual_" + datetime.datetime.now().isoformat() for i in range(self.TEST_SUCCESS_COMMANDS): key_id, command = success_key.format(i), success_command @@ -94,7 +94,8 @@ def fake_task_run(args): executor.end() expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism - assert len(spy.calls) == expected + # Depending on how quickly the tasks run, we might not need to create all the workers we could + assert 1 <= len(spawn_worker.calls) <= expected # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 assert executor._outstanding_messages == 0 From 2f7d6c53d26c91ec8234e54c65a382ee86c44902 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Nov 2024 16:08:36 +0000 Subject: [PATCH 4/5] fixup! Rewite LocalExecutor to be simpler, and to shutdown cleanly on Python 3.10+ --- airflow/executors/local_executor.py | 53 +++++++++++++++++--------- tests/executors/test_local_executor.py | 4 +- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index e1157dde195a6..199f2fe2459de 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -25,13 +25,17 @@ from __future__ import annotations +import ctypes import logging import multiprocessing +import multiprocessing.sharedctypes import os import subprocess from multiprocessing import Queue, SimpleQueue from typing import TYPE_CHECKING, Any, Optional, Tuple +from setproctitle import setproctitle + from airflow import settings from airflow.executors.base_executor import PARALLELISM, BaseExecutor from airflow.traces.tracer import add_span @@ -49,11 +53,14 @@ TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState, Optional[Exception]] -def _run_worker(logger_name: str, input: SimpleQueue[ExecutorWorkType], output: Queue[TaskInstanceStateType]): +def _run_worker( + logger_name: str, + input: SimpleQueue[ExecutorWorkType], + output: Queue[TaskInstanceStateType], + unread_messages: multiprocessing.sharedctypes.Synchronized[int], +): import signal - from setproctitle import setproctitle - # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -80,6 +87,10 @@ def _run_worker(logger_name: str, input: SimpleQueue[ExecutorWorkType], output: # Received poison pill, no more tasks to run return + # Decrement this as soon as we pick up a message off the queue + with unread_messages: + unread_messages.value -= 1 + (key, command) = item try: state = _execute_work(log, key, command) @@ -96,8 +107,6 @@ def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandTyp :param key: the key to identify the task instance :param command: the command to execute """ - from setproctitle import setproctitle - setproctitle(f"airflow worker -- LocalExecutor: {command}") dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) try: @@ -133,8 +142,6 @@ def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> TaskInst try: import signal - from setproctitle import setproctitle - from airflow.cli.cli_parser import get_parser signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -176,7 +183,7 @@ class LocalExecutor(BaseExecutor): activity_queue: SimpleQueue[ExecutorWorkType] result_queue: SimpleQueue[TaskInstanceStateType] workers: dict[int, multiprocessing.Process] - _outstanding_messages: int = 0 + _unread_messages: multiprocessing.sharedctypes.Synchronized[int] def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism) @@ -191,7 +198,10 @@ def start(self) -> None: self.activity_queue = SimpleQueue() self.result_queue = SimpleQueue() self.workers = {} - self._outstanding_messages = 0 + + # Mypy sees this value as `SynchronizedBase[c_uint]`, but that isn't the right runtime type behaviour + # (it looks like an int to python) + self._unread_messages = multiprocessing.Value(ctypes.c_uint) # type: ignore[assignment] @add_span def execute_async( @@ -204,7 +214,8 @@ def execute_async( """Execute asynchronously.""" self.validate_airflow_tasks_run_command(command) self.activity_queue.put((key, command)) - self._outstanding_messages += 1 + with self._unread_messages: + self._unread_messages.value += 1 self._check_workers(can_start=True) def _check_workers(self, can_start: bool = True): @@ -214,20 +225,28 @@ def _check_workers(self, can_start: bool = True): if not proc.is_alive(): to_remove.add(pid) proc.close() + if proc.exitcode is not None and proc.exitcode > 0: + # The process died! + ... if to_remove: self.workers = {pid: proc for pid, proc in self.workers.items() if pid not in to_remove} - # If we're using spawn in multiprocessing (default on macos now) to start tasks, this can get called a - # via sync() a few times before the spawned process actually starts picking up messages. Try not to - # create too much + with self._unread_messages: + num_outstanding = self._unread_messages.value - if self._outstanding_messages <= 0 or self.activity_queue.empty(): - # Nothing to do, should we shut down idle workers? + if num_outstanding <= 0 or self.activity_queue.empty(): + # Nothing to do. Future enhancement if someone wants: shut down workers that have been idle for N + # seconds return - need_more_workers = len(self.workers) < self._outstanding_messages + # If we're using spawn in multiprocessing (default on macOS now) to start tasks, this can get called a + # via `sync()` a few times before the spawned process actually starts picking up messages. Try not to + # create too much + need_more_workers = len(self.workers) < num_outstanding if need_more_workers and (self.parallelism == 0 or len(self.workers) < self.parallelism): + # This only creates one worker, which is fine as we call this directly after putting a message on + # activity_queue in execute_async self._spawn_worker() def _spawn_worker(self): @@ -237,6 +256,7 @@ def _spawn_worker(self): "logger_name": self.log.name, "input": self.activity_queue, "output": self.result_queue, + "unread_messages": self._unread_messages, }, ) p.start() @@ -252,7 +272,6 @@ def sync(self) -> None: def _read_results(self): while not self.result_queue.empty(): key, state, exc = self.result_queue.get() - self._outstanding_messages = self._outstanding_messages - 1 if exc: # TODO: This needs a better stacktrace, it appears from here diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index c81718a810157..2545ceb7e705d 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -66,7 +66,6 @@ def fake_execute_command(command, close_fds=True): return 0 def fake_task_run(args): - print(repr(args)) if args.dag_id != "success": raise AirflowException("Simulate failed task") @@ -96,9 +95,10 @@ def fake_task_run(args): expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism # Depending on how quickly the tasks run, we might not need to create all the workers we could assert 1 <= len(spawn_worker.calls) <= expected + # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 - assert executor._outstanding_messages == 0 + assert executor._unread_messages.value == 0 for i in range(self.TEST_SUCCESS_COMMANDS): key_id = success_key.format(i) From da01d423d82008a0153c2b84af543715f7b919d4 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Nov 2024 16:11:09 +0000 Subject: [PATCH 5/5] fixup! Rewite LocalExecutor to be simpler, and to shutdown cleanly on Python 3.10+ --- airflow/executors/local_executor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 199f2fe2459de..3b8b52176db5b 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -225,9 +225,6 @@ def _check_workers(self, can_start: bool = True): if not proc.is_alive(): to_remove.add(pid) proc.close() - if proc.exitcode is not None and proc.exitcode > 0: - # The process died! - ... if to_remove: self.workers = {pid: proc for pid, proc in self.workers.items() if pid not in to_remove}