Skip to content

Commit

Permalink
(#117) (#45) extend wrap to take args and kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
dperl-dls committed Apr 4, 2024
1 parent 61886f9 commit c67f1cf
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import asyncio
import functools
from typing import Awaitable, Callable, Coroutine, List, Optional, cast
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, cast

from bluesky.protocols import Status

from .utils import Callback, T
from .utils import Callback, P

AS = TypeVar("AS", bound="AsyncStatus")


class AsyncStatus(Status):
Expand Down Expand Up @@ -76,12 +78,14 @@ def watch(self, watcher: Callable):
self._watchers.append(watcher)

@classmethod
def wrap(cls, f: Callable[[T], Coroutine]) -> Callable[[T], "AsyncStatus"]:
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
@functools.wraps(f)
def wrap_f(self) -> AsyncStatus:
return AsyncStatus(f(self))
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))

return wrap_f
# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
return cast(Callable[P, AS], wrap_f)

def __repr__(self) -> str:
if self.done:
Expand Down
2 changes: 2 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Iterable,
List,
Optional,
ParamSpec,
Type,
TypeVar,
Union,
Expand All @@ -18,6 +19,7 @@
from bluesky.protocols import Reading

T = TypeVar("T")
P = ParamSpec("P")
Callback = Callable[[T], None]

#: A function that will be called with the Reading and value when the
Expand Down
66 changes: 66 additions & 0 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio

import pytest

from ophyd_async.core.async_status import AsyncStatus
from ophyd_async.core.signal import SignalR, SimSignalBackend
from ophyd_async.core.standard_readable import StandardReadable


@pytest.fixture
def loop():
return asyncio.get_event_loop()


def test_asyncstatus_wraps_bare_func(loop):
async def do_test():
@AsyncStatus.wrap
async def coro_status():
await asyncio.sleep(0.01)

st = coro_status()
assert isinstance(st, AsyncStatus)
await asyncio.wait_for(st.task, None)
assert st.done

loop.run_until_complete(do_test())


def test_asyncstatus_wraps_bare_func_with_args_kwargs(loop):
async def do_test():
test_result = 5

@AsyncStatus.wrap
async def coro_status(x: int, y: int, *, z=False):
await asyncio.sleep(0.01)
nonlocal test_result
test_result = x * y if z else 0

st = coro_status(3, 4, z=True)
assert isinstance(st, AsyncStatus)
await asyncio.wait_for(st.task, None)
assert st.done
assert test_result == 12

loop.run_until_complete(do_test())


async def test_asyncstatus_wraps_set():
class TestDevice(StandardReadable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(
backend=SimSignalBackend(datatype=int, source="sim:TEST")
)
super().__init__(name)

@AsyncStatus.wrap
async def set(self, val):
await asyncio.sleep(0.01)
self.sig._backend._set_value(val) # type: ignore

TD = TestDevice()
await TD.connect()
st = TD.set(5)
assert isinstance(st, AsyncStatus)
await st
assert (await TD.sig.get_value()) == 5

0 comments on commit c67f1cf

Please sign in to comment.