From 8c94a790432937a73fa43660d20d93e707de5cd9 Mon Sep 17 00:00:00 2001 From: Saul Vargas Date: Sun, 19 May 2024 11:56:54 +0300 Subject: [PATCH] fix varioius type hinting issues --- tests/test_injection.py | 9 +++++++++ that_depends/__init__.py | 1 + that_depends/container.py | 2 +- that_depends/injection.py | 4 +++- that_depends/providers/base.py | 10 ++++++---- that_depends/providers/collections.py | 8 ++++---- that_depends/providers/resources.py | 2 +- 7 files changed, 25 insertions(+), 11 deletions(-) diff --git a/tests/test_injection.py b/tests/test_injection.py index 20b6111..16c8689 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -1,3 +1,4 @@ +import asyncio import datetime import pytest @@ -34,3 +35,11 @@ async def inner( with pytest.raises(RuntimeError, match="Injected arguments must not be redefined"): await inner(_=container.SimpleFactory(dep1="1", dep2=2)) + + +def test_type_check() -> None: + @inject + async def main() -> None: + pass + + asyncio.run(main()) diff --git a/that_depends/__init__.py b/that_depends/__init__.py index 14fe7ce..f5a7e01 100644 --- a/that_depends/__init__.py +++ b/that_depends/__init__.py @@ -1,3 +1,4 @@ +from that_depends import providers from that_depends.container import BaseContainer from that_depends.injection import Provide, inject diff --git a/that_depends/container.py b/that_depends/container.py index 6d56441..d5c0f1f 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -45,7 +45,7 @@ async def _inner() -> T: return _inner @classmethod - async def resolve(cls, object_to_resolve: type[T] | typing.Callable[P, T]) -> T: + async def resolve(cls, object_to_resolve: type[T] | typing.Callable[..., T]) -> T: signature = inspect.signature(object_to_resolve) kwargs = {} providers = cls.get_providers() diff --git a/that_depends/injection.py b/that_depends/injection.py index c7fcb7b..7c9f238 100644 --- a/that_depends/injection.py +++ b/that_depends/injection.py @@ -9,7 +9,9 @@ T = typing.TypeVar("T") -def inject(func: typing.Callable[P, typing.Awaitable[T]]) -> typing.Callable[P, typing.Awaitable[T]]: +def inject( + func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]], +) -> typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]]: signature = inspect.signature(func) @functools.wraps(func) diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index e32b2dc..59d6da0 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -3,20 +3,22 @@ T = typing.TypeVar("T") +R = typing.TypeVar("R") +T_co = typing.TypeVar("T_co", covariant=True) -class AbstractProvider(typing.Generic[T], abc.ABC): +class AbstractProvider(typing.Generic[T_co], abc.ABC): """Abstract Provider Class.""" @abc.abstractmethod - async def async_resolve(self) -> T: + async def async_resolve(self) -> T_co: """Resolve dependency asynchronously.""" @abc.abstractmethod - def sync_resolve(self) -> T: + def sync_resolve(self) -> T_co: """Resolve dependency synchronously.""" - async def __call__(self) -> T: + async def __call__(self) -> T_co: return await self.async_resolve() def override(self, mock: object) -> None: diff --git a/that_depends/providers/collections.py b/that_depends/providers/collections.py index c2eb022..5e2b1c5 100644 --- a/that_depends/providers/collections.py +++ b/that_depends/providers/collections.py @@ -6,15 +6,15 @@ T = typing.TypeVar("T") -class List(AbstractProvider[T]): +class List(AbstractProvider[list[T]]): def __init__(self, *providers: AbstractProvider[T]) -> None: self._providers = providers - async def async_resolve(self) -> list[T]: # type: ignore[override] + async def async_resolve(self) -> list[T]: return [await x.async_resolve() for x in self._providers] - def sync_resolve(self) -> list[T]: # type: ignore[override] + def sync_resolve(self) -> list[T]: return [x.sync_resolve() for x in self._providers] - async def __call__(self) -> list[T]: # type: ignore[override] + async def __call__(self) -> list[T]: return await self.async_resolve() diff --git a/that_depends/providers/resources.py b/that_depends/providers/resources.py index 58db825..840a85f 100644 --- a/that_depends/providers/resources.py +++ b/that_depends/providers/resources.py @@ -12,7 +12,7 @@ class Resource(AbstractResource[T]): def __init__( self, - creator: typing.Callable[..., typing.Iterator[T]], + creator: typing.Callable[P, typing.Iterator[T]], *args: P.args, **kwargs: P.kwargs, ) -> None: