diff --git a/src/ophyd_async/plan_stubs/ensure_connected.py b/src/ophyd_async/plan_stubs/ensure_connected.py index 736ce7b780..6049595971 100644 --- a/src/ophyd_async/plan_stubs/ensure_connected.py +++ b/src/ophyd_async/plan_stubs/ensure_connected.py @@ -10,13 +10,18 @@ def ensure_connected( timeout: float = DEFAULT_TIMEOUT, force_reconnect=False, ): - yield from bps.wait_for( + (connect_task,) = yield from bps.wait_for( [ lambda: wait_for_connection( **{ - device.name: device.connect(mock, timeout, force_reconnect) + device.name: device.connect( + mock=mock, timeout=timeout, force_reconnect=force_reconnect + ) for device in devices } ) ] ) + + if connect_task and connect_task.exception() is not None: + raise connect_task.exception() diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 683f50ba00..1294e64511 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -159,8 +159,8 @@ async def connect(self, timeout=DEFAULT_TIMEOUT): and not test_motor._connect_task.exception() ) - # TODO https://github.com/bluesky/ophyd-async/issues/413 - RE(ensure_connected(test_motor, mock=True, force_reconnect=True)) + with pytest.raises(NotConnected, match="RuntimeError: connect fail"): + RE(ensure_connected(test_motor, mock=True, force_reconnect=True)) assert ( test_motor._connect_task diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 845e454b61..a60bcd63f3 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -29,7 +29,7 @@ wait_for_value, ) from ophyd_async.core.signal import _SignalCache -from ophyd_async.core.utils import DEFAULT_TIMEOUT +from ophyd_async.core.utils import DEFAULT_TIMEOUT, NotConnected from ophyd_async.epics.signal import epics_signal_r, epics_signal_rw from ophyd_async.plan_stubs import ensure_connected @@ -167,8 +167,8 @@ async def connect(self, timeout=DEFAULT_TIMEOUT): and not signal._connect_task.exception() ) - # TODO https://github.com/bluesky/ophyd-async/issues/413 - RE(ensure_connected(signal, mock=False, force_reconnect=True)) + with pytest.raises(NotConnected, match="RuntimeError: connect fail"): + RE(ensure_connected(signal, mock=False, force_reconnect=True)) assert ( signal._connect_task and signal._connect_task.done() diff --git a/tests/plan_stubs/test_ensure_connected.py b/tests/plan_stubs/test_ensure_connected.py new file mode 100644 index 0000000000..f77ad50355 --- /dev/null +++ b/tests/plan_stubs/test_ensure_connected.py @@ -0,0 +1,40 @@ +import pytest + +from ophyd_async.core import Device, NotConnected +from ophyd_async.core.mock_signal_backend import MockSignalBackend +from ophyd_async.core.signal import SignalRW +from ophyd_async.epics.signal import epics_signal_rw +from ophyd_async.plan_stubs import ensure_connected + + +def test_ensure_connected(RE): + class MyDevice(Device): + def __init__(self, prefix: str, name=""): + self.signal = epics_signal_rw(str, f"pva://{prefix}:SIGNAL") + super().__init__(name=name) + + device1 = MyDevice("PREFIX1", name="device1") + + def connect(): + yield from ensure_connected(device1, mock=False, timeout=0.1) + + with pytest.raises( + NotConnected, + match="device1: NotConnected:\n signal: NotConnected: pva://PREFIX1:SIGNAL", + ): + RE(connect()) + + assert isinstance(device1.signal._connect_task.exception(), NotConnected) + + device1.signal = SignalRW(MockSignalBackend(str)) + RE(connect()) + assert device1.signal._connect_task.exception() is None + + device2 = MyDevice("PREFIX2", name="device2") + + def connect_with_mocking(): + assert device2.signal._connect_task is None + yield from ensure_connected(device2, mock=True, timeout=0.1) + assert device2.signal._connect_task.done() + + RE(connect_with_mocking())