diff --git a/src/ophyd_async/epics/pvi.py b/src/ophyd_async/epics/pvi.py index a71880ca1f..29f8961281 100644 --- a/src/ophyd_async/epics/pvi.py +++ b/src/ophyd_async/epics/pvi.py @@ -1,7 +1,23 @@ -from typing import Callable, Dict, FrozenSet, Optional, Type, TypedDict, TypeVar +import re +from dataclasses import dataclass +from typing import ( + Callable, + Dict, + FrozenSet, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, +) +from ophyd_async.core import Device, DeviceVector, SimSignalBackend from ophyd_async.core.signal import Signal -from ophyd_async.core.signal_backend import SignalBackend from ophyd_async.core.utils import DEFAULT_TIMEOUT from ophyd_async.epics._backend._p4p import PvaSignalBackend from ophyd_async.epics.signal.signal import ( @@ -12,59 +28,253 @@ ) T = TypeVar("T") +Access = FrozenSet[ + Literal["r"] | Literal["w"] | Literal["rw"] | Literal["x"] | Literal["d"] +] + + +def _strip_number_from_string(string: str) -> Tuple[str, Optional[int]]: + match = re.match(r"(.*?)(\d*)$", string) + assert match + + name = match.group(1) + number = match.group(2) or None + if number: + number = int(number) + return name, number + + +@dataclass +class PVIEntry: + """ + A dataclass to represent a single entry in the PVI table. + This could either be a signal or a sub-table. + """ + + name: Optional[str] + access: Access + values: List[str] + # `sub_entries` if the signal is a PVI table + # If a sub device is a device vector then it will be represented by a further dict + sub_entries: Optional[Dict[str, Union[Dict[int, "PVIEntry"], "PVIEntry"]]] = None + device: Optional[Device] = None + + @property + def is_pvi_table(self) -> bool: + return len(self.values) == 1 and self.values[0].endswith(":PVI") _pvi_mapping: Dict[FrozenSet[str], Callable[..., Signal]] = { - frozenset({"r", "w"}): lambda dtype, read_pv, write_pv: epics_signal_rw( - dtype, read_pv, write_pv + frozenset({"r", "w"}): lambda read_pv, write_pv: epics_signal_rw( + None, "pva://" + read_pv, "pva://" + write_pv ), - frozenset({"rw"}): lambda dtype, read_pv, write_pv: epics_signal_rw( - dtype, read_pv, write_pv + frozenset({"rw"}): lambda read_write_pv: epics_signal_rw( + None, "pva://" + read_write_pv, write_pv="pva://" + read_write_pv ), - frozenset({"r"}): lambda dtype, read_pv, _: epics_signal_r(dtype, read_pv), - frozenset({"w"}): lambda dtype, _, write_pv: epics_signal_w(dtype, write_pv), - frozenset({"x"}): lambda _, __, write_pv: epics_signal_x(write_pv), + frozenset({"r"}): lambda read_pv: epics_signal_r(None, "pva://" + read_pv), + frozenset({"w"}): lambda write_pv: epics_signal_w(None, "pva://" + write_pv), + frozenset({"x"}): lambda write_pv: epics_signal_x("pva://" + write_pv), } -class PVIEntry(TypedDict, total=False): - d: str - r: str - rw: str - w: str - x: str +class PVIParser: + def __init__( + self, + root_pv: str, + timeout=DEFAULT_TIMEOUT, + ): + self.root_entry = PVIEntry( + name=None, access=frozenset({"d"}), values=[root_pv], sub_entries={} + ) + self.timeout = timeout + + async def get_pvi_entries(self, entry: Optional[PVIEntry] = None): + """Creates signals from a top level PVI table""" + if not entry: + entry = self.root_entry + + assert entry.is_pvi_table + + pvi_table_signal_backend: PvaSignalBackend = PvaSignalBackend( + None, entry.values[0], entry.values[0] + ) + await pvi_table_signal_backend.connect( + timeout=self.timeout + ) # create table signal backend + + pva_table = await pvi_table_signal_backend.get_value() + assert "pvi" in pva_table + + entry.sub_entries = {} + + for sub_name, pva_enties in pva_table["pvi"].items(): + sub_entry = PVIEntry( + name=sub_name, + access=frozenset(pva_enties.keys()), + values=list(pva_enties.values()), + sub_entries={}, + ) + + if sub_entry.is_pvi_table: + sub_split_name, sub_split_number = _strip_number_from_string(sub_name) + if not sub_split_number: + sub_split_number = 1 + await self.get_pvi_entries(entry=sub_entry) + entry.sub_entries[sub_split_name] = entry.sub_entries.get( + sub_split_name, {} + ) + entry.sub_entries[sub_split_name][ + sub_split_number + ] = sub_entry # type: ignore + else: + sub_entry.device = _pvi_mapping[sub_entry.access](*sub_entry.values) + entry.sub_entries[sub_name] = sub_entry -async def pvi_get( - read_pv: str, timeout: float = DEFAULT_TIMEOUT -) -> Dict[str, PVIEntry]: - """Makes a PvaSignalBackend purely to connect to PVI information. + def _get_common_device_types( + self, name: str, common_device: Type[Device] + ) -> Optional[Type[Device]]: + return get_type_hints(common_device).get(name, {}) - This backend is simply thrown away at the end of this method. This is useful - because the backend handles a CancelledError exception that gets thrown on - timeout, and therefore can be used for error reporting.""" - backend: SignalBackend = PvaSignalBackend(None, read_pv, read_pv) - await backend.connect(timeout=timeout) - d: Dict[str, Dict[str, Dict[str, str]]] = await backend.get_value() - pv_info = d.get("pvi") or {} - result = {} + def initialize_device( + self, + entry: PVIEntry, + common_device: Optional[Type[Device]] = None, + ): + """Recursively iterates through the tree of PVI entries and creates devices. - for attr_name, attr_info in pv_info.items(): - result[attr_name] = PVIEntry(**attr_info) # type: ignore + Args: + entry The current PVI entry + common_device The common device type for the current entry + if it exists, else None + Returns: + The initialised device containing it's signals, all typed. + """ - return result + assert entry.sub_entries + for sub_name, sub_entries in entry.sub_entries.items(): + sub_common_device = ( + self._get_common_device_types(sub_name, common_device) + if common_device + else None + ) + if isinstance(sub_entries, dict) and ( + len(sub_entries) != 1 or (get_origin(sub_common_device) == DeviceVector) + ): -def make_signal(signal_pvi: PVIEntry, dtype: Optional[Type[T]] = None) -> Signal[T]: - """Make a signal. + sub_device: Union[DeviceVector, Device] = DeviceVector() + for sub_split_number, sub_entry in sub_entries.items(): + if not sub_entry.device: # If the entry is't a signal + if ( + sub_common_device + and get_origin(sub_common_device) == DeviceVector + ): + sub_common_device = get_args(sub_common_device)[0] + sub_entry.device = ( + sub_common_device() if sub_common_device else Device() + ) + self.initialize_device( + sub_entry, common_device=sub_common_device + ) + assert isinstance(sub_device, DeviceVector) + sub_device[sub_split_number] = sub_entry.device + else: + if isinstance(sub_entries, dict): + sub_device = sub_common_device() if sub_common_device else Device() + assert list(sub_entries) == [1] + sub_entries[1].device = sub_device + self.initialize_device( + sub_entries[1], common_device=sub_common_device + ) + else: + assert sub_entries.device + sub_device = sub_entries.device - This assumes datatype is None so it can be used to create dynamic signals. + setattr(entry.device, sub_name, sub_device) + + if common_device: + common_sub_devices = get_type_hints(common_device) + for sub_name, sub_device in common_sub_devices.items(): + if sub_name in ("_name", "parent"): + continue + if sub_name not in entry.sub_entries: + raise RuntimeError( + f"sub device `{sub_name}:{type(sub_device)}` was not provided" + " by pvi" + ) + + +def _strip_union(field: Union[Union[T], T]) -> T: + if get_origin(field) is Union: + args = get_args(field) + for arg in args: + if arg is not type(None): + return arg + + return field + + +def _strip_device_vector( + field: Union[DeviceVector[Device], Device] +) -> Tuple[bool, Device]: + if get_origin(field) is DeviceVector: + return True, get_args(field)[0] + return False, field + + +def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None): + + device_t = stripped_type or type(device) + for sub_name, sub_device_t in get_type_hints(device_t).items(): + if sub_name in ("_name", "parent"): + continue + + # we'll take the first type in the union which isn't NoneType + sub_device_t = _strip_union(sub_device_t) + is_device_vector, sub_device_t = _strip_device_vector(sub_device_t) + is_signal = (origin := get_origin(sub_device_t)) and issubclass(origin, Signal) + + if is_signal: + signal_type = get_args(sub_device_t)[0] + print("DEBUG: SIGNAL TYPE", signal_type) + print("DEBUG: SIGNAL ARGS", get_args(sub_device_t)) + sub_device = sub_device_t(SimSignalBackend(signal_type, sub_name)) + elif is_device_vector: + sub_device = DeviceVector( + { + 1: sub_device_t(name=f"{device.name}-{sub_name}-1"), + 2: sub_device_t(name=f"{device.name}-{sub_name}-2"), + } + ) + else: + sub_device = sub_device_t(name=f"{device.name}-{sub_name}") + + if not is_signal: + if is_device_vector: + for sub_device_in_vector in sub_device.values(): + _sim_common_blocks(sub_device_in_vector, stripped_type=sub_device_t) + else: + _sim_common_blocks(sub_device, stripped_type=sub_device_t) + + setattr(device, sub_name, sub_device) + + +async def fill_pvi_entries( + device: Device, root_pv: str, timeout=DEFAULT_TIMEOUT, sim=True +): """ - operations = frozenset(signal_pvi.keys()) - pvs = [signal_pvi[i] for i in operations] # type: ignore - signal_factory = _pvi_mapping[operations] + Fills a `device` with signals from a the `root_pvi:PVI` table. - write_pv = "pva://" + pvs[0] - read_pv = write_pv if len(pvs) < 2 else "pva://" + pvs[1] + If the device names match with parent devices of `device` then types are used. + """ + if not sim: + # check the pvi table for devices and fill the device with them + parser = PVIParser(root_pv, timeout=timeout) + await parser.get_pvi_entries() + parser.root_entry.device = device + parser.initialize_device(parser.root_entry, common_device=type(device)) - return signal_factory(dtype, read_pv, write_pv) + if sim: + # set up sim signals for the common annotations + _sim_common_blocks(device) diff --git a/src/ophyd_async/panda/__init__.py b/src/ophyd_async/panda/__init__.py index 4bae59ff1e..375aee8d29 100644 --- a/src/ophyd_async/panda/__init__.py +++ b/src/ophyd_async/panda/__init__.py @@ -1,4 +1,4 @@ -from .panda import PandA, PcapBlock, PulseBlock, PVIEntry, SeqBlock, SeqTable +from .panda import PandA, PcapBlock, PulseBlock, SeqBlock from .panda_controller import PandaPcapController from .table import ( SeqTable, @@ -13,7 +13,6 @@ "PandA", "PcapBlock", "PulseBlock", - "PVIEntry", "seq_table_from_arrays", "seq_table_from_rows", "SeqBlock", diff --git a/src/ophyd_async/panda/panda.py b/src/ophyd_async/panda/panda.py index 7971b667c2..fd1d4e1b1f 100644 --- a/src/ophyd_async/panda/panda.py +++ b/src/ophyd_async/panda/panda.py @@ -1,20 +1,7 @@ from __future__ import annotations -import re -from typing import Dict, Optional, Tuple, cast, get_args, get_origin, get_type_hints - -from ophyd_async.core import ( - DEFAULT_TIMEOUT, - Device, - DeviceVector, - Signal, - SignalBackend, - SignalR, - SignalRW, - SignalX, - SimSignalBackend, -) -from ophyd_async.epics.pvi import PVIEntry, make_signal, pvi_get +from ophyd_async.core import DEFAULT_TIMEOUT, Device, DeviceVector, SignalR, SignalRW +from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.panda.table import SeqTable @@ -33,41 +20,6 @@ class PcapBlock(Device): arm: SignalRW[bool] -def _block_name_number(block_name: str) -> Tuple[str, Optional[int]]: - """Maps a panda block name to a block and number. - - There are exceptions to this rule; some blocks like pcap do not contain numbers. - Other blocks may contain numbers and letters, but no numbers at the end. - - Such block names will only return the block name, and not a number. - - If this function returns both a block name and number, it should be instantiated - into a device vector.""" - m = re.match("^([0-9a-z_-]*)([0-9]+)$", block_name) - if m is not None: - name, num = m.groups() - return name, int(num or 1) # just to pass type checks. - - return block_name, None - - -def _remove_inconsistent_blocks(pvi_info: Optional[Dict[str, PVIEntry]]) -> None: - """Remove blocks from pvi information. - - This is needed because some pandas have 'pcap' and 'pcap1' blocks, which are - inconsistent with the assumption that pandas should only have a 'pcap' block, - for example. - - """ - if pvi_info is None: - return - pvi_keys = set(pvi_info.keys()) - for k in pvi_keys: - kn = re.sub(r"\d*$", "", k) - if kn and k != kn and kn in pvi_keys: - del pvi_info[k] - - class CommonPandABlocks(Device): pulse: DeviceVector[PulseBlock] seq: DeviceVector[SeqBlock] @@ -76,107 +28,8 @@ class CommonPandABlocks(Device): class PandA(CommonPandABlocks, Device): def __init__(self, prefix: str, name: str = "") -> None: - Device.__init__(self, name) self._prefix = prefix - - def verify_block(self, name: str, num: Optional[int]): - """Given a block name and number, return information about a block.""" - anno = get_type_hints(self, globalns=globals()).get(name) - - block: Device = Device() - - if anno: - type_args = get_args(anno) - block = type_args[0]() if type_args else anno() - - if not type_args: - assert num is None, f"Only expected one {name} block, got {num}" - - return block - - async def _make_block( - self, - name: str, - num: Optional[int], - block_pv: str, - sim: bool = False, - timeout: float = DEFAULT_TIMEOUT, - ): - """Makes a block given a block name containing relevant signals. - - Loops through the signals in the block (found using type hints), if not in - sim mode then does a pvi call, and identifies this signal from the pvi call. - """ - block = self.verify_block(name, num) - - field_annos = get_type_hints(block, globalns=globals()) - block_pvi = await pvi_get(block_pv, timeout=timeout) if not sim else None - - # finds which fields this class actually has, e.g. delay, width... - for sig_name, sig_type in field_annos.items(): - origin = get_origin(sig_type) - args = get_args(sig_type) - - # if not in sim mode, - if block_pvi: - # try to get this block in the pvi. - entry: Optional[PVIEntry] = block_pvi.get(sig_name) - if entry is None: - raise Exception( - f"{self.__class__.__name__} has a {name} block containing a/" - + f"an {sig_name} signal which has not been retrieved by PVI." - ) - - signal: Signal = make_signal(entry, args[0] if len(args) > 0 else None) - - else: - backend: SignalBackend = SimSignalBackend( - args[0] if len(args) > 0 else None, block_pv - ) - signal = SignalX(backend) if not origin else origin(backend) - - setattr(block, sig_name, signal) - - # checks for any extra pvi information not contained in this class - if block_pvi: - for attr, attr_pvi in block_pvi.items(): - if not hasattr(block, attr): - # makes any extra signals - setattr(block, attr, make_signal(attr_pvi)) - - return block - - async def _make_untyped_block( - self, block_pv: str, timeout: float = DEFAULT_TIMEOUT - ): - """Populates a block using PVI information. - - This block is not typed as part of the PandA interface but needs to be - included dynamically anyway. - """ - block = Device() - block_pvi: Dict[str, PVIEntry] = await pvi_get(block_pv, timeout=timeout) - - for signal_name, signal_pvi in block_pvi.items(): - setattr(block, signal_name, make_signal(signal_pvi)) - - return block - - # TODO redo to set_panda_block? confusing name - def set_attribute(self, name: str, num: Optional[int], block: Device): - """Set a block on the panda. - - Need to be able to set device vectors on the panda as well, e.g. if num is not - None, need to be able to make a new device vector and start populating it... - """ - anno = get_type_hints(self, globalns=globals()).get(name) - - # if it's an annotated device vector, or it isn't but we've got a number then - # make a DeviceVector on the class - if get_origin(anno) == DeviceVector or (not anno and num is not None): - self.__dict__.setdefault(name, DeviceVector())[num] = block - else: - setattr(self, name, block) + super().__init__(name) async def connect( self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT @@ -189,55 +42,7 @@ async def connect( If there's no pvi information, that's because we're in sim mode. In that case, makes all required blocks. """ - pvi_info = ( - await pvi_get(self._prefix + "PVI", timeout=timeout) if not sim else None - ) - _remove_inconsistent_blocks(pvi_info) - - hints = { - attr_name: attr_type - for attr_name, attr_type in get_type_hints(self, globalns=globals()).items() - if not attr_name.startswith("_") - } - - # create all the blocks pvi says it should have, - if pvi_info: - pvi_info = cast(Dict[str, PVIEntry], pvi_info) - for block_name, block_pvi in pvi_info.items(): - name, num = _block_name_number(block_name) - - if name in hints: - block = await self._make_block( - name, num, block_pvi["d"], timeout=timeout - ) - else: - block = await self._make_untyped_block( - block_pvi["d"], timeout=timeout - ) - - self.set_attribute(name, num, block) - - # then check if the ones defined in this class are in the pvi info - # make them if there is no pvi info, i.e. sim mode. - for block_name in hints.keys(): - if pvi_info is not None: - pvi_name = block_name - - if get_origin(hints[block_name]) == DeviceVector: - pvi_name += "1" - - entry: Optional[PVIEntry] = pvi_info.get(pvi_name) - - assert entry, f"Expected PandA to contain {block_name} block." - assert list(entry) == [ - "d" - ], f"Expected PandA to only contain blocks, got {entry}" - else: - num = 1 if get_origin(hints[block_name]) == DeviceVector else None - block = await self._make_block( - block_name, num, "sim://", sim=sim, timeout=timeout - ) - self.set_attribute(block_name, num, block) - - self.set_name(self.name) - await Device.connect(self, sim) + + await fill_pvi_entries(self, self._prefix + "PVI", timeout=timeout, sim=sim) + + await super().connect(sim) diff --git a/tests/epics/test_pvi.py b/tests/epics/test_pvi.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/panda/test_panda.py b/tests/panda/test_panda.py index 8fcf6836bd..3b485ab15b 100644 --- a/tests/panda/test_panda.py +++ b/tests/panda/test_panda.py @@ -8,8 +8,8 @@ from ophyd_async.core import DeviceCollector from ophyd_async.core.utils import NotConnected -from ophyd_async.panda import PandA, PVIEntry, SeqTable, SeqTrigger -from ophyd_async.panda.panda import _remove_inconsistent_blocks +from ophyd_async.epics.pvi import PVIEntry +from ophyd_async.panda import PandA, SeqTable, SeqTrigger class DummyDict: @@ -55,21 +55,6 @@ def test_panda_name_set(): assert panda.name == "panda" -async def test_inconsistent_blocks(): - dummy_pvi = { - "pcap": {}, - "pcap1": {}, - "pulse1": {}, - "pulse2": {}, - "sfp3_sync_out1": {}, - "sfp3_sync_out": {}, - } - - _remove_inconsistent_blocks(dummy_pvi) - assert "sfp3_sync_out1" not in dummy_pvi - assert "pcap1" not in dummy_pvi - - async def test_panda_children_connected(sim_panda: PandA): # try to set and retrieve from simulated values... table = SeqTable( @@ -108,8 +93,12 @@ async def test_panda_children_connected(sim_panda: PandA): async def test_panda_with_missing_blocks(pva): panda = PandA("PANDAQSRVI:") - with pytest.raises(AssertionError): + with pytest.raises(RuntimeError) as exc: await panda.connect() + assert ( + exc.value.args[0] + == "sub device `pcap:` was not provided by pvi" + ) async def test_panda_with_extra_blocks_and_signals(pva): diff --git a/tests/panda/test_panda_utils.py b/tests/panda/test_panda_utils.py index d9386d9812..7e1b8469b6 100644 --- a/tests/panda/test_panda_utils.py +++ b/tests/panda/test_panda_utils.py @@ -3,7 +3,7 @@ import pytest from bluesky import RunEngine -from ophyd_async.core import SignalRW, save_device +from ophyd_async.core import save_device from ophyd_async.core.device import DeviceCollector from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.panda import PandA @@ -14,7 +14,7 @@ async def sim_panda(): async with DeviceCollector(sim=True): sim_panda = PandA("PANDA") - sim_panda.phase_1_signal_units: SignalRW = epics_signal_rw(int, "") + sim_panda.phase_1_signal_units = epics_signal_rw(int, "") assert sim_panda.name == "sim_panda" yield sim_panda @@ -27,11 +27,15 @@ async def test_save_panda(mock_save_to_yaml, sim_panda, RE: RunEngine): [ {"phase_1_signal_units": 0}, { - "pcap.arm": 0.0, + "pcap.arm": False, "pulse.1.delay": 0.0, "pulse.1.width": 0.0, + "pulse.2.delay": 0.0, + "pulse.2.width": 0.0, "seq.1.table": {}, "seq.1.active": False, + "seq.2.table": {}, + "seq.2.active": False, }, ], "path",