Skip to content

Commit

Permalink
(#117) (#45) add timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
dperl-dls committed Apr 11, 2024
1 parent e9d073f commit 394a632
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 35 deletions.
30 changes: 26 additions & 4 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import functools
import time
from dataclasses import asdict, replace
from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast
from functools import partial
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
Sequence,
Type,
TypeVar,
cast,
)

from bluesky.protocols import Status

Expand Down Expand Up @@ -97,15 +107,19 @@ class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
def __init__(
self,
iterator: AsyncIterator[WatcherUpdate[T]],
watchers: list[Watcher] = [],
timeout_s: float,
watchers: Sequence[Watcher],
):
self._watchers: list[Watcher] = watchers
self._watchers: list[Watcher] = list(watchers)
self._start = time.monotonic()
self._timeout = self._start + timeout_s if timeout_s else None
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
if self._timeout and time.monotonic() > self._timeout:
raise TimeoutError()
self._last_update = replace(
update, time_elapsed=time.monotonic() - self._start
)
Expand All @@ -124,9 +138,17 @@ def watch(self, watcher: Watcher):
def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
timeout_s: float = 0.0,
watchers: Sequence[Watcher] = [],
) -> Callable[P, WAS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
return cls(f(*args, **kwargs))
return cls(f(*args, **kwargs), timeout_s=timeout_s, watchers=watchers)

return cast(Callable[P, WAS], wrap_f)

@classmethod
def wrap_with(
cls: Type[WAS], timeout_s: float = 0.0, watchers: Sequence[Watcher] = []
):
return partial(cls.wrap, timeout_s=timeout_s, watchers=watchers)
22 changes: 21 additions & 1 deletion src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ParamSpec,
Protocol,
Type,
TypeAlias,
TypeVar,
Union,
)
Expand Down Expand Up @@ -101,7 +102,7 @@ class WatcherUpdate(Generic[T]):
C = TypeVar("C", contravariant=True)


class Watcher(Protocol, Generic[C]):
class _ClsWatcher(Protocol, Generic[C]):
@staticmethod
def __call__(
*,
Expand All @@ -117,6 +118,25 @@ def __call__(
) -> Any: ...


class _InsWatcher(Protocol, Generic[C]):
def __call__(
self,
*,
current: C,
initial: C,
target: C,
name: str | None,
unit: str | None,
precision: float | None,
fraction: float | None,
time_elapsed: float | None,
time_remaining: float | None,
) -> Any: ...


Watcher: TypeAlias = _ClsWatcher | _InsWatcher


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them
Expand Down
47 changes: 24 additions & 23 deletions src/ophyd_async/epics/demo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import subprocess
import sys
import time
from dataclasses import replace
from enum import Enum
from pathlib import Path
from typing import Callable, List, Optional
from typing import Optional

import numpy as np
from bluesky.protocols import Movable, Stoppable
Expand All @@ -20,6 +21,7 @@
WatchableAsyncStatus,
observe_value,
)
from ophyd_async.core.utils import WatcherUpdate

from ..signal.signal import epics_signal_r, epics_signal_rw, epics_signal_x

Expand Down Expand Up @@ -74,32 +76,25 @@ def set_name(self, name: str):
# Readback should be named the same as its parent in read()
self.readback.set_name(name)

async def _move(self, new_position: float, watchers: List[Callable] = []):
async def _move(self, new_position: float):
self._set_success = True
# time.monotonic won't go backwards in case of NTP corrections
start = time.monotonic()
old_position, units, precision = await asyncio.gather(
self.setpoint.get_value(),
self.units.get_value(),
self.precision.get_value(),
)
# Wait for the value to set, but don't wait for put completion callback
await self.setpoint.set(new_position, wait=False)
async for current_position in observe_value(self.readback):
for watcher in watchers:
watcher(
name=self.name,
current=current_position,
initial=old_position,
target=new_position,
unit=units,
precision=precision,
time_elapsed=time.monotonic() - start,
)
if np.isclose(current_position, new_position):
break
if not self._set_success:
raise RuntimeError("Motor was stopped")
return WatcherUpdate(
initial=old_position,
current=old_position,
target=new_position,
unit=units,
precision=precision,
)

