Skip to content

Commit

Permalink
(#325) Fix typing issues around mock backend
Browse files Browse the repository at this point in the history
  • Loading branch information
DominicOram authored and abbiemery committed May 28, 2024
1 parent 4da1550 commit 937cbea
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
12 changes: 6 additions & 6 deletions src/ophyd_async/core/mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Callable, Iterable, Iterator, List
from typing import Any, Callable, Iterable
from unittest.mock import ANY, Mock

from ophyd_async.core.signal import Signal
Expand All @@ -22,7 +22,7 @@ def set_mock_value(signal: Signal[T], value: T):
backend.set_value(value)


def set_mock_put_proceeds(signal: Signal[T], proceeds: bool):
def set_mock_put_proceeds(signal: Signal, proceeds: bool):
"""Allow or block a put with wait=True from proceeding"""
backend = _get_mock_signal_backend(signal)

Expand All @@ -33,7 +33,7 @@ def set_mock_put_proceeds(signal: Signal[T], proceeds: bool):


@asynccontextmanager
async def mock_puts_blocked(*signals: List[Signal]):
async def mock_puts_blocked(*signals: Signal):
for signal in signals:
set_mock_put_proceeds(signal, False)
yield
Expand Down Expand Up @@ -79,15 +79,15 @@ def __next__(self):
return next_value

def __del__(self):
if self.require_all_consumed and self.index != len(self.values):
if self.require_all_consumed and self.index != len(list(self.values)):
raise AssertionError("Not all values have been consumed.")


def set_mock_values(
signal: Signal,
values: Iterable[Any],
require_all_consumed: bool = False,
) -> Iterator[Any]:
) -> _SetValuesIterator:
"""Iterator to set a signal to a sequence of values, optionally repeating the
sequence.
Expand Down Expand Up @@ -127,7 +127,7 @@ def _unset_side_effect_cm(put_mock: Mock):
put_mock.side_effect = None


def callback_on_mock_put(signal: Signal, callback: Callable[[T], None]):
def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]):
"""For setting a callback when a backend is put to.
Can either be used in a context, with the callback being
Expand Down
22 changes: 12 additions & 10 deletions tests/core/test_mock_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def test_mock_signal_backend(connect_mock_mode):
# If mock is false it will be handled like a normal signal, otherwise it will
# initalize a new backend from the one in the line above
await mock_signal.connect(mock=connect_mock_mode)
assert isinstance(mock_signal._backend, MockSignalBackend)

assert await mock_signal._backend.get_value() == ""
await mock_signal._backend.put("test")
Expand Down Expand Up @@ -74,6 +75,8 @@ async def test_set_mock_put_proceeds():
mock_signal = SignalW(SoftSignalBackend(str))
await mock_signal.connect(mock=True)

assert isinstance(mock_signal._backend, MockSignalBackend)

assert mock_signal._backend.put_proceeds.is_set() is True

set_mock_put_proceeds(mock_signal, False)
Expand All @@ -95,6 +98,7 @@ async def test_set_mock_put_proceeds_timeout():
async def test_put_proceeds_timeout():
mock_signal = SignalW(SoftSignalBackend(str))
await mock_signal.connect(mock=True)
assert isinstance(mock_signal._backend, MockSignalBackend)

assert mock_signal._backend.put_proceeds.is_set() is True

Expand All @@ -115,11 +119,11 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
assert_mock_put_called_with(signal, 10)
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
async with mock_puts_blocked(signal, 10):
async with mock_puts_blocked(signal):
...
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
with callback_on_mock_put(signal, 10):
with callback_on_mock_put(signal, lambda x: _):
...
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
Expand Down Expand Up @@ -216,10 +220,8 @@ async def test_callback_on_mock_put_no_ctx():
mock_signal = SignalRW(SoftSignalBackend(float))
await mock_signal.connect(mock=True)
calls = []
(
callback_on_mock_put(
mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args})
),
callback_on_mock_put(
mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args})
)
await mock_signal.set(10.0)
assert calls == [
Expand Down Expand Up @@ -249,16 +251,16 @@ def some_function_without_kwargs(arg):
async def test_set_mock_values(mock_signals):
signal1, signal2 = mock_signals

await signal2.get_value() == "first_value"
assert await signal2.get_value() == "first_value"
for value_set in set_mock_values(signal1, ["second_value", "third_value"]):
assert await signal1.get_value() == value_set

iterator = set_mock_values(signal2, ["second_value", "third_value"])
await signal2.get_value() == "first_value"
assert await signal2.get_value() == "first_value"
next(iterator)
await signal2.get_value() == "second_value"
assert await signal2.get_value() == "second_value"
next(iterator)
await signal2.get_value() == "third_value"
assert await signal2.get_value() == "third_value"


async def test_set_mock_values_exhausted_passes(mock_signals):
Expand Down

0 comments on commit 937cbea

Please sign in to comment.