From 264f16d501f8d23094952190757dcaef63f93795 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Wed, 26 Jun 2024 09:54:45 +0100 Subject: [PATCH] cleaned up tests for lazy connection --- src/ophyd_async/core/signal.py | 2 +- tests/core/test_device.py | 43 ++++++++++++--- tests/core/test_signal.py | 98 +++++++++------------------------- 3 files changed, 62 insertions(+), 81 deletions(-) diff --git a/src/ophyd_async/core/signal.py b/src/ophyd_async/core/signal.py index f4ca362c69..8e4705fb17 100644 --- a/src/ophyd_async/core/signal.py +++ b/src/ophyd_async/core/signal.py @@ -88,7 +88,7 @@ async def connect( ) self._previous_connect_was_mock = mock - if mock and not isinstance(self._backend, MockSignalBackend): + if mock and not issubclass(type(self._backend), MockSignalBackend): # Using a soft backend, look to the initial value self._backend = MockSignalBackend(initial_backend=self._backend) diff --git a/tests/core/test_device.py b/tests/core/test_device.py index d9cc16f68c..683f50ba00 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -9,6 +9,7 @@ Device, DeviceCollector, DeviceVector, + MockSignalBackend, NotConnected, wait_for_connection, ) @@ -127,18 +128,44 @@ async def test_device_log_has_correct_name(): async def test_device_lazily_connects(RE): - async with DeviceCollector(mock=True, connect=False): - mock_motor = motor.Motor("BLxxI-MO-TABLE-01:X") + class MockSignalBackendFailingFirst(MockSignalBackend): + succeed_on_connect = False - assert not mock_motor._connect_task + async def connect(self, timeout=DEFAULT_TIMEOUT): + if self.succeed_on_connect: + self.succeed_on_connect = False + await super().connect(timeout=timeout) + else: + self.succeed_on_connect = True + raise RuntimeError("connect fail") - # When ready to connect - RE(ensure_connected(mock_motor, mock=True)) + test_motor = motor.Motor("BLxxI-MO-TABLE-01:X") + test_motor.user_setpoint._backend = MockSignalBackendFailingFirst(int) + + with pytest.raises(NotConnected, match="RuntimeError: connect fail"): + await test_motor.connect(mock=True) assert ( - mock_motor._connect_task - and mock_motor._connect_task.done() - and not mock_motor._connect_task.exception() + test_motor._connect_task + and test_motor._connect_task.done() + and test_motor._connect_task.exception() + ) + + RE(ensure_connected(test_motor, mock=True)) + + assert ( + test_motor._connect_task + and test_motor._connect_task.done() + and not test_motor._connect_task.exception() + ) + + # TODO https://github.com/bluesky/ophyd-async/issues/413 + RE(ensure_connected(test_motor, mock=True, force_reconnect=True)) + + assert ( + test_motor._connect_task + and test_motor._connect_task.done() + and test_motor._connect_task.exception() ) diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 771e4e20b4..9030d1df87 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -2,7 +2,7 @@ import logging import re import time -from unittest.mock import ANY, Mock +from unittest.mock import ANY import numpy import pytest @@ -34,25 +34,6 @@ from ophyd_async.plan_stubs import ensure_connected -class SignalRejectingAtFirst(Signal): - def __init__(self) -> None: - self.connected = False - self.first_connection = True - self.backend = soft_signal_rw(int, 0, name="mock_signal") - - async def connect( - self, mock=False, timeout=DEFAULT_TIMEOUT, force_reconnect: bool = False - ): - if self.first_connection: - self.first_connection = False - else: - await self.backend.connect(timeout) - self.connected = True - - async def _connect_task(self): - return self.backend._connect_task() - - async def test_signals_equality_raises(): s1 = epics_signal_rw(int, "pva://pv1", name="signal") s2 = epics_signal_rw(int, "pva://pv2", name="signal") @@ -155,68 +136,41 @@ async def test_rejects_reconnect_when_connects_have_diff_mock_status( async def test_signal_lazily_connects(RE): - failing_signal = SignalRejectingAtFirst() - await failing_signal.connect(mock=False) - assert failing_signal._connect_task.exception() - RE(ensure_connected(failing_signal, mock=False)) - assert ( - failing_signal._connect_task - and failing_signal._connect_task.done() - and not failing_signal._connect_task.exception() - ) + class MockSignalBackendFailingFirst(MockSignalBackend): + succeed_on_connect = False + async def connect(self, timeout=DEFAULT_TIMEOUT): + if self.succeed_on_connect: + self.succeed_on_connect = False + await super().connect(timeout=timeout) + else: + self.succeed_on_connect = True + raise RuntimeError("connect fail") -async def test_signal_lazily_connects_1(RE): - mock_signal_rw = soft_signal_rw(int, 0, name="mock_signal") + signal = SignalRW(MockSignalBackendFailingFirst(int)) + + with pytest.raises(RuntimeError, match="connect fail"): + await signal.connect(mock=False) - await mock_signal_rw.connect(mock=False) - RE(ensure_connected(mock_signal_rw, mock=True)) assert ( - mock_signal_rw._connect_task - and mock_signal_rw._connect_task.done() - and not mock_signal_rw._connect_task.exception() + signal._connect_task + and signal._connect_task.done() + and signal._connect_task.exception() ) - -async def test_signal_lazily_connects_2(RE): - failing_signal = Signal(MockSignalBackend(int)) - - cache_connect = failing_signal.connect - first_connect = True - - def fail_connect(): - if first_connect: - first_connect = False - else: - cache_connect() - - await failing_signal.connect(mock=False) - RE(ensure_connected(failing_signal, mock=False)) + RE(ensure_connected(signal, mock=False)) assert ( - failing_signal._connect_task - and failing_signal._connect_task.done() - and not failing_signal._connect_task.exception() + signal._connect_task + and signal._connect_task.done() + and not signal._connect_task.exception() ) - -async def test_signal_lazily_connects_3(RE): - mock_signal_rw = soft_signal_rw(int, 0, name="mock_signal") - cached_connect = mock_signal_rw._backend.connect - fail_at_first_connect = Mock() - fail_at_first_connect.side_effect = [ - Exception("Failure on first call"), - cached_connect, - cached_connect, - ] - mock_signal_rw._backend.connect = fail_at_first_connect - - with pytest.raises(Exception, match="Failure on first call"): - await mock_signal_rw.connect(mock=False) - RE(ensure_connected(mock_signal_rw, mock=False)) + # TODO https://github.com/bluesky/ophyd-async/issues/413 + RE(ensure_connected(signal, mock=False, force_reconnect=True)) assert ( - mock_signal_rw._connect_task - and mock_signal_rw._connect_task.done() - and not mock_signal_rw._connect_task.exception() + signal._connect_task + and signal._connect_task.done() + and signal._connect_task.exception() )