Skip to content

Commit

Permalink
Merge pull request #16 from modern-python/13-thread-safe-singletons-a…
Browse files Browse the repository at this point in the history
…nd-resources

Thread-safe singletons and resources
  • Loading branch information
lesnik512 authored Nov 17, 2024
2 parents 762f9c7 + 1c50c39 commit 66103a4
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
pull_request: {}

concurrency:
group: ${{ github.head_ref || github.run_id }} core
group: ${{ github.head_ref || github.run_id }} core lint
cancel-in-progress: true

jobs:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
- '.github/workflows/test-core.yml'

concurrency:
group: ${{ github.head_ref || github.run_id }} core
group: ${{ github.head_ref || github.run_id }} core test
cancel-in-progress: true

jobs:
Expand Down
33 changes: 27 additions & 6 deletions packages/modern-di/modern_di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,31 @@


class Container(contextlib.AbstractAsyncContextManager["Container"]):
__slots__ = "scope", "parent_container", "context", "_is_async", "_provider_states", "_overrides"
__slots__ = (
"scope",
"parent_container",
"context",
"_is_async",
"_provider_states",
"_overrides",
"_use_threading_lock",
)

def __init__(
self,
*,
scope: enum.IntEnum,
parent_container: typing.Optional["Container"] = None,
context: dict[str, typing.Any] | None = None,
use_threading_lock: bool = True,
) -> None:
self.scope = scope
self.parent_container = parent_container
self.context: dict[str, typing.Any] = context or {}
self._is_async: bool | None = None
self._provider_states: dict[str, ProviderState[typing.Any]] = {}
self._overrides: dict[str, typing.Any] = {}
self._use_threading_lock = use_threading_lock

def _exit(self) -> None:
self._is_async = None
Expand Down Expand Up @@ -74,17 +84,28 @@ def find_container(self, scope: enum.IntEnum) -> "typing_extensions.Self":
return container

def fetch_provider_state(
self, provider_id: str, is_async_resource: bool = False, is_lock_required: bool = False
self,
provider_id: str,
is_async_resource: bool = False,
use_asyncio_lock: bool = False,
use_threading_lock: bool = False,
) -> ProviderState[typing.Any]:
self._check_entered()
if is_async_resource and self._is_async is False:
msg = "Resolving async resource in sync container is not allowed"
raise RuntimeError(msg)

if provider_id not in self._provider_states:
self._provider_states[provider_id] = ProviderState(is_lock_required=is_lock_required)

return self._provider_states[provider_id]
if provider_state := self._provider_states.get(provider_id):
return provider_state

# expected to be thread-safe, because setdefault is atomic
return self._provider_states.setdefault(
provider_id,
ProviderState(
use_asyncio_lock=use_asyncio_lock,
use_threading_lock=self._use_threading_lock and use_threading_lock,
),
)

def override(self, provider_id: str, override_object: object) -> None:
self._overrides[provider_id] = override_object
Expand Down
8 changes: 5 additions & 3 deletions packages/modern-di/modern_di/provider_state.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import asyncio
import contextlib
import threading
import typing


T_co = typing.TypeVar("T_co", covariant=True)


class ProviderState(typing.Generic[T_co]):
__slots__ = "context_stack", "instance", "provider_lock"
__slots__ = "context_stack", "instance", "asyncio_lock", "threading_lock"

def __init__(self, is_lock_required: bool) -> None:
def __init__(self, use_asyncio_lock: bool, use_threading_lock: bool) -> None:
self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None
self.instance: T_co | None = None
self.provider_lock: typing.Final = asyncio.Lock() if is_lock_required else None
self.asyncio_lock: typing.Final = asyncio.Lock() if use_asyncio_lock else None
self.threading_lock: typing.Final = threading.Lock() if use_threading_lock else None

async def async_tear_down(self) -> None:
if self.context_stack is None:
Expand Down
33 changes: 23 additions & 10 deletions packages/modern-di/modern_di/providers/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ async def async_resolve(self, container: Container) -> T_co:
return typing.cast(T_co, override)

provider_state = container.fetch_provider_state(
self.provider_id, is_async_resource=self._is_async, is_lock_required=True
self.provider_id, is_async_resource=self._is_async, use_asyncio_lock=True
)
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

assert provider_state.provider_lock
await provider_state.provider_lock.acquire()
if provider_state.asyncio_lock:
await provider_state.asyncio_lock.acquire()

