Skip to content

Commit

Permalink
(#117) (#45) add watchable status and some tests
Browse files Browse the repository at this point in the history
(#117) (#45) add watchable status and some tests
  • Loading branch information
dperl-dls committed Apr 19, 2024
1 parent 0a5533e commit 8b926b4
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 36 deletions.
59 changes: 40 additions & 19 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import asyncio
import functools
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, cast
import time
from dataclasses import replace
from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast

from bluesky.protocols import Status

from .utils import Callback, P
from .utils import Callback, P, T, WatcherUpdate

AS = TypeVar("AS", bound="AsyncStatus")
AS = TypeVar("AS")


class AsyncStatus(Status):
Expand All @@ -17,15 +19,13 @@ class AsyncStatus(Status):
def __init__(
self,
awaitable: Awaitable,
watchers: Optional[List[Callable]] = None,
):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(awaitable) # type: ignore
self.task.add_done_callback(self._run_callbacks)
self._callbacks = cast(List[Callback[Status]], [])
self._watchers = watchers
self._callbacks = cast(list[Callback[Status]], [])

def __await__(self):
return self.task.__await__()
Expand All @@ -41,15 +41,11 @@ def _run_callbacks(self, task: asyncio.Task):
for callback in self._callbacks:
callback(self)

# TODO: remove ignore and bump min version when bluesky v1.12.0 is released
def exception( # type: ignore
self, timeout: Optional[float] = 0.0
) -> Optional[BaseException]:
def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
raise Exception(
"cannot honour any timeout other than 0 in an asynchronous function"
)

if self.task.done():
try:
return self.task.exception()
Expand All @@ -69,14 +65,6 @@ def success(self) -> bool:
and self.task.exception() is None
)

def watch(self, watcher: Callable):
"""Add watcher to the list of interested parties.
Arguments as per Bluesky :external+bluesky:meth:`watch` protocol.
"""
if self._watchers is not None:
self._watchers.append(watcher)

@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
@functools.wraps(f)
Expand All @@ -98,3 +86,36 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"

__str__ = __repr__


class WatchableAsyncStatus(AsyncStatus, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface"""

def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Callable] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for self._last_update in iterator:
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Callable, update: WatcherUpdate[T]):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))

def watch(self, watcher: Callable):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
cls: Type[AS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]]
) -> Callable[P, AS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))

return cast(Callable[P, AS], wrap_f)
13 changes: 13 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import asyncio
import logging
from dataclasses import dataclass
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -79,6 +81,17 @@ def __str__(self) -> str:
return self.format_error_string(indent="")


@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
name: str
current: T
initial: T
target: T
units: str
precision: float
time_elapsed_s: float


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them
Expand Down
118 changes: 101 additions & 17 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,68 @@
import asyncio

import bluesky.plan_stubs as bps
import pytest
from bluesky.protocols import Movable

from ophyd_async.core.async_status import AsyncStatus
from ophyd_async.core.async_status import AsyncStatus, WatchableAsyncStatus
from ophyd_async.core.signal import SignalR, SimSignalBackend
from ophyd_async.core.standard_readable import StandardReadable
from ophyd_async.core.utils import WatcherUpdate


class SetFailed(Exception):
pass


class ASTestDevice(StandardReadable, Movable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(backend=SimSignalBackend(datatype=int, source="sim:TEST"))
super().__init__(name)


class ASTestDeviceSingleSet(ASTestDevice):
@AsyncStatus.wrap
async def set(self, val):
await asyncio.sleep(0.01)
self.sig._backend._set_value(val) # type: ignore


class ASTestDeviceIteratorSet(ASTestDevice):
def __init__(
self, name: str = "", values=[1, 2, 3, 4, 5], complete_set: bool = True
) -> None:
self.values = values
self.complete_set = complete_set
super().__init__(name)

@WatchableAsyncStatus.wrap
async def set(self, val):
self._initial = await self.sig.get_value()
for point in self.values:
await asyncio.sleep(0.01)
yield WatcherUpdate(
name=self.name,
current=point,
initial=self._initial,
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=0,
)
if self.complete_set:
self.sig._backend._set_value(val) # type: ignore
yield WatcherUpdate(
name=self.name,
current=val,
initial=self._initial,
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=point,
)
else:
raise SetFailed
return


@pytest.fixture
Expand Down Expand Up @@ -45,22 +103,48 @@ async def coro_status(x: int, y: int, *, z=False):
loop.run_until_complete(do_test())


async def test_asyncstatus_wraps_set():
class TestDevice(StandardReadable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(
backend=SimSignalBackend(datatype=int, source="sim:TEST")
)
super().__init__(name)
async def test_asyncstatus_wraps_set(RE):
td = ASTestDeviceSingleSet()
await td.connect()
st = td.set(5)
assert isinstance(st, AsyncStatus)
await st
assert (await td.sig.get_value()) == 5
RE(bps.abs_set(td, 3, wait=True))
assert (await td.sig.get_value()) == 3

@AsyncStatus.wrap
async def set(self, val):
await asyncio.sleep(0.01)
self.sig._backend._set_value(val) # type: ignore

TD = TestDevice()
await TD.connect()
st = TD.set(5)
assert isinstance(st, AsyncStatus)
async def test_asyncstatus_wraps_set_iterator(RE):
td = ASTestDeviceIteratorSet()
await td.connect()
st = td.set(6)
updates = []

def watcher(update):
updates.append(update)

st.watch(watcher)
await st
assert (await TD.sig.get_value()) == 5
assert st.done
assert st.success
assert len(updates) == 6


async def test_asyncstatus_wraps_failing_set_iterator_(RE):
td = ASTestDeviceIteratorSet(values=[1, 2, 3], complete_set=False)
await td.connect()
st = td.set(6)
updates = []

def watcher(update):
updates.append(update)

st.watch(watcher)
try:
await st
except Exception:
...
assert st.done
assert not st.success
assert isinstance(st.exception(), SetFailed)
assert len(updates) == 3

0 comments on commit 8b926b4

Please sign in to comment.