Skip to content

Commit

Permalink
(#117) (#45) improve timeouts and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dperl-dls committed Apr 11, 2024
1 parent 394a632 commit 82e7f14
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 36 deletions.
33 changes: 18 additions & 15 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import functools
import time
from dataclasses import asdict, replace
from functools import partial
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
Sequence,
SupportsFloat,
Type,
TypeVar,
cast,
Expand Down Expand Up @@ -107,10 +106,9 @@ class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
def __init__(
self,
iterator: AsyncIterator[WatcherUpdate[T]],
timeout_s: float,
watchers: Sequence[Watcher],
timeout_s: float = 0.0,
):
self._watchers: list[Watcher] = list(watchers)
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._timeout = self._start + timeout_s if timeout_s else None
self._last_update: WatcherUpdate[T] | None = None
Expand All @@ -127,7 +125,10 @@ async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]])
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
watcher(**asdict(update))
vals = asdict(
update, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
watcher(**vals)

def watch(self, watcher: Watcher):
self._watchers.append(watcher)
Expand All @@ -138,17 +139,19 @@ def watch(self, watcher: Watcher):
def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
timeout_s: float = 0.0,
watchers: Sequence[Watcher] = [],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes
'timeout_s' as an argument, this must be a float and it will be propagated
to the status."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs), timeout_s=timeout_s, watchers=watchers)
# We can't type this more properly because Concatenate/ParamSpec doesn't
# yet support keywords
# https://peps.python.org/pep-0612/#concatenating-keyword-parameters
_timeout = kwargs.get("timeout_s")
assert isinstance(_timeout, SupportsFloat) or _timeout is None
timeout = _timeout or 0.0
return cls(f(*args, **kwargs), timeout_s=float(timeout))

return cast(Callable[P, WAS], wrap_f)

@classmethod
def wrap_with(
cls: Type[WAS], timeout_s: float = 0.0, watchers: Sequence[Watcher] = []
):
return partial(cls.wrap, timeout_s=timeout_s, watchers=watchers)
7 changes: 4 additions & 3 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,12 @@ async def _prepare(self, value: T) -> None:
exposure=self._trigger_info.livetime,
)

