Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify error handling for timed out signals when connecting. #97

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 27 additions & 37 deletions src/ophyd_async/core/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@

from __future__ import annotations

import asyncio
import logging
import sys
from contextlib import suppress
from typing import Any, Dict, Generator, Iterator, Optional, Set, Tuple, TypeVar
from typing import (
Any,
Coroutine,
Dict,
Generator,
Iterator,
Optional,
Set,
Tuple,
TypeVar,
)

from bluesky.protocols import HasName
from bluesky.run_engine import call_in_bluesky_event_loop

from .utils import NotConnected, wait_for_connection
from .utils import DEFAULT_TIMEOUT, wait_for_connection


class Device(HasName):
Expand Down Expand Up @@ -51,16 +58,21 @@ def set_name(self, name: str):
child.set_name(child_name)
child.parent = self

async def connect(self, sim: bool = False):
async def connect(self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT):
"""Connect self and all child Devices.

Contains a timeout that gets propagated to child.connect methods.

Parameters
----------
sim:
If True then connect in simulation mode.
timeout:
Time to wait before failing with a TimeoutError.
"""
coros = {
name: child_device.connect(sim) for name, child_device in self.children()
name: child_device.connect(sim, timeout=timeout)
for name, child_device in self.children()
}
if coros:
await wait_for_connection(**coros)
Expand Down Expand Up @@ -141,41 +153,19 @@ async def __aenter__(self) -> "DeviceCollector":

async def _on_exit(self) -> None:
# Name and kick off connect for devices
tasks: Dict[asyncio.Task, str] = {}
connect_coroutines: Dict[str, Coroutine] = {}
for name, obj in self._objects_on_exit.items():
if name not in self._names_on_enter and isinstance(obj, Device):
if self._set_name and not obj.name:
obj.set_name(name)
if self._connect:
task = asyncio.create_task(obj.connect(self._sim))
tasks[task] = name
# Wait for all the signals to have finished
if tasks:
await self._wait_for_tasks(tasks)

async def _wait_for_tasks(self, tasks: Dict[asyncio.Task, str]):
done, pending = await asyncio.wait(tasks, timeout=self._timeout)
if pending:
msg = f"{len(pending)} Devices did not connect:"
for t in pending:
t.cancel()
with suppress(Exception):
await t
e = t.exception()
msg += f"\n {tasks[t]}: {type(e).__name__}"
lines = str(e).splitlines()
if len(lines) <= 1:
msg += f": {e}"
else:
msg += "".join(f"\n {line}" for line in lines)
logging.error(msg)
raised = [t for t in done if t.exception()]
if raised:
logging.error(f"{len(raised)} Devices raised an error:")
for t in raised:
logging.exception(f" {tasks[t]}:", exc_info=t.exception())
if pending or raised:
raise NotConnected("Not all Devices connected")
connect_coroutines[name] = obj.connect(
self._sim, timeout=self._timeout
)

# Connect to all the devices
if connect_coroutines:
await wait_for_connection(**connect_coroutines)

async def __aexit__(self, type, value, traceback):
self._objects_on_exit = self._caller_locals()
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def name(self) -> str:
def set_name(self, name: str = ""):
self._name = name

