Skip to content

Commit

Permalink
Make devices connect with a timeout (#321)
Browse files Browse the repository at this point in the history
Co-authored-by: Rose Yemelyanova <[email protected]>
  • Loading branch information
callumforrester and Rose Yemelyanova authored Oct 26, 2023
1 parent 1f9dadf commit 42cac05
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ testpaths = "docs src tests"
markers = [
"handler: marks tests that interact with the global handler object in handler.py",
]
asyncio_mode = "auto"

[tool.coverage.run]
data_file = "/tmp/blueapi.coverage"
Expand Down
25 changes: 11 additions & 14 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
)

from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop
from ophyd_async.core import Device as AsyncDevice
from ophyd_async.core import wait_for_connection
from pydantic import create_model
from pydantic.fields import FieldInfo, ModelField

from blueapi.config import EnvironmentConfig, SourceKind
from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider
from blueapi.utils import BlueapiPlanModelConfig, load_module_all
from blueapi.utils import (
BlueapiPlanModelConfig,
connect_ophyd_async_devices,
load_module_all,
)

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Expand Down Expand Up @@ -104,17 +106,12 @@ def with_config(self, config: EnvironmentConfig) -> None:
elif source.kind is SourceKind.DODAL:
self.with_dodal_module(mod)

call_in_bluesky_event_loop(self.connect_devices(self.sim))

async def connect_devices(self, sim: bool = False) -> None:
coros = {}
for device_name, device in self.devices.items():
if isinstance(device, AsyncDevice):
device.set_name(device_name)
coros[device_name] = device.connect(sim)

if len(coros) > 0:
await wait_for_connection(**coros)
call_in_bluesky_event_loop(
connect_ophyd_async_devices(
self.devices.values(),
self.sim,
)
)

def with_plan_module(self, module: ModuleType) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig
from .invalid_config_error import InvalidConfigError
from .modules import load_module_all
from .ophyd_async_connect import connect_ophyd_async_devices
from .serialization import serialize
from .thread_exception import handle_all_exceptions

Expand All @@ -13,4 +14,5 @@
"BlueapiModelConfig",
"BlueapiPlanModelConfig",
"InvalidConfigError",
"connect_ophyd_async_devices",
]
47 changes: 47 additions & 0 deletions src/blueapi/utils/ophyd_async_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio
import logging
from contextlib import suppress
from typing import Any, Dict, Iterable

from ophyd_async.core import DEFAULT_TIMEOUT
from ophyd_async.core import Device as OphydAsyncDevice
from ophyd_async.core import NotConnected


async def connect_ophyd_async_devices(
devices: Iterable[Any],
sim: bool = False,
timeout: float = DEFAULT_TIMEOUT,
) -> None:
tasks: Dict[asyncio.Task, str] = {}
for device in devices:
if isinstance(device, OphydAsyncDevice):
task = asyncio.create_task(device.connect(sim=sim))
tasks[task] = device.name
if tasks:
await _wait_for_tasks(tasks, timeout=timeout)


async def _wait_for_tasks(tasks: Dict[asyncio.Task, str], timeout: float):
done, pending = await asyncio.wait(tasks, timeout=timeout)
if pending:
msg = f"{len(pending)} Devices did not connect:"
for t in pending:
t.cancel()
with suppress(Exception):
await t
e = t.exception()
msg += f"\n {tasks[t]}: {type(e).__name__}"
lines = str(e).splitlines()
if len(lines) <= 1:
msg += f": {e}"
else:
msg += "".join(f"\n {line}" for line in lines)
logging.error(msg)
raised = [t for t in done if t.exception()]
if raised:
logging.error(f"{len(raised)} Devices raised an error:")
for t in raised:
logging.exception(f" {tasks[t]}:", exc_info=t.exception())
if pending or raised:
raise NotConnected("Not all Devices connected")

0 comments on commit 42cac05

Please sign in to comment.