diff --git a/src/ophyd_async/core/mock_signal_utils.py b/src/ophyd_async/core/mock_signal_utils.py index 6a79709684..38e23a6c99 100644 --- a/src/ophyd_async/core/mock_signal_utils.py +++ b/src/ophyd_async/core/mock_signal_utils.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager, contextmanager -from typing import Any, Callable, Iterable, Iterator, List +from typing import Any, Callable, Iterable from unittest.mock import ANY, Mock from ophyd_async.core.signal import Signal @@ -22,7 +22,7 @@ def set_mock_value(signal: Signal[T], value: T): backend.set_value(value) -def set_mock_put_proceeds(signal: Signal[T], proceeds: bool): +def set_mock_put_proceeds(signal: Signal, proceeds: bool): """Allow or block a put with wait=True from proceeding""" backend = _get_mock_signal_backend(signal) @@ -33,7 +33,7 @@ def set_mock_put_proceeds(signal: Signal[T], proceeds: bool): @asynccontextmanager -async def mock_puts_blocked(*signals: List[Signal]): +async def mock_puts_blocked(*signals: Signal): for signal in signals: set_mock_put_proceeds(signal, False) yield @@ -79,7 +79,7 @@ def __next__(self): return next_value def __del__(self): - if self.require_all_consumed and self.index != len(self.values): + if self.require_all_consumed and self.index != len(list(self.values)): raise AssertionError("Not all values have been consumed.") @@ -87,7 +87,7 @@ def set_mock_values( signal: Signal, values: Iterable[Any], require_all_consumed: bool = False, -) -> Iterator[Any]: +) -> _SetValuesIterator: """Iterator to set a signal to a sequence of values, optionally repeating the sequence. @@ -127,7 +127,7 @@ def _unset_side_effect_cm(put_mock: Mock): put_mock.side_effect = None -def callback_on_mock_put(signal: Signal, callback: Callable[[T], None]): +def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]): """For setting a callback when a backend is put to. Can either be used in a context, with the callback being diff --git a/tests/core/test_mock_signal_backend.py b/tests/core/test_mock_signal_backend.py index 7b16431e91..2cce6dff7e 100644 --- a/tests/core/test_mock_signal_backend.py +++ b/tests/core/test_mock_signal_backend.py @@ -31,6 +31,7 @@ async def test_mock_signal_backend(connect_mock_mode): # If mock is false it will be handled like a normal signal, otherwise it will # initalize a new backend from the one in the line above await mock_signal.connect(mock=connect_mock_mode) + assert isinstance(mock_signal._backend, MockSignalBackend) assert await mock_signal._backend.get_value() == "" await mock_signal._backend.put("test") @@ -74,6 +75,8 @@ async def test_set_mock_put_proceeds(): mock_signal = SignalW(SoftSignalBackend(str)) await mock_signal.connect(mock=True) + assert isinstance(mock_signal._backend, MockSignalBackend) + assert mock_signal._backend.put_proceeds.is_set() is True set_mock_put_proceeds(mock_signal, False) @@ -95,6 +98,7 @@ async def test_set_mock_put_proceeds_timeout(): async def test_put_proceeds_timeout(): mock_signal = SignalW(SoftSignalBackend(str)) await mock_signal.connect(mock=True) + assert isinstance(mock_signal._backend, MockSignalBackend) assert mock_signal._backend.put_proceeds.is_set() is True @@ -115,11 +119,11 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend(): assert_mock_put_called_with(signal, 10) exc_msgs.append(str(exc.value)) with pytest.raises(AssertionError) as exc: - async with mock_puts_blocked(signal, 10): + async with mock_puts_blocked(signal): ... exc_msgs.append(str(exc.value)) with pytest.raises(AssertionError) as exc: - with callback_on_mock_put(signal, 10): + with callback_on_mock_put(signal, lambda x: _): ... exc_msgs.append(str(exc.value)) with pytest.raises(AssertionError) as exc: @@ -216,10 +220,8 @@ async def test_callback_on_mock_put_no_ctx(): mock_signal = SignalRW(SoftSignalBackend(float)) await mock_signal.connect(mock=True) calls = [] - ( - callback_on_mock_put( - mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args}) - ), + callback_on_mock_put( + mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args}) ) await mock_signal.set(10.0) assert calls == [ @@ -249,16 +251,16 @@ def some_function_without_kwargs(arg): async def test_set_mock_values(mock_signals): signal1, signal2 = mock_signals - await signal2.get_value() == "first_value" + assert await signal2.get_value() == "first_value" for value_set in set_mock_values(signal1, ["second_value", "third_value"]): assert await signal1.get_value() == value_set iterator = set_mock_values(signal2, ["second_value", "third_value"]) - await signal2.get_value() == "first_value" + assert await signal2.get_value() == "first_value" next(iterator) - await signal2.get_value() == "second_value" + assert await signal2.get_value() == "second_value" next(iterator) - await signal2.get_value() == "third_value" + assert await signal2.get_value() == "third_value" async def test_set_mock_values_exhausted_passes(mock_signals):