From 8d74ee8dcd1b3ad0291ef666835edcffb24265ae Mon Sep 17 00:00:00 2001 From: Pavan Sharma Date: Sun, 25 Feb 2024 12:59:24 +0530 Subject: [PATCH] Remove pid arg from celery option to fix duplicate pid issue, Move celery command to provider package (#36794) --- airflow/cli/commands/celery_command.py | 17 +- airflow/providers/celery/cli/__init__.py | 16 + .../providers/celery/cli/celery_command.py | 258 +++++++++++++ .../celery/executors/celery_executor.py | 6 +- tests/cli/commands/test_celery_command.py | 6 +- tests/providers/celery/cli/__init__.py | 16 + .../celery/cli/test_celery_command.py | 353 ++++++++++++++++++ 7 files changed, 659 insertions(+), 13 deletions(-) create mode 100644 airflow/providers/celery/cli/__init__.py create mode 100644 airflow/providers/celery/cli/celery_command.py create mode 100644 tests/providers/celery/cli/__init__.py create mode 100644 tests/providers/celery/cli/test_celery_command.py diff --git a/airflow/cli/commands/celery_command.py b/airflow/cli/commands/celery_command.py index 5e3e01042a070..f29641309d741 100644 --- a/airflow/cli/commands/celery_command.py +++ b/airflow/cli/commands/celery_command.py @@ -15,11 +15,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + + +# DO NOT MODIFY THIS FILE unless it is a serious bugfix - all the new celery commands should be added in celery provider. +# This file is kept for backward compatibility only. """Celery command.""" from __future__ import annotations import logging import sys +import warnings from contextlib import contextmanager from multiprocessing import Process @@ -40,6 +45,10 @@ WORKER_PROCESS_NAME = "worker" +warnings.warn( + "Use celery command from providers package, Use celery provider > 3.5.2", DeprecationWarning, stacklevel=2 +) + @cli_utils.action_cli @providers_configuration_loaded @@ -153,9 +162,6 @@ def worker(args): if not celery_log_level: celery_log_level = conf.get("logging", "LOGGING_LEVEL") - # Setup pid file location - worker_pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME, pid=args.pid) - # Setup Celery worker options = [ "worker", @@ -169,8 +175,6 @@ def worker(args): args.celery_hostname, "--loglevel", celery_log_level, - "--pidfile", - worker_pid_file_path, ] if autoscale: options.extend(["--autoscale", autoscale]) @@ -189,11 +193,12 @@ def worker(args): # executed. maybe_patch_concurrency(["-P", pool]) - _, stdout, stderr, log_file = setup_locations( + worker_pid_file_path, stdout, stderr, log_file = setup_locations( process=WORKER_PROCESS_NAME, stdout=args.stdout, stderr=args.stderr, log=args.log_file, + pid=args.pid, ) def run_celery_worker(): diff --git a/airflow/providers/celery/cli/__init__.py b/airflow/providers/celery/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/celery/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/celery/cli/celery_command.py b/airflow/providers/celery/cli/celery_command.py new file mode 100644 index 0000000000000..217433f811d00 --- /dev/null +++ b/airflow/providers/celery/cli/celery_command.py @@ -0,0 +1,258 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Celery command.""" +from __future__ import annotations + +import logging +import sys +from contextlib import contextmanager +from multiprocessing import Process + +import psutil +import sqlalchemy.exc +from celery import maybe_patch_concurrency # type: ignore[attr-defined] +from celery.app.defaults import DEFAULT_TASK_LOG_FMT +from celery.signals import after_setup_logger +from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile + +from airflow import settings +from airflow.configuration import conf +from airflow.utils import cli as cli_utils +from airflow.utils.cli import setup_locations +from airflow.utils.serve_logs import serve_logs + +WORKER_PROCESS_NAME = "worker" + + +def _run_command_with_daemon_option(*args, **kwargs): + try: + from airflow.cli.commands.daemon_utils import run_command_with_daemon_option + + run_command_with_daemon_option(*args, **kwargs) + except ImportError: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException( + "Failed to import run_command_with_daemon_option. This feature is only available in Airflow versions >= 2.8.0" + ) + + +def _providers_configuration_loaded(func): + def wrapper(*args, **kwargs): + try: + from airflow.utils.providers_configuration_loader import providers_configuration_loaded + + providers_configuration_loaded(func)(*args, **kwargs) + except ImportError: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException( + "Failed to import providers_configuration_loaded. This feature is only available in Airflow versions >= 2.8.0" + ) + + return wrapper + + +@cli_utils.action_cli +@_providers_configuration_loaded +def flower(args): + """Start Flower, Celery monitoring tool.""" + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + options = [ + "flower", + conf.get("celery", "BROKER_URL"), + f"--address={args.hostname}", + f"--port={args.port}", + ] + + if args.broker_api: + options.append(f"--broker-api={args.broker_api}") + + if args.url_prefix: + options.append(f"--url-prefix={args.url_prefix}") + + if args.basic_auth: + options.append(f"--basic-auth={args.basic_auth}") + + if args.flower_conf: + options.append(f"--conf={args.flower_conf}") + + _run_command_with_daemon_option( + args=args, process_name="flower", callback=lambda: celery_app.start(options) + ) + + +@contextmanager +def _serve_logs(skip_serve_logs: bool = False): + """Start serve_logs sub-process.""" + sub_proc = None + if skip_serve_logs is False: + sub_proc = Process(target=serve_logs) + sub_proc.start() + yield + if sub_proc: + sub_proc.terminate() + + +@after_setup_logger.connect() +@_providers_configuration_loaded +def logger_setup_handler(logger, **kwargs): + """ + Reconfigure the logger. + + * remove any previously configured handlers + * logs of severity error, and above goes to stderr, + * logs of severity lower than error goes to stdout. + """ + if conf.getboolean("logging", "celery_stdout_stderr_separation", fallback=False): + celery_formatter = logging.Formatter(DEFAULT_TASK_LOG_FMT) + + class NoErrorOrAboveFilter(logging.Filter): + """Allow only logs with level *lower* than ERROR to be reported.""" + + def filter(self, record): + return record.levelno < logging.ERROR + + below_error_handler = logging.StreamHandler(sys.stdout) + below_error_handler.addFilter(NoErrorOrAboveFilter()) + below_error_handler.setFormatter(celery_formatter) + + from_error_handler = logging.StreamHandler(sys.stderr) + from_error_handler.setLevel(logging.ERROR) + from_error_handler.setFormatter(celery_formatter) + + logger.handlers[:] = [below_error_handler, from_error_handler] + + +@cli_utils.action_cli +@_providers_configuration_loaded +def worker(args): + """Start Airflow Celery worker.""" + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + # Disable connection pool so that celery worker does not hold an unnecessary db connection + settings.reconfigure_orm(disable_connection_pool=True) + if not settings.validate_session(): + raise SystemExit("Worker exiting, database connection precheck failed.") + + autoscale = args.autoscale + skip_serve_logs = args.skip_serve_logs + + if autoscale is None and conf.has_option("celery", "worker_autoscale"): + autoscale = conf.get("celery", "worker_autoscale") + + if hasattr(celery_app.backend, "ResultSession"): + # Pre-create the database tables now, otherwise SQLA via Celery has a + # race condition where one of the subprocesses can die with "Table + # already exists" error, because SQLA checks for which tables exist, + # then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT + # EXISTS + try: + session = celery_app.backend.ResultSession() + session.close() + except sqlalchemy.exc.IntegrityError: + # At least on postgres, trying to create a table that already exist + # gives a unique constraint violation or the + # "pg_type_typname_nsp_index" table. If this happens we can ignore + # it, we raced to create the tables and lost. + pass + + # backwards-compatible: https://github.com/apache/airflow/pull/21506#pullrequestreview-879893763 + celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL") + if not celery_log_level: + celery_log_level = conf.get("logging", "LOGGING_LEVEL") + + # Setup Celery worker + options = [ + "worker", + "-O", + "fair", + "--queues", + args.queues, + "--concurrency", + args.concurrency, + "--hostname", + args.celery_hostname, + "--loglevel", + celery_log_level, + ] + if autoscale: + options.extend(["--autoscale", autoscale]) + if args.without_mingle: + options.append("--without-mingle") + if args.without_gossip: + options.append("--without-gossip") + + if conf.has_option("celery", "pool"): + pool = conf.get("celery", "pool") + options.extend(["--pool", pool]) + # Celery pools of type eventlet and gevent use greenlets, which + # requires monkey patching the app: + # https://eventlet.net/doc/patching.html#monkey-patch + # Otherwise task instances hang on the workers and are never + # executed. + maybe_patch_concurrency(["-P", pool]) + + worker_pid_file_path, stdout, stderr, log_file = setup_locations( + process=WORKER_PROCESS_NAME, + stdout=args.stdout, + stderr=args.stderr, + log=args.log_file, + pid=args.pid, + ) + + def run_celery_worker(): + with _serve_logs(skip_serve_logs): + celery_app.worker_main(options) + + if args.umask: + umask = args.umask + else: + umask = conf.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK) + + _run_command_with_daemon_option( + args=args, + process_name=WORKER_PROCESS_NAME, + callback=run_celery_worker, + should_setup_logging=True, + umask=umask, + pid_file=worker_pid_file_path, + ) + + +@cli_utils.action_cli +@_providers_configuration_loaded +def stop_worker(args): + """Send SIGTERM to Celery worker.""" + # Read PID from file + if args.pid: + pid_file_path = args.pid + else: + pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME) + pid = read_pid_from_pidfile(pid_file_path) + + # Send SIGTERM + if pid: + worker_process = psutil.Process(pid) + worker_process.terminate() + + # Remove pid file + remove_existing_pidfile(pid_file_path) diff --git a/airflow/providers/celery/executors/celery_executor.py b/airflow/providers/celery/executors/celery_executor.py index f4aa2a9b75df0..a8a4e320ec6a5 100644 --- a/airflow/providers/celery/executors/celery_executor.py +++ b/airflow/providers/celery/executors/celery_executor.py @@ -181,7 +181,7 @@ def __getattr__(name): ActionCommand( name="worker", help="Start a Celery worker node", - func=lazy_load_command("airflow.cli.commands.celery_command.worker"), + func=lazy_load_command("airflow.providers.celery.cli.celery_command.worker"), args=( ARG_QUEUES, ARG_CONCURRENCY, @@ -202,7 +202,7 @@ def __getattr__(name): ActionCommand( name="flower", help="Start a Celery Flower", - func=lazy_load_command("airflow.cli.commands.celery_command.flower"), + func=lazy_load_command("airflow.providers.celery.cli.celery_command.flower"), args=( ARG_FLOWER_HOSTNAME, ARG_FLOWER_PORT, @@ -221,7 +221,7 @@ def __getattr__(name): ActionCommand( name="stop", help="Stop the Celery worker gracefully", - func=lazy_load_command("airflow.cli.commands.celery_command.stop_worker"), + func=lazy_load_command("airflow.providers.celery.cli.celery_command.stop_worker"), args=(ARG_PID, ARG_VERBOSE), ), ) diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py index cbd87fea588d3..13cfff5275618 100644 --- a/tests/cli/commands/test_celery_command.py +++ b/tests/cli/commands/test_celery_command.py @@ -106,7 +106,7 @@ def test_same_pid_file_is_used_in_start_and_stop( assert mock_celery_app.worker_main.call_args args, _ = mock_celery_app.worker_main.call_args args_str = " ".join(map(str, args[0])) - assert f"--pidfile {pid_file}" in args_str + assert f"--pidfile {pid_file}" not in args_str # Call stop stop_args = self.parser.parse_args(["celery", "stop"]) @@ -134,7 +134,7 @@ def test_custom_pid_file_is_used_in_start_and_stop( assert mock_celery_app.worker_main.call_args args, _ = mock_celery_app.worker_main.call_args args_str = " ".join(map(str, args[0])) - assert f"--pidfile {pid_file}" in args_str + assert f"--pidfile {pid_file}" not in args_str stop_args = self.parser.parse_args(["celery", "stop", "--pid", pid_file]) celery_command.stop_worker(stop_args) @@ -194,8 +194,6 @@ def test_worker_started_with_required_arguments(self, mock_celery_app, mock_pope celery_hostname, "--loglevel", conf.get("logging", "CELERY_LOGGING_LEVEL"), - "--pidfile", - pid_file, "--autoscale", autoscale, "--without-mingle", diff --git a/tests/providers/celery/cli/__init__.py b/tests/providers/celery/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/celery/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/celery/cli/test_celery_command.py b/tests/providers/celery/cli/test_celery_command.py new file mode 100644 index 0000000000000..c896f5d87d244 --- /dev/null +++ b/tests/providers/celery/cli/test_celery_command.py @@ -0,0 +1,353 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib +import os +from argparse import Namespace +from unittest import mock + +import pytest +import sqlalchemy + +import airflow +from airflow.cli import cli_parser +from airflow.configuration import conf +from airflow.providers.celery.cli import celery_command +from tests.test_utils.config import conf_vars + +pytestmark = pytest.mark.db_test + + +class TestWorkerPrecheck: + @mock.patch("airflow.settings.validate_session") + def test_error(self, mock_validate_session): + """ + Test to verify the exit mechanism of airflow-worker cli + by mocking validate_session method + """ + mock_validate_session.return_value = False + with pytest.raises(SystemExit) as ctx, conf_vars({("core", "executor"): "CeleryExecutor"}): + celery_command.worker(Namespace(queues=1, concurrency=1)) + assert str(ctx.value) == "Worker exiting, database connection precheck failed." + + @conf_vars({("celery", "worker_precheck"): "False"}) + def test_worker_precheck_exception(self): + """ + Test to check the behaviour of validate_session method + when worker_precheck is absent in airflow configuration + """ + assert airflow.settings.validate_session() + + @mock.patch("sqlalchemy.orm.session.Session.execute") + @conf_vars({("celery", "worker_precheck"): "True"}) + def test_validate_session_dbapi_exception(self, mock_session): + """ + Test to validate connection failure scenario on SELECT 1 query + """ + mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2", "m3", "m4") + assert airflow.settings.validate_session() is False + + +@pytest.mark.backend("mysql", "postgres") +class TestCeleryStopCommand: + @classmethod + def setup_class(cls): + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() + + @mock.patch("airflow.providers.celery.cli.celery_command.setup_locations") + @mock.patch("airflow.providers.celery.cli.celery_command.psutil.Process") + def test_if_right_pid_is_read(self, mock_process, mock_setup_locations, tmp_path): + args = self.parser.parse_args(["celery", "stop"]) + pid = "123" + path = tmp_path / "testfile" + # Create pid file + path.write_text(pid) + # Setup mock + mock_setup_locations.return_value = (os.fspath(path), None, None, None) + + # Calling stop_worker should delete the temporary pid file + celery_command.stop_worker(args) + # Check if works as expected + assert not path.exists() + mock_process.assert_called_once_with(int(pid)) + mock_process.return_value.terminate.assert_called_once_with() + + @mock.patch("airflow.providers.celery.cli.celery_command.read_pid_from_pidfile") + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + @mock.patch("airflow.providers.celery.cli.celery_command.setup_locations") + def test_same_pid_file_is_used_in_start_and_stop( + self, mock_setup_locations, mock_celery_app, mock_read_pid_from_pidfile + ): + pid_file = "test_pid_file" + mock_setup_locations.return_value = (pid_file, None, None, None) + mock_read_pid_from_pidfile.return_value = None + + # Call worker + worker_args = self.parser.parse_args(["celery", "worker", "--skip-serve-logs"]) + celery_command.worker(worker_args) + assert mock_celery_app.worker_main.call_args + args, _ = mock_celery_app.worker_main.call_args + args_str = " ".join(map(str, args[0])) + assert f"--pidfile {pid_file}" not in args_str + + # Call stop + stop_args = self.parser.parse_args(["celery", "stop"]) + celery_command.stop_worker(stop_args) + mock_read_pid_from_pidfile.assert_called_once_with(pid_file) + + @mock.patch("airflow.providers.celery.cli.celery_command.remove_existing_pidfile") + @mock.patch("airflow.providers.celery.cli.celery_command.read_pid_from_pidfile") + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + @mock.patch("airflow.providers.celery.cli.celery_command.psutil.Process") + @mock.patch("airflow.providers.celery.cli.celery_command.setup_locations") + def test_custom_pid_file_is_used_in_start_and_stop( + self, + mock_setup_locations, + mock_process, + mock_celery_app, + mock_read_pid_from_pidfile, + mock_remove_existing_pidfile, + ): + pid_file = "custom_test_pid_file" + mock_setup_locations.return_value = (pid_file, None, None, None) + # Call worker + worker_args = self.parser.parse_args(["celery", "worker", "--skip-serve-logs", "--pid", pid_file]) + celery_command.worker(worker_args) + assert mock_celery_app.worker_main.call_args + args, _ = mock_celery_app.worker_main.call_args + args_str = " ".join(map(str, args[0])) + assert f"--pidfile {pid_file}" not in args_str + + stop_args = self.parser.parse_args(["celery", "stop", "--pid", pid_file]) + celery_command.stop_worker(stop_args) + + mock_read_pid_from_pidfile.assert_called_once_with(pid_file) + mock_process.return_value.terminate.assert_called() + mock_remove_existing_pidfile.assert_called_once_with(pid_file) + + +@pytest.mark.backend("mysql", "postgres") +class TestWorkerStart: + @classmethod + def setup_class(cls): + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() + + @mock.patch("airflow.providers.celery.cli.celery_command.setup_locations") + @mock.patch("airflow.providers.celery.cli.celery_command.Process") + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + def test_worker_started_with_required_arguments(self, mock_celery_app, mock_popen, mock_locations): + pid_file = "pid_file" + mock_locations.return_value = (pid_file, None, None, None) + concurrency = "1" + celery_hostname = "celery_hostname" + queues = "queue" + autoscale = "2,5" + args = self.parser.parse_args( + [ + "celery", + "worker", + "--autoscale", + autoscale, + "--concurrency", + concurrency, + "--celery-hostname", + celery_hostname, + "--queues", + queues, + "--without-mingle", + "--without-gossip", + ] + ) + + celery_command.worker(args) + + mock_celery_app.worker_main.assert_called_once_with( + [ + "worker", + "-O", + "fair", + "--queues", + queues, + "--concurrency", + int(concurrency), + "--hostname", + celery_hostname, + "--loglevel", + conf.get("logging", "CELERY_LOGGING_LEVEL"), + "--autoscale", + autoscale, + "--without-mingle", + "--without-gossip", + "--pool", + "prefork", + ] + ) + + +@pytest.mark.backend("mysql", "postgres") +class TestWorkerFailure: + @classmethod + def setup_class(cls): + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() + + @mock.patch("airflow.providers.celery.cli.celery_command.Process") + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + def test_worker_failure_gracefull_shutdown(self, mock_celery_app, mock_popen): + args = self.parser.parse_args(["celery", "worker"]) + mock_celery_app.run.side_effect = Exception("Mock exception to trigger runtime error") + try: + celery_command.worker(args) + finally: + mock_popen().terminate.assert_called() + + +@pytest.mark.backend("mysql", "postgres") +class TestFlowerCommand: + @classmethod + def setup_class(cls): + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() + + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + def test_run_command(self, mock_celery_app): + args = self.parser.parse_args( + [ + "celery", + "flower", + "--basic-auth", + "admin:admin", + "--broker-api", + "http://username:password@rabbitmq-server-name:15672/api/", + "--flower-conf", + "flower_config", + "--hostname", + "my-hostname", + "--port", + "3333", + "--url-prefix", + "flower-monitoring", + ] + ) + + celery_command.flower(args) + mock_celery_app.start.assert_called_once_with( + [ + "flower", + conf.get("celery", "BROKER_URL"), + "--address=my-hostname", + "--port=3333", + "--broker-api=http://username:password@rabbitmq-server-name:15672/api/", + "--url-prefix=flower-monitoring", + "--basic-auth=admin:admin", + "--conf=flower_config", + ] + ) + + @mock.patch("airflow.cli.commands.daemon_utils.TimeoutPIDLockFile") + @mock.patch("airflow.cli.commands.daemon_utils.setup_locations") + @mock.patch("airflow.cli.commands.daemon_utils.daemon") + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + def test_run_command_daemon(self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file): + mock_setup_locations.return_value = ( + mock.MagicMock(name="pidfile"), + mock.MagicMock(name="stdout"), + mock.MagicMock(name="stderr"), + mock.MagicMock(name="INVALID"), + ) + args = self.parser.parse_args( + [ + "celery", + "flower", + "--basic-auth", + "admin:admin", + "--broker-api", + "http://username:password@rabbitmq-server-name:15672/api/", + "--flower-conf", + "flower_config", + "--hostname", + "my-hostname", + "--log-file", + "/tmp/flower.log", + "--pid", + "/tmp/flower.pid", + "--port", + "3333", + "--stderr", + "/tmp/flower-stderr.log", + "--stdout", + "/tmp/flower-stdout.log", + "--url-prefix", + "flower-monitoring", + "--daemon", + ] + ) + mock_open = mock.mock_open() + with mock.patch("airflow.cli.commands.daemon_utils.open", mock_open): + celery_command.flower(args) + + mock_celery_app.start.assert_called_once_with( + [ + "flower", + conf.get("celery", "BROKER_URL"), + "--address=my-hostname", + "--port=3333", + "--broker-api=http://username:password@rabbitmq-server-name:15672/api/", + "--url-prefix=flower-monitoring", + "--basic-auth=admin:admin", + "--conf=flower_config", + ] + ) + assert mock_daemon.mock_calls[:3] == [ + mock.call.DaemonContext( + pidfile=mock_pid_file.return_value, + files_preserve=None, + stdout=mock_open.return_value, + stderr=mock_open.return_value, + umask=0o077, + ), + mock.call.DaemonContext().__enter__(), + mock.call.DaemonContext().__exit__(None, None, None), + ] + + assert mock_setup_locations.mock_calls == [ + mock.call( + process="flower", + stdout="/tmp/flower-stdout.log", + stderr="/tmp/flower-stderr.log", + log="/tmp/flower.log", + ) + ] + mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], -1)]) + assert mock_open.mock_calls == [ + mock.call(mock_setup_locations.return_value[1], "a"), + mock.call().__enter__(), + mock.call(mock_setup_locations.return_value[2], "a"), + mock.call().__enter__(), + mock.call().truncate(0), + mock.call().truncate(0), + mock.call().__exit__(None, None, None), + mock.call().__exit__(None, None, None), + ]