Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Signal type parsing in pvi logic #225

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 50 additions & 54 deletions src/ophyd_async/epics/pvi/pvi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from inspect import isclass
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Expand Down Expand Up @@ -44,6 +44,18 @@ def _strip_number_from_string(string: str) -> Tuple[str, Optional[int]]:
return name, number


def _split_subscript(tp: T) -> Union[Tuple[Any, Tuple[Any]], Tuple[T, None]]:
"""Split a subscripted type into the its origin and args.

If `tp` is not a subscripted type, then just return the type and None as args.

"""
if get_origin(tp) is not None:
return get_origin(tp), get_args(tp)

return tp, None


def _strip_union(field: Union[Union[T], T]) -> T:
if get_origin(field) is Union:
args = get_args(field)
Expand Down Expand Up @@ -115,86 +127,70 @@ def _parse_type(
):
if common_device_type:
# pre-defined type
device_type = _strip_union(common_device_type)
is_device_vector, device_type = _strip_device_vector(device_type)

if ((origin := get_origin(device_type)) and issubclass(origin, Signal)) or (
isclass(device_type) and issubclass(device_type, Signal)
):
# if device_type is of the form `Signal` or `Signal[type]`
is_signal = True
signal_dtype = get_args(device_type)[0]
else:
is_signal = False
signal_dtype = None
device_cls = _strip_union(common_device_type)
is_device_vector, device_cls = _strip_device_vector(device_cls)
device_cls, device_args = _split_subscript(device_cls)
assert issubclass(device_cls, Device)

is_signal = issubclass(device_cls, Signal)
signal_dtype = device_args[0] if device_args is not None else None

elif is_pvi_table:
# is a block, we can make it a DeviceVector if it ends in a number
is_device_vector = number_suffix is not None
is_signal = False
signal_dtype = None
device_type = Device
device_cls = Device
else:
# is a signal, signals aren't stored in DeviceVectors unless
# they're defined as such in the common_device_type
is_device_vector = False
is_signal = True
signal_dtype = None
device_type = Signal
device_cls = Signal

return is_device_vector, is_signal, signal_dtype, device_type
return is_device_vector, is_signal, signal_dtype, device_cls


def _sim_common_blocks(device: Device, stripped_type: Optional[Type] = None):
device_t = stripped_type or type(device)
for sub_name, sub_device_t in get_type_hints(device_t).items():
if sub_name in ("_name", "parent"):
continue
sub_devices = (
(field, field_type)
for field, field_type in get_type_hints(device_t).items()
if field not in ("_name", "parent")
)

for device_name, device_cls in sub_devices:
device_cls = _strip_union(device_cls)
is_device_vector, device_cls = _strip_device_vector(device_cls)
device_cls, device_args = _split_subscript(device_cls)
assert issubclass(device_cls, Device)

# we'll take the first type in the union which isn't NoneType
sub_device_t = _strip_union(sub_device_t)
is_device_vector, sub_device_t = _strip_device_vector(sub_device_t)
is_signal = (
(origin := get_origin(sub_device_t)) and issubclass(origin, Signal)
) or (issubclass(sub_device_t, Signal))
is_signal = issubclass(device_cls, Signal)
signal_dtype = device_args[0] if device_args is not None else None

# TODO: worth coming back to all this code once 3.9 is gone and we can use
# match statments: https://github.com/bluesky/ophyd-async/issues/180
if is_device_vector:
if is_signal:
signal_type = args[0] if (args := get_args(sub_device_t)) else None
sub_device_1 = sub_device_t(SimSignalBackend(signal_type, sub_name))
sub_device_2 = sub_device_t(SimSignalBackend(signal_type, sub_name))
sub_device = DeviceVector(
{
1: sub_device_1,
2: sub_device_2,
}
)
sub_device_1 = device_cls(SimSignalBackend(signal_dtype, device_name))
sub_device_2 = device_cls(SimSignalBackend(signal_dtype, device_name))
sub_device = DeviceVector({1: sub_device_1, 2: sub_device_2})
else:
sub_device = DeviceVector(
{
1: sub_device_t(),
2: sub_device_t(),
}
)
sub_device = DeviceVector({1: device_cls(), 2: device_cls()})

for sub_device_in_vector in sub_device.values():
_sim_common_blocks(sub_device_in_vector, stripped_type=device_cls)

for value in sub_device.values():
value.parent = sub_device

elif is_signal:
signal_type = args[0] if (args := get_args(sub_device_t)) else None
sub_device = sub_device_t(SimSignalBackend(signal_type, sub_name))
else:
sub_device = sub_device_t()

if not is_signal:
if is_device_vector:
for sub_device_in_vector in sub_device.values():
_sim_common_blocks(sub_device_in_vector, stripped_type=sub_device_t)
if is_signal:
sub_device = device_cls(SimSignalBackend(signal_dtype, device_name))
else:
_sim_common_blocks(sub_device, stripped_type=sub_device_t)
sub_device = device_cls()

_sim_common_blocks(sub_device, stripped_type=device_cls)

setattr(device, sub_name, sub_device)
setattr(device, device_name, sub_device)
sub_device.parent = device


Expand Down
Loading