Skip to content

Commit

Permalink
tested the sim mode of pvi
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Apr 5, 2024
1 parent d6bc8cb commit 59d44fe
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 30 deletions.
61 changes: 39 additions & 22 deletions src/ophyd_async/epics/pvi/pvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _strip_union(field: Union[Union[T], T]) -> T:
for arg in args:
if arg is not type(None):
return arg

return field


Expand All @@ -60,10 +59,6 @@ def _strip_device_vector(field: Union[Type[Device]]) -> Tuple[bool, Type[Device]
return False, field


def _get_common_device_typeypes(name: str, common_device: Type[Device]) -> Type[Device]:
return get_type_hints(common_device).get(name, {})


@dataclass
class PVIEntry:
"""
Expand All @@ -85,15 +80,24 @@ def is_pvi_table(self) -> bool:


def _verify_common_blocks(entry: PVIEntry, common_device: Type[Device]):
if not entry.sub_entries:
return
common_sub_devices = get_type_hints(common_device)
for sub_name, sub_device in common_sub_devices.items():
if sub_name in ("_name", "parent"):
continue
assert entry.sub_entries
if sub_name not in entry.sub_entries:
if sub_name not in entry.sub_entries and get_origin(sub_device) is not Optional:
raise RuntimeError(
f"sub device `{sub_name}:{type(sub_device)}` was not provided by pvi"
)
if isinstance(entry.sub_entries[sub_name], dict):
for sub_sub_entry in entry.sub_entries[sub_name].values(): # type: ignore
_verify_common_blocks(sub_sub_entry, sub_device) # type: ignore
else:
_verify_common_blocks(
entry.sub_entries[sub_name], sub_device # type: ignore
)


_pvi_mapping: Dict[FrozenSet[str], Callable[..., Signal]] = {
Expand Down Expand Up @@ -183,12 +187,13 @@ def initialize_device(
"""

assert entry.sub_entries
common_device_type_hints = (
get_type_hints(common_device_type) if common_device_type else None
)
for sub_name, sub_entries in entry.sub_entries.items():
sub_common_device_type = None
if common_device_type:
sub_common_device_type = _get_common_device_typeypes(
sub_name, common_device_type
)
if common_device_type_hints:
sub_common_device_type = common_device_type_hints.get(sub_name, None)
sub_common_device_type = _strip_union(sub_common_device_type)
pre_defined_device_vector, sub_common_device_type = (
_strip_device_vector(sub_common_device_type)
Expand Down Expand Up @@ -253,7 +258,6 @@ def initialize_device(


def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None):

device_t = stripped_type or type(device)
for sub_name, sub_device_t in get_type_hints(device_t).items():
if sub_name in ("_name", "parent"):
Expand All @@ -262,18 +266,30 @@ def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None):
# we'll take the first type in the union which isn't NoneType
sub_device_t = _strip_union(sub_device_t)
is_device_vector, sub_device_t = _strip_device_vector(sub_device_t)
is_signal = (origin := get_origin(sub_device_t)) and issubclass(origin, Signal)

if is_signal:
signal_type = get_args(sub_device_t)[0]
sub_device = sub_device_t(SimSignalBackend(signal_type, sub_name))
elif is_device_vector:
is_signal = (
(origin := get_origin(sub_device_t)) and issubclass(origin, Signal)
) or (issubclass(sub_device_t, Signal))

if is_device_vector and is_signal:
signal_type = args[0] if (args := get_args(sub_device_t)) else None
sub_device_1 = sub_device_t(SimSignalBackend(signal_type, sub_name))
sub_device_2 = sub_device_t(SimSignalBackend(signal_type, sub_name))
sub_device = DeviceVector(
{
1: sub_device_1,
2: sub_device_2,
}
)
elif is_device_vector and not is_signal:
sub_device = DeviceVector(
{
1: sub_device_t(name=f"{device.name}-{sub_name}-1"),
2: sub_device_t(name=f"{device.name}-{sub_name}-2"),
}
)
elif is_signal:
signal_type = args[0] if (args := get_args(sub_device_t)) else None
sub_device = sub_device_t(SimSignalBackend(signal_type, sub_name))
else:
sub_device = sub_device_t(name=f"{device.name}-{sub_name}")

Expand All @@ -288,20 +304,21 @@ def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None):


async def fill_pvi_entries(
device: Device, root_pv: str, timeout=DEFAULT_TIMEOUT, sim=True
device: Device, root_pv: str, timeout=DEFAULT_TIMEOUT, sim=False
):
"""
Fills a ``device`` with signals from a the ``root_pvi:PVI`` table.
If the device names match with parent devices of ``device`` then types are used.
"""
if not sim:
if sim:
# set up sim signals for the common annotations
_sim_common_blocks(device)

else:
# check the pvi table for devices and fill the device with them
parser = PVIParser(root_pv, timeout=timeout)
await parser.get_pvi_entries()
parser.root_entry.device = device
parser.initialize_device(parser.root_entry, common_device_type=type(device))

