diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index cb19b57a81501..9d2d8ae197f22 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -25,192 +25,139 @@ from __future__ import annotations -import contextlib import logging +import multiprocessing import os import subprocess -from abc import abstractmethod -from multiprocessing import Manager, Process -from queue import Empty +from multiprocessing import Queue, SimpleQueue from typing import TYPE_CHECKING, Any, Optional, Tuple -from setproctitle import getproctitle, setproctitle - from airflow import settings -from airflow.exceptions import AirflowException from airflow.executors.base_executor import PARALLELISM, BaseExecutor -from airflow.traces.tracer import Trace, add_span +from airflow.traces.tracer import add_span from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from multiprocessing.managers import SyncManager - from queue import Queue - from airflow.executors.base_executor import CommandType - from airflow.models.taskinstance import TaskInstanceStateType from airflow.models.taskinstancekey import TaskInstanceKey # This is a work to be executed by a worker. # It can Key and Command - but it can also be None, None which is actually a # "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly. - ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]] - + ExecutorWorkType = Optional[Tuple[TaskInstanceKey, CommandType]] + TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState, Optional[Exception]] -class LocalWorkerBase(Process, LoggingMixin): - """ - LocalWorkerBase implementation to run airflow commands. - - Executes the given command and puts the result into a result queue when done, terminating execution. - :param result_queue: the queue to store result state - """ +def _run_worker(logger_name: str, input: SimpleQueue[ExecutorWorkType], output: Queue[TaskInstanceStateType]): + import signal - def __init__(self, result_queue: Queue[TaskInstanceStateType]): - super().__init__(target=self.do_work) - self.daemon: bool = True - self.result_queue: Queue[TaskInstanceStateType] = result_queue + from setproctitle import setproctitle - def run(self): - # We know we've just started a new process, so lets disconnect from the metadata db now - settings.engine.pool.dispose() - settings.engine.dispose() - setproctitle("airflow worker -- LocalExecutor") - return super().run() + # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion + signal.signal(signal.SIGINT, signal.SIG_IGN) - @add_span - def execute_work(self, key: TaskInstanceKey, command: CommandType) -> None: - """ - Execute command received and stores result state in queue. - - :param key: the key to identify the task instance - :param command: the command to execute - """ - if key is None: - return + log = logging.getLogger(logger_name) - self.log.info("%s running %s", self.__class__.__name__, command) - setproctitle(f"airflow worker -- LocalExecutor: {command}") - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - state = self._execute_work_in_subprocess(command) - else: - state = self._execute_work_in_fork(command) + # We know we've just started a new process, so lets disconnect from the metadata db now + settings.engine.pool.dispose() + settings.engine.dispose() - self.result_queue.put((key, state)) - # Remove the command since the worker is done executing the task - setproctitle("airflow worker -- LocalExecutor") + setproctitle("airflow worker -- LocalExecutor: ") - @add_span - def _execute_work_in_subprocess(self, command: CommandType) -> TaskInstanceState: + while True: try: - subprocess.check_call(command, close_fds=True) - return TaskInstanceState.SUCCESS - except subprocess.CalledProcessError as e: - self.log.error("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - - @add_span - def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState: - pid = os.fork() - if pid: - # In parent, wait for the child - pid, ret = os.waitpid(pid, 0) - return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED - - from airflow.sentry import Sentry + item = input.get() + except EOFError: + log.info( + "Failed to read tasks from the task queue because the other " + "end has closed the connection. Terminating worker %s.", + multiprocessing.current_process().name, + ) + break + + if item is None: + # Received poison pill, no more tasks to run + return - ret = 1 + (key, command) = item try: - import signal + state = _execute_work(log, key, command) - from airflow.cli.cli_parser import get_parser + output.put((key, state, None)) + except Exception as e: + output.put((key, TaskInstanceState.FAILED, e)) - signal.signal(signal.SIGINT, signal.SIG_DFL) - signal.signal(signal.SIGTERM, signal.SIG_DFL) - signal.signal(signal.SIGUSR2, signal.SIG_DFL) - parser = get_parser() - # [1:] - remove "airflow" from the start of the command - args = parser.parse_args(command[1:]) - args.shut_down_logging = False +def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandType) -> TaskInstanceState: + """ + Execute command received and stores result state in queue. - setproctitle(f"airflow task supervisor: {command}") + :param key: the key to identify the task instance + :param command: the command to execute + """ + from setproctitle import setproctitle - args.func(args) - ret = 0 - return TaskInstanceState.SUCCESS - except Exception as e: - self.log.exception("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - finally: - Sentry.flush() - logging.shutdown() - os._exit(ret) + setproctitle(f"airflow worker -- LocalExecutor: {command}") + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) + try: + with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + return _execute_work_in_subprocess(log, command) + else: + return _execute_work_in_fork(log, command) + finally: + # Remove the command since the worker is done executing the task + setproctitle("airflow worker -- LocalExecutor: ") - @abstractmethod - def do_work(self): - """Execute tasks; called in the subprocess.""" - raise NotImplementedError() +def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> TaskInstanceState: + try: + subprocess.check_call(command, close_fds=True) + return TaskInstanceState.SUCCESS + except subprocess.CalledProcessError as e: + log.error("Failed to execute task %s.", e) + return TaskInstanceState.FAILED -class LocalWorker(LocalWorkerBase): - """ - Local worker that executes the task. - :param result_queue: queue where results of the tasks are put. - :param key: key identifying task instance - :param command: Command to execute - """ +def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> TaskInstanceState: + pid = os.fork() + if pid: + # In parent, wait for the child + pid, ret = os.waitpid(pid, 0) + return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED - def __init__( - self, result_queue: Queue[TaskInstanceStateType], key: TaskInstanceKey, command: CommandType - ): - super().__init__(result_queue) - self.key: TaskInstanceKey = key - self.command: CommandType = command + from airflow.sentry import Sentry - @add_span - def do_work(self) -> None: - self.execute_work(key=self.key, command=self.command) + ret = 1 + try: + import signal + from setproctitle import setproctitle -class QueuedLocalWorker(LocalWorkerBase): - """ - LocalWorker implementation that is waiting for tasks from a queue. + from airflow.cli.cli_parser import get_parser - Will continue executing commands as they become available in the queue. - It will terminate execution once the poison token is found. + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGUSR2, signal.SIG_DFL) - :param task_queue: queue from which worker reads tasks - :param result_queue: queue where worker puts results after finishing tasks - """ + parser = get_parser() + # [1:] - remove "airflow" from the start of the command + args = parser.parse_args(command[1:]) + args.shut_down_logging = False - def __init__(self, task_queue: Queue[ExecutorWorkType], result_queue: Queue[TaskInstanceStateType]): - super().__init__(result_queue=result_queue) - self.task_queue = task_queue + setproctitle(f"airflow task supervisor: {command}") - @add_span - def do_work(self) -> None: - while True: - try: - key, command = self.task_queue.get() - except EOFError: - self.log.info( - "Failed to read tasks from the task queue because the other " - "end has closed the connection. Terminating worker %s.", - self.name, - ) - break - try: - if key is None or command is None: - # Received poison pill, no more tasks to run - break - self.execute_work(key=key, command=command) - finally: - self.task_queue.task_done() + args.func(args) + ret = 0 + return TaskInstanceState.SUCCESS + except Exception as e: + log.exception("Failed to execute task %s.", e) + return TaskInstanceState.FAILED + finally: + Sentry.flush() + logging.shutdown() + os._exit(ret) class LocalExecutor(BaseExecutor): @@ -228,171 +175,16 @@ class LocalExecutor(BaseExecutor): def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism) + self._outstanding_messages: int = 0 if self.parallelism < 0: - raise AirflowException("parallelism must be bigger than or equal to 0") - self.manager: SyncManager | None = None - self.result_queue: Queue[TaskInstanceStateType] | None = None - self.workers: list[QueuedLocalWorker] = [] - self.workers_used: int = 0 - self.workers_active: int = 0 - self.impl: None | (LocalExecutor.UnlimitedParallelism | LocalExecutor.LimitedParallelism) = None - - class UnlimitedParallelism: - """ - Implement LocalExecutor with unlimited parallelism, starting one process per command executed. - - :param executor: the executor instance to implement. - """ - - def __init__(self, executor: LocalExecutor): - self.executor: LocalExecutor = executor - - def start(self) -> None: - """Start the executor.""" - self.executor.workers_used = 0 - self.executor.workers_active = 0 - - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """ - Execute task asynchronously. - - :param key: the key to identify the task instance - :param command: the command to execute - :param queue: Name of the queue - :param executor_config: configuration for the executor - """ - if TYPE_CHECKING: - assert self.executor.result_queue - - span = Trace.get_current_span() - if span.is_recording(): - span.set_attributes( - { - "dag_id": key.dag_id, - "run_id": key.run_id, - "task_id": key.task_id, - "try_number": key.try_number, - "commands_to_run": str(command), - } - ) - - local_worker = LocalWorker(self.executor.result_queue, key=key, command=command) - self.executor.workers_used += 1 - self.executor.workers_active += 1 - local_worker.start() - - def sync(self) -> None: - """Sync will get called periodically by the heartbeat method.""" - if not self.executor.result_queue: - raise AirflowException("Executor should be started first") - while not self.executor.result_queue.empty(): - results = self.executor.result_queue.get() - self.executor.change_state(*results) - self.executor.workers_active -= 1 - - def end(self) -> None: - """Wait synchronously for the previously submitted job to complete.""" - while self.executor.workers_active > 0: - self.executor.sync() - - class LimitedParallelism: - """ - Implements LocalExecutor with limited parallelism. - - Uses a task queue to coordinate work distribution. - - :param executor: the executor instance to implement. - """ - - def __init__(self, executor: LocalExecutor): - self.executor: LocalExecutor = executor - self.queue: Queue[ExecutorWorkType] | None = None - - def start(self) -> None: - """Start limited parallelism implementation.""" - if TYPE_CHECKING: - assert self.executor.manager - assert self.executor.result_queue - - self.queue = self.executor.manager.Queue() - self.executor.workers = [ - QueuedLocalWorker(self.queue, self.executor.result_queue) - for _ in range(self.executor.parallelism) - ] - - self.executor.workers_used = len(self.executor.workers) - - for worker in self.executor.workers: - worker.start() - - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """ - Execute task asynchronously. - - :param key: the key to identify the task instance - :param command: the command to execute - :param queue: name of the queue - :param executor_config: configuration for the executor - """ - if TYPE_CHECKING: - assert self.queue - - self.queue.put((key, command)) - - def sync(self): - """Sync will get called periodically by the heartbeat method.""" - with contextlib.suppress(Empty): - while True: - results = self.executor.result_queue.get_nowait() - try: - self.executor.change_state(*results) - finally: - self.executor.result_queue.task_done() - - def end(self): - """ - End the executor. - - Sends the poison pill to all workers. - """ - for _ in self.executor.workers: - self.queue.put((None, None)) - - # Wait for commands to finish - self.queue.join() - self.executor.sync() + raise ValueError("parallelism must be greater than or equal to 0") + self.activity_queue: SimpleQueue[ExecutorWorkType] = SimpleQueue() + self.result_queue: SimpleQueue[TaskInstanceStateType] = SimpleQueue() + self.workers: dict[int, multiprocessing.Process] = {} def start(self) -> None: """Start the executor.""" - old_proctitle = getproctitle() - setproctitle("airflow executor -- LocalExecutor") - self.manager = Manager() - setproctitle(old_proctitle) - self.result_queue = self.manager.Queue() - self.workers = [] - self.workers_used = 0 - self.workers_active = 0 - self.impl = ( - LocalExecutor.UnlimitedParallelism(self) - if self.parallelism == 0 - else LocalExecutor.LimitedParallelism(self) - ) - - self.impl.start() + pass @add_span def execute_async( @@ -403,32 +195,91 @@ def execute_async( executor_config: Any | None = None, ) -> None: """Execute asynchronously.""" - if TYPE_CHECKING: - assert self.impl - self.validate_airflow_tasks_run_command(command) + self.activity_queue.put((key, command)) + self._outstanding_messages += 1 + self._check_workers(can_start=True) + + def _check_workers(self, can_start: bool = True): + # Reap any dead workers + to_remove = set() + for pid, proc in self.workers.items(): + if not proc.is_alive(): + to_remove.add(pid) + proc.close() + + if to_remove: + self.workers = {pid: proc for pid, proc in self.workers.items() if pid not in to_remove} + + # If we're using spawn in multiprocessing (default on macos now) to start tasks, this can get called a + # via sync() a few times before the spawned process actually starts picking up messages. Try not to + # create too much + + if self._outstanding_messages <= 0 or self.activity_queue.empty(): + # Nothing to do, should we shut down idle workers? + return - self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) + need_more_workers = len(self.workers) < self._outstanding_messages + if need_more_workers and (self.parallelism == 0 or len(self.workers) < self.parallelism): + self._spawn_worker() + + def _spawn_worker(self): + p = multiprocessing.Process( + target=_run_worker, + kwargs={ + "logger_name": self.log.name, + "input": self.activity_queue, + "output": self.result_queue, + }, + ) + p.start() + if TYPE_CHECKING: + assert p.pid # Since we've called start + self.workers[p.pid] = p def sync(self) -> None: """Sync will get called periodically by the heartbeat method.""" - if TYPE_CHECKING: - assert self.impl + self._read_results() + self._check_workers() - self.impl.sync() + def _read_results(self): + while not self.result_queue.empty(): + key, state, exc = self.result_queue.get() + self._outstanding_messages = self._outstanding_messages - 1 + + if exc: + # TODO: This needs a better stacktrace, it appears from here + if hasattr(exc, "add_note"): + exc.add_note("(This stacktrace is incorrect -- the exception came from a subprocess)") + raise exc + + self.change_state(key, state) def end(self) -> None: """End the executor.""" - if TYPE_CHECKING: - assert self.impl - assert self.manager - self.log.info( "Shutting down LocalExecutor" "; waiting for running tasks to finish. Signal again if you don't want to wait." ) - self.impl.end() - self.manager.shutdown() + + # We can't tell which proc will pick which close message up, so we send all the messages, and then + # wait on all the procs + + for proc in self.workers.values(): + # Send the shutdown message once for each alive worker + if proc.is_alive(): + self.activity_queue.put(None) + + for proc in self.workers.values(): + if proc.is_alive(): + proc.join() + proc.close() + + # Process any extra results before closing + self._read_results() + + self.activity_queue.close() + self.result_queue.close() def terminate(self): """Terminate the executor is not doing anything.""" diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index f6ba8dca464d3..3f1ebabe148c8 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -18,10 +18,13 @@ from __future__ import annotations import datetime +import multiprocessing +import os import subprocess from unittest import mock import pytest +from kgb import spy_on from airflow import settings from airflow.exceptions import AirflowException @@ -30,6 +33,12 @@ pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +# Runtime is fine, we just can't run the tests on macOS +skip_spawn_mp_start = pytest.mark.skipif( + multiprocessing.get_context().get_start_method() == "spawn", + reason="mock patching in test don't work with 'spawn' mode (default on macOS)", +) + class TestLocalExecutor: TEST_SUCCESS_COMMANDS = 5 @@ -44,85 +53,71 @@ def test_serve_logs_default_value(self): assert LocalExecutor.serve_logs @mock.patch("airflow.executors.local_executor.subprocess.check_call") - def execution_parallelism_subprocess(self, mock_check_call, parallelism=0): - success_command = ["airflow", "tasks", "run", "true", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "false", "task_id", "2020-10-07"] + @mock.patch("airflow.cli.commands.task_command.task_run") + def _test_execute(self, mock_run, mock_check_call, parallelism=1): + success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] + fail_command = ["airflow", "tasks", "run", "failure", "task_id", "2020-10-07"] + # We just mock both styles here, only one will be hit though def fake_execute_command(command, close_fds=True): if command != success_command: raise subprocess.CalledProcessError(returncode=1, cmd=command) else: return 0 - mock_check_call.side_effect = fake_execute_command - - self._test_execute(parallelism, success_command, fail_command) - - @mock.patch("airflow.cli.commands.task_command.task_run") - def execution_parallelism_fork(self, mock_run, parallelism=0): - success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "failure", "some_parameter", "2020-10-07"] - def fake_task_run(args): + print(repr(args)) if args.dag_id != "success": raise AirflowException("Simulate failed task") + mock_check_call.side_effect = fake_execute_command mock_run.side_effect = fake_task_run - self._test_execute(parallelism, success_command, fail_command) - - def _test_execute(self, parallelism, success_command, fail_command): executor = LocalExecutor(parallelism=parallelism) executor.start() success_key = "success {}" assert executor.result_queue.empty() - logical_date = datetime.datetime.now() - for i in range(self.TEST_SUCCESS_COMMANDS): - key_id, command = success_key.format(i), success_command - key = key_id, "fake_ti", logical_date, 0 - executor.running.add(key) - executor.execute_async(key=key, command=command) + with spy_on(executor._spawn_worker) as spy: + run_id = "manual_" + datetime.datetime.now().isoformat() + for i in range(self.TEST_SUCCESS_COMMANDS): + key_id, command = success_key.format(i), success_command + key = key_id, "fake_ti", run_id, 0 + executor.running.add(key) + executor.execute_async(key=key, command=command) + + fail_key = "fail", "fake_ti", run_id, 0 + executor.running.add(fail_key) + executor.execute_async(key=fail_key, command=fail_command) - fail_key = "fail", "fake_ti", logical_date, 0 - executor.running.add(fail_key) - executor.execute_async(key=fail_key, command=fail_command) + executor.end() - executor.end() + expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism + assert len(spy.calls) == expected # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 + assert executor._outstanding_messages == 0 for i in range(self.TEST_SUCCESS_COMMANDS): key_id = success_key.format(i) - key = key_id, "fake_ti", logical_date, 0 + key = key_id, "fake_ti", run_id, 0 assert executor.event_buffer[key][0] == State.SUCCESS assert executor.event_buffer[fail_key][0] == State.FAILED - expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism - assert executor.workers_used == expected - - def test_execution_subprocess_unlimited_parallelism(self): - with mock.patch.object( - settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", new_callable=mock.PropertyMock - ) as option: - option.return_value = True - self.execution_parallelism_subprocess(parallelism=0) - - def test_execution_subprocess_limited_parallelism(self): - with mock.patch.object( - settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", new_callable=mock.PropertyMock - ) as option: - option.return_value = True - self.execution_parallelism_subprocess(parallelism=2) - - @mock.patch.object(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", False) - def test_execution_unlimited_parallelism_fork(self): - self.execution_parallelism_fork(parallelism=0) - - @mock.patch.object(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", False) - def test_execution_limited_parallelism_fork(self): - self.execution_parallelism_fork(parallelism=2) + @skip_spawn_mp_start + @pytest.mark.parametrize( + ("parallelism", "fork_or_subproc"), + [ + pytest.param(0, True, id="unlimited_subprocess"), + pytest.param(2, True, id="limited_subprocess"), + pytest.param(0, False, id="unlimited_fork"), + pytest.param(2, False, id="limited_fork"), + ], + ) + def test_execution(self, parallelism: int, fork_or_subproc: bool, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", fork_or_subproc) + self._test_execute(parallelism=parallelism) @mock.patch("airflow.executors.local_executor.LocalExecutor.sync") @mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") @@ -142,3 +137,20 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock ), ] mock_stats_gauge.assert_has_calls(calls) + + @pytest.mark.execution_timeout(5) + def test_clean_stop_on_signal(self): + import signal + + executor = LocalExecutor(parallelism=2) + executor.start() + + # We want to ensure we start a worker process, as we now only create them on demand + executor._spawn_worker() + + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + pass + finally: + executor.end()