Skip to content

Commit

Permalink
(#117) (#45) add watchable status and some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dperl-dls committed Apr 5, 2024
1 parent c67f1cf commit cdf6883
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 30 deletions.
59 changes: 46 additions & 13 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@

import asyncio
import functools
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, cast
import time
from dataclasses import replace
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
List,
Optional,
Type,
TypeVar,
cast,
)

from bluesky.protocols import Status

from .utils import Callback, P
from .utils import Callback, P, T, WatcherUpdate

AS = TypeVar("AS", bound="AsyncStatus")
AS = TypeVar("AS")


class AsyncStatus(Status):
Expand All @@ -17,15 +29,13 @@ class AsyncStatus(Status):
def __init__(
self,
awaitable: Awaitable,
watchers: Optional[List[Callable]] = None,
):
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(awaitable) # type: ignore
self.task.add_done_callback(self._run_callbacks)
self._callbacks = cast(List[Callback[Status]], [])
self._watchers = watchers

def __await__(self):
return self.task.__await__()
Expand Down Expand Up @@ -69,14 +79,6 @@ def success(self) -> bool:
and self.task.exception() is None
)

def watch(self, watcher: Callable):
"""Add watcher to the list of interested parties.
Arguments as per Bluesky :external+bluesky:meth:`watch` protocol.
"""
if self._watchers is not None:
self._watchers.append(watcher)

@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
@functools.wraps(f)
Expand All @@ -98,3 +100,34 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"

__str__ = __repr__


class WatchableAsyncStatus(AsyncStatus, Generic[T]):
def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Callable] = []
self._start = time.monotonic()
self._last_update: Optional[WatcherUpdate[T]] = None
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for self._last_update in iterator:
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Callable, update: WatcherUpdate[T]):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))

def watch(self, watcher: Callable):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
cls: Type[AS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]]
) -> Callable[P, AS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))

return cast(Callable[P, AS], wrap_f)
13 changes: 13 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import asyncio
import logging
from dataclasses import dataclass
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -79,6 +81,17 @@ def __str__(self) -> str:
return self.format_error_string(indent="")


@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
name: str
current: T
initial: T
target: T
units: str
precision: float
time_elapsed_s: float


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them
Expand Down
85 changes: 68 additions & 17 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,57 @@
import asyncio

import bluesky.plan_stubs as bps
import pytest

from ophyd_async.core.async_status import AsyncStatus
from ophyd_async.core.async_status import AsyncStatus, WatchableAsyncStatus
from ophyd_async.core.signal import SignalR, SimSignalBackend
from ophyd_async.core.standard_readable import StandardReadable
from ophyd_async.core.utils import WatcherUpdate


class TestDevice(StandardReadable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(backend=SimSignalBackend(datatype=int, source="sim:TEST"))
super().__init__(name)


class TestDeviceSingleSet(TestDevice):
@AsyncStatus.wrap
async def set(self, val):
await asyncio.sleep(0.01)
self.sig._backend._set_value(val) # type: ignore


class TestDeviceIteratorSet(TestDevice):
def __init__(self, name: str = "", values=[1, 2, 3, 4, 5]) -> None:
self.values = values
super().__init__(name)

@WatchableAsyncStatus.wrap
async def set(self, val):
self._initial = await self.sig.get_value()
for point in self.values:
await asyncio.sleep(0.01)
yield WatcherUpdate(
name=self.name,
current=point,
initial=self._initial,
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=point,
)
self.sig._backend._set_value(val) # type: ignore
yield WatcherUpdate(
name=self.name,
current=val,
initial=self._initial,
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=point,
)
return


@pytest.fixture
Expand Down Expand Up @@ -45,22 +92,26 @@ async def coro_status(x: int, y: int, *, z=False):
loop.run_until_complete(do_test())


async def test_asyncstatus_wraps_set():
class TestDevice(StandardReadable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(
backend=SimSignalBackend(datatype=int, source="sim:TEST")
)
super().__init__(name)
async def test_asyncstatus_wraps_set(RE):
td = TestDeviceSingleSet()
await td.connect()
st = td.set(5)
assert isinstance(st, AsyncStatus)
await st
assert (await td.sig.get_value()) == 5
RE(bps.abs_set(td, 3, wait=True))
assert (await td.sig.get_value()) == 3

@AsyncStatus.wrap
async def set(self, val):
await asyncio.sleep(0.01)
self.sig._backend._set_value(val) # type: ignore

TD = TestDevice()
await TD.connect()
st = TD.set(5)
assert isinstance(st, AsyncStatus)
async def test_asyncstatus_wraps_set_iterator(RE):
td = TestDeviceIteratorSet()
await td.connect()
st = td.set(5)
updates = []

def watcher(update):
updates.append(update)

st.watch(watcher)
await st
assert (await TD.sig.get_value()) == 5
assert len(updates) == 6

0 comments on commit cdf6883

Please sign in to comment.