From c67f1cfa120f7637e1b005f03c5b7e43d77ffaf9 Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 4 Apr 2024 15:23:35 +0100 Subject: [PATCH] (#117) (#45) extend wrap to take args and kwargs --- src/ophyd_async/core/async_status.py | 16 +++--- src/ophyd_async/core/utils.py | 2 + tests/core/test_async_status_wrapper.py | 66 +++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 tests/core/test_async_status_wrapper.py diff --git a/src/ophyd_async/core/async_status.py b/src/ophyd_async/core/async_status.py index 201c7162b2..5cf0aa3212 100644 --- a/src/ophyd_async/core/async_status.py +++ b/src/ophyd_async/core/async_status.py @@ -2,11 +2,13 @@ import asyncio import functools -from typing import Awaitable, Callable, Coroutine, List, Optional, cast +from typing import Awaitable, Callable, List, Optional, Type, TypeVar, cast from bluesky.protocols import Status -from .utils import Callback, T +from .utils import Callback, P + +AS = TypeVar("AS", bound="AsyncStatus") class AsyncStatus(Status): @@ -76,12 +78,14 @@ def watch(self, watcher: Callable): self._watchers.append(watcher) @classmethod - def wrap(cls, f: Callable[[T], Coroutine]) -> Callable[[T], "AsyncStatus"]: + def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]: @functools.wraps(f) - def wrap_f(self) -> AsyncStatus: - return AsyncStatus(f(self)) + def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS: + return cls(f(*args, **kwargs)) - return wrap_f + # type is actually functools._Wrapped[P, Awaitable, P, AS] + # but functools._Wrapped is not necessarily available + return cast(Callable[P, AS], wrap_f) def __repr__(self) -> str: if self.done: diff --git a/src/ophyd_async/core/utils.py b/src/ophyd_async/core/utils.py index 42a4b5b7d4..2180cb01c0 100644 --- a/src/ophyd_async/core/utils.py +++ b/src/ophyd_async/core/utils.py @@ -9,6 +9,7 @@ Iterable, List, Optional, + ParamSpec, Type, TypeVar, Union, @@ -18,6 +19,7 @@ from bluesky.protocols import Reading T = TypeVar("T") +P = ParamSpec("P") Callback = Callable[[T], None] #: A function that will be called with the Reading and value when the diff --git a/tests/core/test_async_status_wrapper.py b/tests/core/test_async_status_wrapper.py new file mode 100644 index 0000000000..cff80e2d89 --- /dev/null +++ b/tests/core/test_async_status_wrapper.py @@ -0,0 +1,66 @@ +import asyncio + +import pytest + +from ophyd_async.core.async_status import AsyncStatus +from ophyd_async.core.signal import SignalR, SimSignalBackend +from ophyd_async.core.standard_readable import StandardReadable + + +@pytest.fixture +def loop(): + return asyncio.get_event_loop() + + +def test_asyncstatus_wraps_bare_func(loop): + async def do_test(): + @AsyncStatus.wrap + async def coro_status(): + await asyncio.sleep(0.01) + + st = coro_status() + assert isinstance(st, AsyncStatus) + await asyncio.wait_for(st.task, None) + assert st.done + + loop.run_until_complete(do_test()) + + +def test_asyncstatus_wraps_bare_func_with_args_kwargs(loop): + async def do_test(): + test_result = 5 + + @AsyncStatus.wrap + async def coro_status(x: int, y: int, *, z=False): + await asyncio.sleep(0.01) + nonlocal test_result + test_result = x * y if z else 0 + + st = coro_status(3, 4, z=True) + assert isinstance(st, AsyncStatus) + await asyncio.wait_for(st.task, None) + assert st.done + assert test_result == 12 + + loop.run_until_complete(do_test()) + + +async def test_asyncstatus_wraps_set(): + class TestDevice(StandardReadable): + def __init__(self, name: str = "") -> None: + self.sig = SignalR( + backend=SimSignalBackend(datatype=int, source="sim:TEST") + ) + super().__init__(name) + + @AsyncStatus.wrap + async def set(self, val): + await asyncio.sleep(0.01) + self.sig._backend._set_value(val) # type: ignore + + TD = TestDevice() + await TD.connect() + st = TD.set(5) + assert isinstance(st, AsyncStatus) + await st + assert (await TD.sig.get_value()) == 5