Skip to content

Commit

Permalink
(#117) (#45) make Watcher match bluesky spec
Browse files Browse the repository at this point in the history
  • Loading branch information
dperl-dls committed Apr 9, 2024
1 parent 2f0e284 commit a4bdf2d
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 42 deletions.
40 changes: 14 additions & 26 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,8 @@
import asyncio
import functools
import time
from dataclasses import replace
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
Sequence,
Type,
TypeVar,
cast,
)
from dataclasses import asdict, replace
from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast

from bluesky.protocols import Status

Expand Down Expand Up @@ -105,37 +96,34 @@ class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):

def __init__(
self,
iterator_or_awaitable: Awaitable | AsyncIterator[WatcherUpdate[T]],
iterator: AsyncIterator[WatcherUpdate[T]],
watchers: list[Watcher] = [],
):
self._watchers: list[Watcher] = watchers
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
awaitable = (
iterator_or_awaitable
if isinstance(iterator_or_awaitable, Awaitable)
else self._notify_watchers_from(iterator_or_awaitable)
)
super().__init__(awaitable)
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for self._last_update in iterator:
async for update in iterator:
self._last_update = replace(
update, time_elapsed=time.monotonic() - self._start
)
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))
watcher(**asdict(update))

def watch(self, watchers: Sequence[Watcher]):
for watcher in watchers:
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
cls: Type[WAS],
f: Callable[P, Awaitable] | Callable[P, AsyncIterator[WatcherUpdate[T]]],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
Expand Down
30 changes: 25 additions & 5 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
List,
Optional,
ParamSpec,
Protocol,
Type,
TypeAlias,
TypeVar,
Union,
)
Expand Down Expand Up @@ -85,16 +85,36 @@ def __str__(self) -> str:

@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
"""A dataclass such that, when expanded, it provides the kwargs for a watcher"""

name: str
current: T
initial: T
target: T
units: str
precision: float
time_elapsed_s: float


Watcher: TypeAlias = Callable[[WatcherUpdate[T]], Any]
fraction: float
time_elapsed: float
time_remaining: float


C = TypeVar("C", contravariant=True)


class Watcher(Protocol, Generic[C]):
@staticmethod
def __call__(
*,
name: str,
current: C,
initial: C,
target: C,
units: str,
precision: float,
fraction: float,
time_elapsed: float,
time_remaining: float,
) -> Any: ...


async def wait_for_connection(**coros: Awaitable[None]):
Expand Down
71 changes: 60 additions & 11 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from functools import partial
from typing import AsyncIterator

import bluesky.plan_stubs as bps
Expand All @@ -15,6 +16,53 @@ class SetFailed(Exception):
pass


def watcher_test(
storage: list[WatcherUpdate],
*,
name: str,
current: int,
initial: int,
target: int,
units: str,
precision: float,
fraction: float,
time_elapsed: float,
time_remaining: float,
):
storage.append(
WatcherUpdate(
name=name,
current=current,
initial=initial,
target=target,
units=units,
precision=precision,
fraction=fraction,
time_elapsed=time_elapsed,
time_remaining=time_remaining,
)
)


class TWatcher:
updates: list[int] = []

def __call__(
self,
*,
name: str,
current: int,
initial: int,
target: int,
units: str,
precision: float,
fraction: float,
time_elapsed: float,
time_remaining: float,
) -> None:
self.updates.append(current)


class ASTestDevice(StandardReadable, Movable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(backend=SimSignalBackend(datatype=int, source="sim:TEST"))
Expand Down Expand Up @@ -48,7 +96,9 @@ async def set(self, val) -> AsyncIterator:
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=0,
time_elapsed=0,
time_remaining=0,
fraction=0,
)
if self.complete_set:
self.sig._backend._set_value(val) # type: ignore
Expand All @@ -59,7 +109,9 @@ async def set(self, val) -> AsyncIterator:
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=point,
time_elapsed=0,
time_remaining=0,
fraction=0,
)
else:
raise SetFailed
Expand Down Expand Up @@ -115,20 +167,20 @@ async def test_asyncstatus_wraps_set(RE):
assert (await td.sig.get_value()) == 3


async def test_asyncstatus_wraps_set_iterator(RE):
async def test_asyncstatus_wraps_set_iterator_with_class_or_func_watcher(RE):
td = ASTestDeviceIteratorSet()
await td.connect()
st = td.set(6)
updates = []

def watcher(update):
updates.append(update)

st.watch([watcher])
w = TWatcher()
st.watch(partial(watcher_test, updates))
st.watch(w)
await st
assert st.done
assert st.success
assert len(updates) == 6
assert sum(w.updates) == 21


async def test_asyncstatus_wraps_failing_set_iterator_(RE):
Expand All @@ -137,10 +189,7 @@ async def test_asyncstatus_wraps_failing_set_iterator_(RE):
st = td.set(6)
updates = []

def watcher(update):
updates.append(update)

st.watch([watcher])
st.watch(partial(watcher_test, updates))
try:
await st
except Exception:
Expand Down

0 comments on commit a4bdf2d

Please sign in to comment.