diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index dd1868a1920cb..ea84ad6dab646 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -53,6 +53,7 @@ from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin, remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.singleton import Singleton from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: @@ -73,17 +74,11 @@ POD_EXECUTOR_DONE_KEY = "airflow_executor_done" -class ResourceVersion: +class ResourceVersion(metaclass=Singleton): """Singleton for tracking resourceVersion from Kubernetes.""" - _instance: ResourceVersion | None = None resource_version: dict[str, str] = {} - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): """Watches for Kubernetes jobs.""" diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index e2508f6dc09d8..0563abeaf572f 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -39,6 +39,7 @@ from airflow.utils.entry_points import entry_points_with_dist from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.singleton import Singleton log = logging.getLogger(__name__) @@ -355,7 +356,7 @@ def wrapped_function(*args, **kwargs): return provider_info_cache_decorator -class ProvidersManager(LoggingMixin): +class ProvidersManager(LoggingMixin, metaclass=Singleton): """ Manages all provider packages. @@ -364,14 +365,8 @@ class ProvidersManager(LoggingMixin): local source folders (if airflow is run from sources). """ - _instance = None resource_version = "0" - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): """Initializes the manager.""" super().__init__() diff --git a/airflow/utils/singleton.py b/airflow/utils/singleton.py new file mode 100644 index 0000000000000..cfc97eddbfcfc --- /dev/null +++ b/airflow/utils/singleton.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class Singleton(type, Generic[T]): + """Metaclass that allows to implement singleton pattern.""" + + _instances: dict[Singleton[T], T] = {} + + def __call__(cls: Singleton[T], *args, **kwargs) -> T: + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/tests/always/test_providers_manager.py b/tests/always/test_providers_manager.py index 106755dc4b516..b99dbcb84fa33 100644 --- a/tests/always/test_providers_manager.py +++ b/tests/always/test_providers_manager.py @@ -36,6 +36,11 @@ class TestProviderManager: def inject_fixtures(self, caplog): self._caplog = caplog + @pytest.fixture(autouse=True, scope="function") + def clean(self): + """The tests depend on a clean state of a ProvidersManager.""" + ProvidersManager().__init__() + def test_providers_are_loaded(self): with self._caplog.at_level(logging.WARNING): provider_manager = ProvidersManager() diff --git a/tests/utils/test_singleton.py b/tests/utils/test_singleton.py new file mode 100644 index 0000000000000..57145fe7b97ba --- /dev/null +++ b/tests/utils/test_singleton.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.utils.singleton import Singleton + + +class A(metaclass=Singleton): + pass + + +class Counter(metaclass=Singleton): + """Singleton class that counts how much __init__ and count was called.""" + + counter = 0 + + def __init__(self): + self.counter += 1 + + def count(self): + self.counter += 1 + + +def test_singleton_refers_to_same_instance(): + a, b = A(), A() + assert a is b + + +def test_singleton_after_out_of_context_does_refer_to_same_instance(): + # check if setting something on singleton is preserved after instance goes out of context + def x(): + a = A() + a.a = "a" + + x() + b = A() + assert b.a == "a" + + +def test_singleton_does_not_call_init_second_time(): + # first creation of Counter, check if __init__ is called + c = Counter() + assert c.counter == 1 + + # check if "new instance" calls __init__ - it shouldn't + d = Counter() + assert c.counter == 1 + + # check if incrementing "new instance" increments counter on previous one + d.count() + assert c.counter == 2