Skip to content

Commit

Permalink
fix bugged singleton implementation (#32218)
Browse files Browse the repository at this point in the history
* fix bugged singleton implementation

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski authored Jun 29, 2023
1 parent 6b4350e commit e2e707c
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 14 deletions.
9 changes: 2 additions & 7 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import LoggingMixin, remove_escape_codes
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.singleton import Singleton
from airflow.utils.state import State, TaskInstanceState

if TYPE_CHECKING:
Expand All @@ -73,17 +74,11 @@
POD_EXECUTOR_DONE_KEY = "airflow_executor_done"


class ResourceVersion:
class ResourceVersion(metaclass=Singleton):
"""Singleton for tracking resourceVersion from Kubernetes."""

_instance: ResourceVersion | None = None
resource_version: dict[str, str] = {}

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance


class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
"""Watches for Kubernetes jobs."""
Expand Down
9 changes: 2 additions & 7 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
from airflow.utils.singleton import Singleton

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -355,7 +356,7 @@ def wrapped_function(*args, **kwargs):
return provider_info_cache_decorator


class ProvidersManager(LoggingMixin):
class ProvidersManager(LoggingMixin, metaclass=Singleton):
"""
Manages all provider packages.
Expand All @@ -364,14 +365,8 @@ class ProvidersManager(LoggingMixin):
local source folders (if airflow is run from sources).
"""

_instance = None
resource_version = "0"

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
"""Initializes the manager."""
super().__init__()
Expand Down
33 changes: 33 additions & 0 deletions airflow/utils/singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Generic, TypeVar

T = TypeVar("T")


class Singleton(type, Generic[T]):
"""Metaclass that allows to implement singleton pattern."""

_instances: dict[Singleton[T], T] = {}

def __call__(cls: Singleton[T], *args, **kwargs) -> T:
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
5 changes: 5 additions & 0 deletions tests/always/test_providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class TestProviderManager:
def inject_fixtures(self, caplog):
self._caplog = caplog

@pytest.fixture(autouse=True, scope="function")
def clean(self):
"""The tests depend on a clean state of a ProvidersManager."""
ProvidersManager().__init__()

def test_providers_are_loaded(self):
with self._caplog.at_level(logging.WARNING):
provider_manager = ProvidersManager()
Expand Down
65 changes: 65 additions & 0 deletions tests/utils/test_singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.utils.singleton import Singleton


class A(metaclass=Singleton):
pass


class Counter(metaclass=Singleton):
"""Singleton class that counts how much __init__ and count was called."""

counter = 0

def __init__(self):
self.counter += 1

def count(self):
self.counter += 1


def test_singleton_refers_to_same_instance():
a, b = A(), A()
assert a is b


def test_singleton_after_out_of_context_does_refer_to_same_instance():
# check if setting something on singleton is preserved after instance goes out of context
def x():
a = A()
a.a = "a"

x()
b = A()
assert b.a == "a"


def test_singleton_does_not_call_init_second_time():
# first creation of Counter, check if __init__ is called
c = Counter()
assert c.counter == 1

# check if "new instance" calls __init__ - it shouldn't
d = Counter()
assert c.counter == 1

# check if incrementing "new instance" increments counter on previous one
d.count()
assert c.counter == 2

0 comments on commit e2e707c

Please sign in to comment.