def move(self, new_position: float, timeout: Optional[float] = None):
"""Commandline only synchronous move of a Motor"""
Expand All @@ -109,13 +104,19 @@ def move(self, new_position: float, timeout: Optional[float] = None):
raise RuntimeError("Will deadlock run engine if run in a plan")
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

# TODO: this fails if we call from the cli, but works if we "ipython await" it
def set(
self, new_position: float, timeout: Optional[float] = None
) -> WatchableAsyncStatus:
watchers: List[Callable] = []
coro = asyncio.wait_for(self._move(new_position, watchers), timeout=timeout)
return WatchableAsyncStatus(coro, watchers)
@WatchableAsyncStatus.wrap
async def set(self, new_position: float):
update = await self._move(new_position)
start = time.monotonic()
async for current_position in observe_value(self.readback):
yield replace(
update,
name=self.name,
current=current_position,
time_elapsed=time.monotonic() - start,
)
if np.isclose(current_position, new_position):
return

async def stop(self, success=True):
self._set_success = success
Expand Down
5 changes: 4 additions & 1 deletion src/ophyd_async/epics/motion/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, prefix: str, name="") -> None:
self.setpoint = epics_signal_rw(float, prefix + ".VAL")
self.readback = epics_signal_r(float, prefix + ".RBV")
self.velocity = epics_signal_rw(float, prefix + ".VELO")
self.done_moving = epics_signal_r(bool, prefix + ".DMOV")
self.units = epics_signal_r(str, prefix + ".EGU")
self.precision = epics_signal_r(int, prefix + ".PREC")
# Signals that collide with standard methods should have a trailing underscore
Expand Down Expand Up @@ -65,13 +66,15 @@ def move(self, new_position: float, timeout: Optional[float] = None):
call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore

@WatchableAsyncStatus.wrap
async def set(self, new_position: float, timeout: Optional[float] = None):
async def set(self, new_position: float):
start_time = time.monotonic()
update: WatcherUpdate[float] = await self._move(new_position)
async for readback in observe_value(self.readback):
yield replace(
update, current=readback, time_elapsed=start_time - time.monotonic()
)
if await self.done_moving.get_value():
return

async def stop(self, success=False):
self._set_success = success
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_async_status_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.complete_set = complete_set
super().__init__(name)

@WatchableAsyncStatus.wrap
@WatchableAsyncStatus.wrap_with(timeout_s=5, watchers=[])
async def set(self, val) -> AsyncIterator:
assert self._staged
self._initial = await self.sig.get_value()
Expand Down
35 changes: 30 additions & 5 deletions tests/epics/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,38 @@ async def sim_sensor():
yield sim_sensor


class Watcher:
class DemoWatcher:
def __init__(self) -> None:
self._event = asyncio.Event()
self._mock = Mock()

def __call__(self, *args, **kwargs):
self._mock(*args, **kwargs)
def __call__(
self,
*args,
current: float,
initial: float,
target: float,
name: str | None = None,
unit: str | None = None,
precision: float | None = None,
fraction: float | None = None,
time_elapsed: float | None = None,
time_remaining: float | None = None,
**kwargs,
):
self._mock(
*args,
current=current,
initial=initial,
target=target,
name=name,
unit=unit,
precision=precision,
fraction=fraction,
time_elapsed=time_elapsed,
time_remaining=time_remaining,
**kwargs,
)
self._event.set()

async def wait_for_call(self, *args, **kwargs):
Expand All @@ -60,8 +85,8 @@ async def wait_for_call(self, *args, **kwargs):

async def test_mover_moving_well(sim_mover: demo.Mover) -> None:
s = sim_mover.set(0.55)
watcher = Watcher()
s.watch([watcher])
watcher = DemoWatcher()
s.watch(watcher)
done = Mock()
s.add_callback(done)
await watcher.wait_for_call(
Expand Down

0 comments on commit 394a632

Please sign in to comment.