async def kickoff(self):
def kickoff(self, timeout_s=0.0):
self._fly_start = time.monotonic()
return WatchableAsyncStatus(
self._observe_writer_indicies(self._last_frame), self._watchers
self._fly_status = WatchableAsyncStatus(
self._observe_writer_indicies(self._last_frame), timeout_s
)
return self._fly_status

async def _observe_writer_indicies(self, end_observation: int):
async for index in self.writer.observe_indices_written(
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/flyer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ async def _prepare(self, value: T) -> None:
# Move to start and setup the flyscan
await self._trigger_logic.prepare(value)

@AsyncStatus.wrap
async def kickoff(self) -> None:
def kickoff(self) -> AsyncStatus:
self._fly_status = AsyncStatus(self._trigger_logic.start())
return self._fly_status

def complete(self) -> AsyncStatus:
assert self._fly_status, "Kickoff not run"
Expand Down
5 changes: 4 additions & 1 deletion src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import logging
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -98,6 +98,9 @@ class WatcherUpdate(Generic[T]):
time_elapsed: float | None = None
time_remaining: float | None = None

def as_dict(self) -> dict[str, T | str | float]:
return {k: v for k, v in asdict(self) if v is not None}


C = TypeVar("C", contravariant=True)

Expand Down
1 change: 1 addition & 0 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ async def _move(self, new_position: float):
await self.setpoint.set(new_position, wait=False)
if not self._set_success:
raise RuntimeError("Motor was stopped")
# return a template to set() which it can use to yield progress updates
return WatcherUpdate(
initial=old_position,
current=old_position,
Expand Down
15 changes: 9 additions & 6 deletions src/ophyd_async/epics/motion/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def _move(self, new_position: float) -> WatcherUpdate[float]:
self.units.get_value(),
self.precision.get_value(),
)
await self.setpoint.set(new_position)
await self.setpoint.set(new_position, wait=False)
if not self._set_success:
raise RuntimeError("Motor was stopped")
return WatcherUpdate(
Expand All @@ -66,12 +66,15 @@ def move(self, new_position: float, timeout: Optional[float] = None):
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

@WatchableAsyncStatus.wrap
async def set(self, new_position: float):
start_time = time.monotonic()
update: WatcherUpdate[float] = await self._move(new_position)
async for readback in observe_value(self.readback):
async def set(self, new_position: float, timeout_s: float = 0.0):
update = await self._move(new_position)
start = time.monotonic()
async for current_position in observe_value(self.readback):
yield replace(
update, current=readback, time_elapsed=start_time - time.monotonic()
update,
name=self.name,
current=current_position,
time_elapsed=time.monotonic() - start,
)
if await self.done_moving.get_value():
return
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __init__(
self.complete_set = complete_set
super().__init__(name)

@WatchableAsyncStatus.wrap_with(timeout_s=5, watchers=[])
async def set(self, val) -> AsyncIterator:
@WatchableAsyncStatus.wrap
async def set(self, val, timeout_s=5) -> AsyncIterator:
assert self._staged
self._initial = await self.sig.get_value()
for point in self.values:
Expand Down
1 change: 1 addition & 0 deletions tests/core/test_device_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_async_device_connector_run_engine_same_event_loop():
async def set_up_device():
async with DeviceCollector(sim=True):
sim_motor = motor.Motor("BLxxI-MO-TABLE-01:X")
sim_motor.set
return sim_motor

loop = asyncio.new_event_loop()
Expand Down
2 changes: 0 additions & 2 deletions tests/epics/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def __call__(
name=name,
unit=unit,
precision=precision,
fraction=fraction,
time_elapsed=time_elapsed,
time_remaining=time_remaining,
**kwargs,
)
self._event.set()
Expand Down
22 changes: 17 additions & 5 deletions tests/epics/motion/test_motor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import time
from typing import Dict
from unittest.mock import Mock, call

Expand All @@ -24,18 +25,27 @@ async def sim_motor():
set_sim_value(sim_motor.units, "mm")
set_sim_value(sim_motor.precision, 3)
set_sim_value(sim_motor.velocity, 1)
set_sim_value(sim_motor.done_moving, True)
yield sim_motor


async def wait_for_eq(item, attribute, comparison, timeout):
timeout_time = time.monotonic() + timeout
while getattr(item, attribute) != comparison:
await asyncio.sleep(A_BIT)
if time.monotonic() > timeout_time:
raise TimeoutError


async def test_motor_moving_well(sim_motor: motor.Motor) -> None:
set_sim_put_proceeds(sim_motor.setpoint, False)
s = sim_motor.set(0.55)
set_sim_value(sim_motor.done_moving, False)
s = sim_motor.set(0.55, timeout_s=1)
watcher = Mock(spec=Watcher)
s.watch(watcher)
done = Mock()
s.add_callback(done)
await asyncio.sleep(A_BIT)
assert watcher.call_count == 1
await wait_for_eq(watcher, "call_count", 1, 1)
assert watcher.call_args == call(
name="sim_motor",
current=0.0,
Expand All @@ -50,7 +60,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None:
assert not s.done
await asyncio.sleep(0.1)
set_sim_value(sim_motor.readback, 0.1)
assert watcher.call_count == 1
await wait_for_eq(watcher, "call_count", 1, 1)
assert watcher.call_args == call(
name="sim_motor",
current=0.1,
Expand All @@ -61,8 +71,10 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None:
time_elapsed=pytest.approx(0.1, abs=0.05),
)
set_sim_put_proceeds(sim_motor.setpoint, True)
set_sim_value(sim_motor.done_moving, True)
set_sim_value(sim_motor.readback, 0.55)
await asyncio.sleep(A_BIT)
assert s.done
await wait_for_eq(s, "done", True, 1)
done.assert_called_once_with(s)


Expand Down

0 comments on commit 82e7f14

Please sign in to comment.