Skip to content

Commit

Permalink
(#117) (#45) test completion and success
Browse files Browse the repository at this point in the history
  • Loading branch information
dperl-dls committed Apr 5, 2024
1 parent cdf6883 commit 4f9c467
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 37 deletions.
30 changes: 10 additions & 20 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,7 @@
import functools
import time
from dataclasses import replace
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
List,
Optional,
Type,
TypeVar,
cast,
)
from typing import Any, AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast

from bluesky.protocols import Status

Expand All @@ -35,7 +25,7 @@ def __init__(
else:
self.task = asyncio.create_task(awaitable) # type: ignore
self.task.add_done_callback(self._run_callbacks)
self._callbacks = cast(List[Callback[Status]], [])
self._callbacks = cast(list[Callback[Status]], [])

def __await__(self):
return self.task.__await__()
Expand All @@ -51,15 +41,11 @@ def _run_callbacks(self, task: asyncio.Task):
for callback in self._callbacks:
callback(self)

# TODO: remove ignore and bump min version when bluesky v1.12.0 is released
def exception( # type: ignore
self, timeout: Optional[float] = 0.0
) -> Optional[BaseException]:
def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
raise Exception(
"cannot honour any timeout other than 0 in an asynchronous function"
)

if self.task.done():
try:
return self.task.exception()
Expand Down Expand Up @@ -103,21 +89,25 @@ def __repr__(self) -> str:


class WatchableAsyncStatus(AsyncStatus, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface"""

def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Callable] = []
self._start = time.monotonic()
self._last_update: Optional[WatcherUpdate[T]] = None
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for self._last_update in iterator:
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Callable, update: WatcherUpdate[T]):
def _update_watcher(
self, watcher: Callable[[WatcherUpdate[T]], Any], update: WatcherUpdate[T]
):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))

def watch(self, watcher: Callable):
def watch(self, watcher: Callable[[WatcherUpdate[T]], Any]):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
Expand Down
67 changes: 50 additions & 17 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,37 @@

import bluesky.plan_stubs as bps
import pytest
from bluesky.protocols import Movable

from ophyd_async.core.async_status import AsyncStatus, WatchableAsyncStatus
from ophyd_async.core.signal import SignalR, SimSignalBackend
from ophyd_async.core.standard_readable import StandardReadable
from ophyd_async.core.utils import WatcherUpdate


class TestDevice(StandardReadable):
class SetFailed(Exception):
pass


class ASTestDevice(StandardReadable, Movable):
def __init__(self, name: str = "") -> None:
self.sig = SignalR(backend=SimSignalBackend(datatype=int, source="sim:TEST"))
super().__init__(name)


class TestDeviceSingleSet(TestDevice):
class ASTestDeviceSingleSet(ASTestDevice):
@AsyncStatus.wrap
async def set(self, val):
await asyncio.sleep(0.01)
self.sig._backend._set_value(val) # type: ignore


class TestDeviceIteratorSet(TestDevice):
def __init__(self, name: str = "", values=[1, 2, 3, 4, 5]) -> None:
class ASTestDeviceIteratorSet(ASTestDevice):
def __init__(
self, name: str = "", values=[1, 2, 3, 4, 5], complete_set: bool = True
) -> None:
self.values = values
self.complete_set = complete_set
super().__init__(name)

@WatchableAsyncStatus.wrap
Expand All @@ -39,18 +47,21 @@ async def set(self, val):
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=0,
)
if self.complete_set:
self.sig._backend._set_value(val) # type: ignore
yield WatcherUpdate(
name=self.name,
current=val,
initial=self._initial,
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=point,
)
self.sig._backend._set_value(val) # type: ignore
yield WatcherUpdate(
name=self.name,
current=val,
initial=self._initial,
target=val,
units="dimensionless",
precision=0.0,
time_elapsed_s=point,
)
else:
raise SetFailed
return


Expand Down Expand Up @@ -93,7 +104,7 @@ async def coro_status(x: int, y: int, *, z=False):


async def test_asyncstatus_wraps_set(RE):
td = TestDeviceSingleSet()
td = ASTestDeviceSingleSet()
await td.connect()
st = td.set(5)
assert isinstance(st, AsyncStatus)
Expand All @@ -104,14 +115,36 @@ async def test_asyncstatus_wraps_set(RE):


async def test_asyncstatus_wraps_set_iterator(RE):
td = TestDeviceIteratorSet()
td = ASTestDeviceIteratorSet()
await td.connect()
st = td.set(5)
st = td.set(6)
updates = []

def watcher(update):
updates.append(update)

st.watch(watcher)
await st
assert st.done
assert st.success
assert len(updates) == 6


async def test_asyncstatus_wraps_failing_set_iterator_(RE):
td = ASTestDeviceIteratorSet(values=[1, 2, 3], complete_set=False)
await td.connect()
st = td.set(6)
updates = []

def watcher(update):
updates.append(update)

st.watch(watcher)
try:
await st
except Exception:
...
assert st.done
assert not st.success
assert isinstance(st.exception(), SetFailed)
assert len(updates) == 3

0 comments on commit 4f9c467

Please sign in to comment.