try:
if provider_state.instance is not None:
Expand All @@ -79,7 +79,8 @@ async def async_resolve(self, container: Container) -> T_co:
provider_state.context_stack = contextlib.ExitStack()
provider_state.instance = provider_state.context_stack.enter_context(_intermediate_)
finally:
provider_state.provider_lock.release()
if provider_state.asyncio_lock:
provider_state.asyncio_lock.release()

return typing.cast(T_co, provider_state.instance)

Expand All @@ -88,19 +89,31 @@ def sync_resolve(self, container: Container) -> T_co:
if (override := container.fetch_override(self.provider_id)) is not None:
return typing.cast(T_co, override)

provider_state = container.fetch_provider_state(self.provider_id)
provider_state = container.fetch_provider_state(
self.provider_id, is_async_resource=self._is_async, use_threading_lock=True
)
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

if self._is_async:
msg = "Async resource cannot be resolved synchronously"
raise RuntimeError(msg)

_intermediate_ = self._sync_build_creator(container)
if provider_state.threading_lock:
provider_state.threading_lock.acquire()

provider_state.context_stack = contextlib.ExitStack()
provider_state.instance = provider_state.context_stack.enter_context(
typing.cast(contextlib.AbstractContextManager[typing.Any], _intermediate_)
)
try:
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

_intermediate_ = self._sync_build_creator(container)

provider_state.context_stack = contextlib.ExitStack()
provider_state.instance = provider_state.context_stack.enter_context(
typing.cast(contextlib.AbstractContextManager[typing.Any], _intermediate_)
)
finally:
if provider_state.threading_lock:
provider_state.threading_lock.release()

return typing.cast(T_co, provider_state.instance)
23 changes: 17 additions & 6 deletions packages/modern-di/modern_di/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ async def async_resolve(self, container: Container) -> T_co:
if (override := container.fetch_override(self.provider_id)) is not None:
return typing.cast(T_co, override)

provider_state = container.fetch_provider_state(self.provider_id, is_lock_required=True)
provider_state = container.fetch_provider_state(self.provider_id, use_asyncio_lock=True)
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

assert provider_state.provider_lock
await provider_state.provider_lock.acquire()
assert provider_state.asyncio_lock
await provider_state.asyncio_lock.acquire()

try:
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

provider_state.instance = typing.cast(T_co, await self._async_build_creator(container))
finally:
provider_state.provider_lock.release()
provider_state.asyncio_lock.release()

return provider_state.instance

Expand All @@ -48,9 +48,20 @@ def sync_resolve(self, container: Container) -> T_co:
if (override := container.fetch_override(self.provider_id)) is not None:
return typing.cast(T_co, override)

provider_state = container.fetch_provider_state(self.provider_id)
provider_state = container.fetch_provider_state(self.provider_id, use_threading_lock=True)
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

provider_state.instance = self._sync_build_creator(container)
if provider_state.threading_lock:
provider_state.threading_lock.acquire()

try:
if provider_state.instance is not None:
return typing.cast(T_co, provider_state.instance)

provider_state.instance = self._sync_build_creator(container)
finally:
if provider_state.threading_lock:
provider_state.threading_lock.release()

return typing.cast(T_co, provider_state.instance)
1 change: 1 addition & 0 deletions packages/modern-di/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-asyncio",
"pytest-repeat",
"ruff",
"mypy",
"typing-extensions",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ async def test_injected_async_factory_in_sync_mode() -> None:
with pytest.raises(RuntimeError, match="Resolving async resource in sync container is not allowed"):
await request_async_factory.factory_provider.async_resolve(request_container)

with pytest.raises(RuntimeError, match="Async resource cannot be resolved synchronously"):
with pytest.raises(RuntimeError, match="Resolving async resource in sync container is not allowed"):
request_async_factory.factory_provider.sync_resolve(request_container)
50 changes: 42 additions & 8 deletions packages/modern-di/tests_core/providers/test_resource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import threading
import time
import typing
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from modern_di import Container, Scope, providers
Expand Down Expand Up @@ -121,7 +124,19 @@ async def test_sync_resource_overridden() -> None:
assert sync_resource4 is sync_resource1


async def test_async_resource_race_condition() -> None:
async def test_resource_unsupported_creator() -> None:
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.Resource(Scope.APP, None) # type: ignore[arg-type]


async def test_async_resource_sync_resolve() -> None:
async with Container(scope=Scope.APP) as app_container:
with pytest.raises(RuntimeError, match="Async resource cannot be resolved synchronously"):
async_resource.sync_resolve(app_container)