if sim:
# set up sim signals for the common annotations
_sim_common_blocks(device)
2 changes: 1 addition & 1 deletion src/ophyd_async/panda/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CommonPandABlocks(Device):
pcap: PcapBlock


class PandA(CommonPandABlocks, Device):
class PandA(CommonPandABlocks):
def __init__(self, prefix: str, name: str = "") -> None:
self._prefix = prefix
super().__init__(name)
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from bluesky.run_engine import RunEngine, TransitionError

RECORD = str(Path(__file__).parent / "panda" / "db" / "panda.db")
PANDA_RECORD = str(Path(__file__).parent / "panda" / "db" / "panda.db")
INCOMPLETE_BLOCK_RECORD = str(
Path(__file__).parent / "panda" / "db" / "incomplete_block_panda.db"
)
Expand Down Expand Up @@ -39,7 +39,7 @@ def clean_event_loop():


@pytest.fixture(scope="module", params=["pva"])
def pva():
def panda_pva():
processes = [
subprocess.Popen(
[
Expand All @@ -49,7 +49,7 @@ def pva():
"-m",
macros,
"-d",
RECORD,
PANDA_RECORD,
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
Expand Down
84 changes: 84 additions & 0 deletions tests/epics/test_pvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Optional

import pytest

from ophyd_async.core import (
DEFAULT_TIMEOUT,
Device,
DeviceCollector,
DeviceVector,
SignalRW,
SignalX,
)
from ophyd_async.epics.pvi import fill_pvi_entries


class Block1(Device):
device_vector_signal_x: DeviceVector[SignalX]
device_vector_signal_rw: DeviceVector[SignalRW[float]]
signal_x: SignalX
signal_rw: SignalRW[int]


class Block2(Device):
device_vector: DeviceVector[Block1]
device: Block1
signal_x: SignalX
signal_rw: SignalRW[int]


class Block3(Device):
device_vector: Optional[DeviceVector[Block2]]
device: Block2
signal_device: Block1
signal_x: SignalX
signal_rw: SignalRW[int]


@pytest.fixture
def pvi_test_device_t():
"""A fixture since pytest discourages init in test case classes"""

class TestDevice(Block3, Device):
def __init__(self, prefix: str, name: str = ""):
self._prefix = prefix
super().__init__(name)

async def connect(
self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT
) -> None:
await fill_pvi_entries(self, self._prefix + "PVI", timeout=timeout, sim=sim)

await super().connect(sim)

yield TestDevice


async def test_fill_pvi_entries_sim_mode(pvi_test_device_t):
async with DeviceCollector(sim=True):
test_device = pvi_test_device_t("PREFIX:")

# device vectors are typed
assert isinstance(test_device.device_vector[1], Block2)
assert isinstance(test_device.device_vector[2], Block2)

# elements of device vectors are typed recursively
assert test_device.device_vector[1].signal_rw._backend.datatype is int
assert isinstance(test_device.device_vector[1].device, Block1)
assert test_device.device_vector[1].device.signal_rw._backend.datatype is int
assert (
test_device.device_vector[1].device.device_vector_signal_rw[1]._backend.datatype
is float
)

# top level blocks are typed
assert isinstance(test_device.signal_device, Block1)
assert isinstance(test_device.device, Block2)

# elements of top level blocks are typed recursively
assert test_device.device.signal_rw._backend.datatype is int
assert isinstance(test_device.device.device, Block1)
assert test_device.device.device.signal_rw._backend.datatype is int

# top level signals are typed
assert test_device.signal_rw._backend.datatype is int
11 changes: 7 additions & 4 deletions tests/panda/test_panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def test_panda_children_connected(sim_panda: PandA):
assert readback_seq == table


async def test_panda_with_missing_blocks(pva):
async def test_panda_with_missing_blocks(panda_pva):
panda = PandA("PANDAQSRVI:")
with pytest.raises(RuntimeError) as exc:
await panda.connect()
Expand All @@ -108,7 +108,7 @@ async def test_panda_with_missing_blocks(pva):
)


async def test_panda_with_extra_blocks_and_signals(pva):
async def test_panda_with_extra_blocks_and_signals(panda_pva):
panda = PandA("PANDAQSRV:")
await panda.connect()

Expand All @@ -118,7 +118,7 @@ async def test_panda_with_extra_blocks_and_signals(pva):
assert panda.pcap.newsignal # type: ignore


async def test_panda_gets_types_from_common_class(pva):
async def test_panda_gets_types_from_common_class(panda_pva):
panda = PandA("PANDAQSRV:")
await panda.connect()

Expand All @@ -133,11 +133,14 @@ async def test_panda_gets_types_from_common_class(pva):
# predefined signals get set up with the correct datatype
assert panda.pcap.active._backend.datatype is bool

# works with custom datatypes
assert panda.seq[1].table._backend.datatype is SeqTable

# others are given the None datatype
assert panda.pcap.newsignal._backend.datatype is None


async def test_panda_block_missing_signals(pva):
async def test_panda_block_missing_signals(panda_pva):
panda = PandA("PANDAQSRVIB:")

with pytest.raises(Exception) as exc:
Expand Down

0 comments on commit 59d44fe

Please sign in to comment.