async def connect(self, sim=False):
async def connect(self, sim=False, timeout=DEFAULT_TIMEOUT):
if sim:
self._backend = SimSignalBackend(
datatype=self._init_backend.datatype, source=self._init_backend.source
Expand All @@ -67,7 +67,7 @@ async def connect(self, sim=False):
else:
self._backend = self._init_backend
_sim_backends.pop(self, None)
await self._backend.connect()
await self._backend.connect(timeout=timeout)

@property
def source(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from bluesky.protocols import Descriptor, Reading

from .utils import ReadingValueCallback, T
from .utils import DEFAULT_TIMEOUT, ReadingValueCallback, T


class SignalBackend(Generic[T]):
Expand All @@ -16,7 +16,7 @@ class SignalBackend(Generic[T]):
source: str = ""

@abstractmethod
async def connect(self):
async def connect(self, timeout: float = DEFAULT_TIMEOUT):
"""Connect to underlying hardware"""

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/sim_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bluesky.protocols import Descriptor, Dtype, Reading

from .signal_backend import SignalBackend
from .utils import ReadingValueCallback, T, get_dtype
from .utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype

primitive_dtypes: Dict[type, Dtype] = {
str: "string",
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, datatype: Optional[Type[T]], source: str) -> None:
self.put_proceeds.set()
self.callback: Optional[ReadingValueCallback[T]] = None

async def connect(self) -> None:
async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None:
self.converter = make_converter(self.datatype)
self._initial_value = self.converter.make_initial_value(self.datatype)
self._severity = 0
Expand Down
104 changes: 75 additions & 29 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from __future__ import annotations

import asyncio
from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Type, TypeVar
import logging
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
TypeVar,
Union,
)

import numpy as np
from bluesky.protocols import Reading
Expand All @@ -11,46 +24,79 @@
#: monitor updates
ReadingValueCallback = Callable[[Reading, T], None]
DEFAULT_TIMEOUT = 10.0
ErrorText = Union[str, Dict[str, Exception]]


class NotConnected(Exception):
"""Exception to be raised if a `Device.connect` is cancelled"""

def __init__(self, *lines: str):
self.lines = list(lines)
_indent_width = " "

def __init__(self, errors: ErrorText):
"""
NotConnected holds a mapping of device/signal names to
errors.

Parameters
----------
errors: ErrorText
Mapping of device name to Exception or another NotConnected.
Alternatively a string with the signal error text.
"""

self._errors = errors

def _format_sub_errors(self, name: str, error: Exception, indent="") -> str:
if isinstance(error, NotConnected):
error_txt = ":" + error.format_error_string(indent + self._indent_width)
elif isinstance(error, Exception):
error_txt = ": " + err_str + "\n" if (err_str := str(error)) else "\n"
else:
raise RuntimeError(
f"Unexpected type `{type(error)}`, expected an Exception"
)

string = f"{indent}{name}: {type(error).__name__}" + error_txt
return string

def format_error_string(self, indent="") -> str:
if not isinstance(self._errors, dict) and not isinstance(self._errors, str):
raise RuntimeError(
f"Unexpected type `{type(self._errors)}` " "expected `str` or `dict`"
)

if isinstance(self._errors, str):
return " " + self._errors + "\n"

string = "\n"
for name, error in self._errors.items():
string += self._format_sub_errors(name, error, indent=indent)
return string

def __str__(self) -> str:
return "\n".join(self.lines)
return self.format_error_string(indent="")


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating `NotConnected` exceptions
"""Call many underlying signals, accumulating exceptions and returning them

Raises
------
`NotConnected` if cancelled
Expected kwargs should be a mapping of names to coroutine tasks to execute.
"""
ts = {k: asyncio.create_task(c) for (k, c) in coros.items()} # type: ignore
try:
done, pending = await asyncio.wait(ts.values())
except asyncio.CancelledError:
for t in ts.values():
t.cancel()
lines: List[str] = []
for k, t in ts.items():
try:
await t
except NotConnected as e:
if len(e.lines) == 1:
lines.append(f"{k}: {e.lines[0]}")
else:
lines.append(f"{k}:")
lines += [f" {line}" for line in e.lines]
raise NotConnected(*lines)
else:
# Wait for everything to foreground the exceptions
for f in list(done) + list(pending):
await f
results = await asyncio.gather(*coros.values(), return_exceptions=True)
exceptions = {}

for name, result in zip(coros, results):
if isinstance(result, Exception):
exceptions[name] = result
if not isinstance(result, NotConnected):
logging.exception(
f"device `{name}` raised unexpected exception "
f"{type(result).__name__}",
exc_info=result,
)

if exceptions:
raise NotConnected(exceptions)


def get_dtype(typ: Type) -> Optional[np.dtype]:
Expand Down
24 changes: 14 additions & 10 deletions src/ophyd_async/epics/_backend/_aioca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
import sys
from asyncio import CancelledError
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Type, Union
Expand All @@ -8,6 +8,7 @@
FORMAT_CTRL,
FORMAT_RAW,
FORMAT_TIME,
CANothing,
Subscription,
caget,
camonitor,
Expand All @@ -18,14 +19,14 @@
from epicscorelibs.ca import dbr

from ophyd_async.core import (
NotConnected,
ReadingValueCallback,
SignalBackend,
T,
get_dtype,
get_unique,
wait_for_connection,
)
from ophyd_async.core.utils import DEFAULT_TIMEOUT, NotConnected

dbr_to_dtype: Dict[Dbr, Dtype] = {
dbr.DBR_STRING: "string",
Expand Down Expand Up @@ -184,23 +185,26 @@ def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str):
self.source = f"ca://{self.read_pv}"
self.subscription: Optional[Subscription] = None

async def _store_initial_value(self, pv):
async def _store_initial_value(self, pv, timeout: float = DEFAULT_TIMEOUT):
try:
self.initial_values[pv] = await caget(pv, format=FORMAT_CTRL, timeout=None)
except CancelledError:
raise NotConnected(self.source)
self.initial_values[pv] = await caget(
pv, format=FORMAT_CTRL, timeout=timeout
)
except CANothing as exc:
logging.debug(f"signal ca://{pv} timed out")
raise NotConnected(f"ca://{pv}") from exc

async def connect(self):
async def connect(self, timeout: float = DEFAULT_TIMEOUT):
_use_pyepics_context_if_imported()
if self.read_pv != self.write_pv:
# Different, need to connect both
await wait_for_connection(
read_pv=self._store_initial_value(self.read_pv),
write_pv=self._store_initial_value(self.write_pv),
read_pv=self._store_initial_value(self.read_pv, timeout=timeout),
write_pv=self._store_initial_value(self.write_pv, timeout=timeout),
)
else:
# The same, so only need to connect one
await self._store_initial_value(self.read_pv)
await self._store_initial_value(self.read_pv, timeout=timeout)
self.converter = make_converter(self.datatype, self.initial_values)

async def put(self, value: Optional[T], wait=True, timeout=None):
Expand Down
Loading
Loading