Skip to content

Commit

Permalink
(#117) (#45) make AS and WAS both derive from base
Browse files Browse the repository at this point in the history
so that they can both use `wrap` with different types
  • Loading branch information
dperl-dls committed Apr 12, 2024
1 parent e2c822a commit 13ce75b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import functools
import time
from dataclasses import replace
from typing import Any, AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast
from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast

from bluesky.protocols import Status

from .utils import Callback, P, T, WatcherUpdate
from .utils import Callback, P, T, Watcher, WatcherUpdate

AS = TypeVar("AS")


class AsyncStatus(Status):
class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""

def __init__(
Expand Down Expand Up @@ -65,16 +65,6 @@ def success(self) -> bool:
and self.task.exception() is None
)

@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
return cast(Callable[P, AS], wrap_f)

def __repr__(self) -> str:
if self.done:
if e := self.exception():
Expand All @@ -88,11 +78,23 @@ def __repr__(self) -> str:
__str__ = __repr__


class WatchableAsyncStatus(AsyncStatus, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface"""
class AsyncStatus(AsyncStatusBase):
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
return cls(f(*args, **kwargs))

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
return cast(Callable[P, AS], wrap_f)


class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(self, iterator: AsyncIterator[WatcherUpdate[T]]):
self._watchers: list[Callable] = []
self._watchers: list[Watcher]
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator))
Expand All @@ -102,12 +104,10 @@ async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]])
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(
self, watcher: Callable[[WatcherUpdate[T]], Any], update: WatcherUpdate[T]
):
def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
watcher(replace(update, time_elapsed_s=time.monotonic() - self._start))

def watch(self, watcher: Callable[[WatcherUpdate[T]], Any]):
def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)
Expand Down
5 changes: 5 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from dataclasses import dataclass
from typing import (
Any,
Awaitable,
Callable,
Dict,
Expand All @@ -13,6 +14,7 @@
Optional,
ParamSpec,
Type,
TypeAlias,
TypeVar,
Union,
)
Expand Down Expand Up @@ -92,6 +94,9 @@ class WatcherUpdate(Generic[T]):
time_elapsed_s: float


Watcher: TypeAlias = Callable[[WatcherUpdate[T]], Any]


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them
Expand Down

0 comments on commit 13ce75b

Please sign in to comment.