Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(typing): improve typing of WrappedFn #390

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 65 additions & 32 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import sys
import threading
Expand Down Expand Up @@ -91,37 +92,8 @@
from .wait import WaitBaseT


WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])
_RetValT = t.TypeVar("_RetValT")


def retry(*dargs: t.Any, **dkw: t.Any) -> t.Union[WrappedFn, t.Callable[[WrappedFn], WrappedFn]]: # noqa
"""Wrap a function with a new `Retrying` object.

:param dargs: positional arguments passed to Retrying object
:param dkw: keyword arguments passed to the Retrying object
"""
# support both @retry and @retry() as valid syntax
if len(dargs) == 1 and callable(dargs[0]):
return retry()(dargs[0])
else:

def wrap(f: WrappedFn) -> WrappedFn:
if isinstance(f, retry_base):
warnings.warn(
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
)
if iscoroutinefunction(f):
r: "BaseRetrying" = AsyncRetrying(*dargs, **dkw)
elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f):
r = TornadoRetrying(*dargs, **dkw)
else:
r = Retrying(*dargs, **dkw)

return r.wraps(f)

return wrap


class TryAgain(Exception):
Expand Down Expand Up @@ -382,14 +354,24 @@ def __iter__(self) -> t.Generator[AttemptManager, None, None]:
break

@abstractmethod
def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
def __call__(
self,
fn: t.Callable[..., WrappedFnReturnT],
*args: t.Any,
**kwargs: t.Any,
) -> WrappedFnReturnT:
pass


class Retrying(BaseRetrying):
"""Retrying controller."""

def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
def __call__(
self,
fn: t.Callable[..., WrappedFnReturnT],
*args: t.Any,
**kwargs: t.Any,
) -> WrappedFnReturnT:
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
Expand Down Expand Up @@ -510,6 +492,57 @@ def __repr__(self) -> str:
return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>"


@t.overload
def retry(func: WrappedFn) -> WrappedFn:
...


@t.overload
def retry(
sleep: t.Callable[[t.Union[int, float]], None] = sleep,
stop: "StopBaseT" = stop_never,
wait: "WaitBaseT" = wait_none(),
retry: "RetryBaseT" = retry_if_exception_type(),
before: t.Callable[["RetryCallState"], None] = before_nothing,
after: t.Callable[["RetryCallState"], None] = after_nothing,
before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None,
reraise: bool = False,
retry_error_cls: t.Type["RetryError"] = RetryError,
retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None,
) -> t.Callable[[WrappedFn], WrappedFn]:
...


def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any:
"""Wrap a function with a new `Retrying` object.

:param dargs: positional arguments passed to Retrying object
:param dkw: keyword arguments passed to the Retrying object
"""
# support both @retry and @retry() as valid syntax
if len(dargs) == 1 and callable(dargs[0]):
return retry()(dargs[0])
else:

def wrap(f: WrappedFn) -> WrappedFn:
if isinstance(f, retry_base):
warnings.warn(
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
)
r: "BaseRetrying"
if iscoroutinefunction(f):
r = AsyncRetrying(*dargs, **dkw)
elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f):
r = TornadoRetrying(*dargs, **dkw)
else:
r = Retrying(*dargs, **dkw)

return r.wraps(f)

return wrap


from tenacity._asyncio import AsyncRetrying # noqa:E402,I100

if tornado:
Expand Down
24 changes: 10 additions & 14 deletions tenacity/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import functools
import sys
import typing
import typing as t
from asyncio import sleep

from tenacity import AttemptManager
Expand All @@ -26,24 +26,20 @@
from tenacity import DoSleep
from tenacity import RetryCallState


WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable[..., typing.Any])
_RetValT = typing.TypeVar("_RetValT")
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])


class AsyncRetrying(BaseRetrying):
def __init__(
self, sleep: typing.Callable[[float], typing.Awaitable[typing.Any]] = sleep, **kwargs: typing.Any
) -> None:
sleep: t.Callable[[float], t.Awaitable[t.Any]]

def __init__(self, sleep: t.Callable[[float], t.Awaitable[t.Any]] = sleep, **kwargs: t.Any) -> None:
super().__init__(**kwargs)
self.sleep = sleep

async def __call__( # type: ignore[override]
self,
fn: typing.Callable[..., typing.Awaitable[_RetValT]],
*args: typing.Any,
**kwargs: typing.Any,
) -> _RetValT:
self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any
) -> WrappedFnReturnT:
self.begin()

retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
Expand All @@ -62,7 +58,7 @@ async def __call__( # type: ignore[override]
else:
return do # type: ignore[no-any-return]

def __iter__(self) -> typing.Generator[AttemptManager, None, None]:
def __iter__(self) -> t.Generator[AttemptManager, None, None]:
raise TypeError("AsyncRetrying object is not iterable")

def __aiter__(self) -> "AsyncRetrying":
Expand All @@ -88,7 +84,7 @@ def wraps(self, fn: WrappedFn) -> WrappedFn:
# Ensure wrapper is recognized as a coroutine function.

@functools.wraps(fn)
async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
return await fn(*args, **kwargs)

# Preserve attributes
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ commands =

[testenv:mypy]
deps =
mypy
mypy>=1.0.0
commands =
mypy tenacity

Expand Down