@pytest.mark.repeat(10)
async def test_resource_async_resolve_race_condition() -> None:
calls: int = 0

async def create_resource() -> typing.AsyncIterator[str]:
Expand All @@ -141,12 +156,31 @@ async def resolve_resource(container: Container) -> str:
assert calls == 1


async def test_resource_unsupported_creator() -> None:
with pytest.raises(TypeError, match="Unsupported resource type"):
providers.Resource(Scope.APP, None) # type: ignore[arg-type]
@pytest.mark.repeat(10)
def test_resource_sync_resolve_race_condition() -> None:
calls: int = 0
lock = threading.Lock()

def create_resource() -> typing.Iterator[str]:
nonlocal calls
with lock:
calls += 1
time.sleep(0.01)
yield ""

async def test_async_resource_sync_resolve() -> None:
async with Container(scope=Scope.APP) as app_container:
with pytest.raises(RuntimeError, match="Async resource cannot be resolved synchronously"):
async_resource.sync_resolve(app_container)
resource = providers.Resource(Scope.APP, create_resource)

def resolve_resource(container: Container) -> str:
return resource.sync_resolve(container)

with Container(scope=Scope.APP) as app_container, ThreadPoolExecutor(max_workers=4) as pool:
tasks = [
pool.submit(resolve_resource, app_container),
pool.submit(resolve_resource, app_container),
pool.submit(resolve_resource, app_container),
pool.submit(resolve_resource, app_container),
]
results = [x.result() for x in as_completed(tasks)]

assert results == ["", "", "", ""]
assert calls == 1
47 changes: 41 additions & 6 deletions packages/modern-di/tests_core/providers/test_singleton.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import dataclasses
import threading
import time
import typing
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
from modern_di import Container, Scope, providers
Expand Down Expand Up @@ -84,7 +87,16 @@ async def test_singleton_overridden() -> None:
assert singleton4 is singleton1


async def test_singleton_race_condition() -> None:
async def test_singleton_wrong_dependency_scope() -> None:
def some_factory(_: SimpleCreator) -> None: ...

request_singleton_ = providers.Singleton(Scope.REQUEST, SimpleCreator, dep1="original")
with pytest.raises(RuntimeError, match="Scope of dependency cannot be more than scope of dependent"):
providers.Singleton(Scope.APP, some_factory, request_singleton_.cast)


@pytest.mark.repeat(10)
async def test_singleton_async_resolve_race_condition() -> None:
calls: int = 0

async def create_resource() -> typing.AsyncIterator[str]:
Expand All @@ -106,9 +118,32 @@ async def resolve_factory(container: Container) -> SimpleCreator:
assert calls == 1


async def test_singleton_wrong_dependency_scope() -> None:
def some_factory(_: SimpleCreator) -> None: ...
@pytest.mark.repeat(10)
def test_resource_sync_resolve_race_condition() -> None:
calls: int = 0
lock = threading.Lock()

request_singleton_ = providers.Singleton(Scope.REQUEST, SimpleCreator, dep1="original")
with pytest.raises(RuntimeError, match="Scope of dependency cannot be more than scope of dependent"):
providers.Singleton(Scope.APP, some_factory, request_singleton_.cast)
def create_resource() -> typing.Iterator[str]:
nonlocal calls
with lock:
calls += 1
time.sleep(0.01)
yield ""

resource = providers.Resource(Scope.APP, create_resource)
factory_with_resource = providers.Singleton(Scope.APP, SimpleCreator, dep1=resource.cast)

def resolve_factory(container: Container) -> SimpleCreator:
return factory_with_resource.sync_resolve(container)

with Container(scope=Scope.APP) as app_container, ThreadPoolExecutor(max_workers=4) as pool:
tasks = [
pool.submit(resolve_factory, app_container),
pool.submit(resolve_factory, app_container),
pool.submit(resolve_factory, app_container),
pool.submit(resolve_factory, app_container),
]
results = [x.result() for x in as_completed(tasks)]

assert all(isinstance(x, SimpleCreator) for x in results)
assert calls == 1
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ addopts = "--cov=. --cov-report term-missing"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.report]
exclude_also = ["if typing.TYPE_CHECKING:"]
[tool.coverage]
run.concurrency = ["thread"]
report.exclude_also = ["if typing.TYPE_CHECKING:"]

0 comments on commit 66103a4

Please sign in to comment.