Skip to content

Commit

Permalink
added RuntimeEnum class (with metaclass), and added logic to `make_…
Browse files Browse the repository at this point in the history
…converter`s to support them
  • Loading branch information
evalott100 committed Jun 4, 2024
1 parent b2b8bba commit 02522cf
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 60 deletions.
90 changes: 46 additions & 44 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
soft_signal_rw,
wait_for_value,
)
from .signal_backend import SignalBackend
from .signal_backend import RuntimeEnum, SignalBackend
from .soft_signal_backend import SoftSignalBackend
from .standard_readable import ConfigSignal, HintedSignal, StandardReadable
from .utils import (
Expand All @@ -68,66 +68,68 @@
)

__all__ = [
"get_mock_put",
"callback_on_mock_put",
"mock_puts_blocked",
"set_mock_values",
"reset_mock_put_calls",
"SignalBackend",
"SoftSignalBackend",
"AsyncStatus",
"CalculatableTimeout",
"CalculateTimeout",
"Callback",
"ConfigSignal",
"DEFAULT_TIMEOUT",
"DetectorControl",
"MockSignalBackend",
"DetectorTrigger",
"DetectorWriter",
"StandardDetector",
"Device",
"DeviceCollector",
"DeviceVector",
"Signal",
"SignalR",
"SignalW",
"SignalRW",
"SignalX",
"soft_signal_r_and_setter",
"soft_signal_rw",
"observe_value",
"set_and_wait_for_value",
"set_mock_put_proceeds",
"set_mock_value",
"wait_for_value",
"AsyncStatus",
"WatchableAsyncStatus",
"DirectoryInfo",
"DirectoryProvider",
"HardwareTriggeredFlyable",
"HintedSignal",
"MockSignalBackend",
"NameProvider",
"NotConnected",
"ReadingValueCallback",
"RuntimeEnum",
"ShapeProvider",
"StaticDirectoryProvider",
"Signal",
"SignalBackend",
"SignalR",
"SignalRW",
"SignalW",
"SignalX",
"SoftSignalBackend",
"StandardDetector",
"StandardReadable",
"ConfigSignal",
"HintedSignal",
"StaticDirectoryProvider",
"T",
"TriggerInfo",
"TriggerLogic",
"HardwareTriggeredFlyable",
"CalculateTimeout",
"CalculatableTimeout",
"DEFAULT_TIMEOUT",
"Callback",
"NotConnected",
"ReadingValueCallback",
"T",
"WatchableAsyncStatus",
"assert_configuration",
"assert_emitted",
"assert_mock_put_called_with",
"assert_reading",
"assert_value",
"callback_on_mock_put",
"get_dtype",
"get_unique",
"merge_gathered_dicts",
"wait_for_connection",
"get_mock_put",
"get_signal_values",
"get_unique",
"load_device",
"load_from_yaml",
"merge_gathered_dicts",
"mock_puts_blocked",
"observe_value",
"reset_mock_put_calls",
"save_device",
"save_to_yaml",
"set_and_wait_for_value",
"set_mock_put_proceeds",
"set_mock_value",
"set_mock_values",
"set_signal_values",
"soft_signal_r_and_setter",
"soft_signal_rw",
"wait_for_connection",
"wait_for_value",
"walk_rw_signals",
"load_device",
"save_device",
"assert_reading",
"assert_value",
"assert_configuration",
"assert_emitted",
]
60 changes: 59 additions & 1 deletion src/ophyd_async/core/signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from abc import abstractmethod
from typing import Generic, Optional, Type
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Dict,
FrozenSet,
Generic,
Literal,
Optional,
Tuple,
Type,
)

from bluesky.protocols import DataKey, Reading

Expand Down Expand Up @@ -45,3 +55,51 @@ async def get_setpoint(self) -> T:
@abstractmethod
def set_callback(self, callback: Optional[ReadingValueCallback[T]]) -> None:
"""Observe changes to the current value, timestamp and severity"""


if TYPE_CHECKING:
RuntimeEnum = Literal
else:

class _RuntimeEnumMeta(type):
# Intentionally immutable class variable
__enum_classes_created: Dict[FrozenSet[str], Type["RuntimeEnum"]] = {}

def __str__(cls):
if hasattr(cls, "_choices"):
return f"RuntimeEnum{list(cls._choices.keys())}"
return "RuntimeEnum"

@property
def choices(cls) -> Tuple[str]:
return tuple(cls._choices.keys())

def __getitem__(cls, choices):
if isinstance(choices, str):
choices = (choices,)
else:
if not isinstance(choices, tuple) or not all(
isinstance(c, str) for c in choices
):
raise TypeError(
f"Choices must be a str or a tuple of str, not {choices}."
)
if len(set(choices)) != len(choices):
raise TypeError("Duplicate elements in runtime enum choices.")

choices_frozenset = frozenset(choices)

# If the enum has already been created, return it (ignoring order)
if choices_frozenset in _RuntimeEnumMeta.__enum_classes_created:
return _RuntimeEnumMeta.__enum_classes_created[choices_frozenset]

# Create a new enum subclass
class _RuntimeEnum(cls):
_choices = {choice: choice for choice in choices}

_RuntimeEnumMeta.__enum_classes_created[choices_frozenset] = _RuntimeEnum
return _RuntimeEnum

class RuntimeEnum(metaclass=_RuntimeEnumMeta):
def __init__(self):
raise RuntimeError("RuntimeEnum cannot be instantiated")
35 changes: 24 additions & 11 deletions src/ophyd_async/core/soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from collections import abc
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, Optional, Type, Union, cast, get_origin
from typing import Any, Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin

import numpy as np
from bluesky.protocols import DataKey, Dtype, Reading

from .signal_backend import SignalBackend
from .signal_backend import RuntimeEnum, SignalBackend
from .utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype

primitive_dtypes: Dict[type, Dtype] = {
Expand Down Expand Up @@ -68,31 +68,44 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
return cast(T, datatype(shape=0)) # type: ignore


@dataclass
class SoftEnumConverter(SoftConverter):
enum_class: Type[Enum]
choices: Tuple[str, ...]

def write_value(self, value: Union[Enum, str]) -> Enum:
def __init__(self, datatype: Union[RuntimeEnum, Enum]):
if issubclass(datatype, Enum):
self.choices = tuple(v.value for v in datatype)
else:
self.choices = datatype.choices

def write_value(self, value: Union[Enum, str]) -> str:
if isinstance(value, Enum):
return value.value
else: # Runtime enum
return value
else:
return self.enum_class(value)

def get_datakey(self, source: str, value) -> DataKey:
choices = [e.value for e in self.enum_class]
return {"source": source, "dtype": "string", "shape": [], "choices": choices} # type: ignore
return {
"source": source,
"dtype": "string",
"shape": [],
"choices": self.choices,
} # type: ignore

def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
if datatype is None:
return cast(T, None)

return cast(T, list(datatype.__members__.values())[0]) # type: ignore
if issubclass(datatype, Enum):
return cast(T, list(datatype.__members__.values())[0]) # type: ignore
return cast(T, datatype.choices[0]) # type: ignore


def make_converter(datatype):
is_array = get_dtype(datatype) is not None
is_sequence = get_origin(datatype) == abc.Sequence
is_enum = issubclass(datatype, Enum) if inspect.isclass(datatype) else False
is_enum = inspect.isclass(datatype) and (
issubclass(datatype, Enum) or issubclass(datatype, RuntimeEnum)
)

if is_array or is_sequence:
return SoftArrayConverter()
Expand Down
8 changes: 6 additions & 2 deletions src/ophyd_async/epics/_backend/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum
from typing import Dict, Optional, Tuple, Type

from ophyd_async.core.signal_backend import RuntimeEnum


def get_supported_values(
pv: str,
Expand All @@ -10,9 +12,11 @@ def get_supported_values(
if not datatype:
return {x: x or "_" for x in pv_choices}

if not issubclass(datatype, str):
if issubclass(datatype, RuntimeEnum):
pv_choices = datatype.choices
elif not issubclass(datatype, str):
raise TypeError(f"{pv} is type Enum but doesn't inherit from String")
if issubclass(datatype, Enum):
elif issubclass(datatype, Enum):
choices = tuple(v.value for v in datatype)
if set(choices) != set(pv_choices):
raise TypeError(
Expand Down
90 changes: 90 additions & 0 deletions tests/core/test_runtime_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
from epicscorelibs.ca import dbr
from p4p import Value as P4PValue
from p4p.nt import NTEnum

from ophyd_async.core import RuntimeEnum
from ophyd_async.epics._backend._aioca import make_converter as aioca_make_converter
from ophyd_async.epics._backend._p4p import make_converter as p4p_make_converter
from ophyd_async.epics.signal.signal import epics_signal_rw


async def test_runtime_enum_behaviour():
rt_enum = RuntimeEnum["A", "B"]

with pytest.raises(RuntimeError) as exc:
rt_enum()
assert str(exc.value) == "RuntimeEnum cannot be instantiated"

assert issubclass(rt_enum, RuntimeEnum)
assert issubclass(rt_enum, RuntimeEnum["A", "B"])
assert issubclass(rt_enum, RuntimeEnum["B", "A"])

rt_enum_reversed_args = RuntimeEnum["B", "A"]
assert rt_enum == rt_enum_reversed_args
assert str(rt_enum) == "RuntimeEnum['A', 'B']"
assert str(RuntimeEnum) == "RuntimeEnum"

# The order of the choices is not important
assert str(rt_enum_reversed_args) == "RuntimeEnum['A', 'B']"

with pytest.raises(TypeError) as exc:
RuntimeEnum["A", "B", "A"]
assert str(exc.value) == "Duplicate elements in runtime enum choices."


async def test_ca_runtime_enum_converter():
class EpicsValue:
def __init__(self):
self.name = "test"
self.ok = (True,)
self.errorcode = 0
self.datatype = dbr.DBR_ENUM
self.element_count = 1
self.severity = 0
self.status = 0
self.raw_stamp = (0,)
self.timestamp = 0
self.datetime = 0
self.enums = ["A", "B", "C"] # More than the runtime enum

epics_value = EpicsValue()
rt_enum = RuntimeEnum["A", "B"]
converter = aioca_make_converter(
rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value}
)
assert converter.choices == {"A": "A", "B": "B"}


async def test_pva_runtime_enum_converter():
enum_type = NTEnum.buildType()
epics_value = P4PValue(
enum_type,
{
"value.choices": ["A", "B", "C"],
},
)
rt_enum = RuntimeEnum["A", "B"]
converter = p4p_make_converter(
rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value}
)
assert converter.choices == ("A", "B")


async def test_runtime_enum_signal():
signal_rw_pva = epics_signal_rw(
RuntimeEnum["A1", "B1"], "ca://RW_PV", name="signal"
)
signal_rw_ca = epics_signal_rw(RuntimeEnum["A2", "B2"], "ca://RW_PV", name="signal")
await signal_rw_pva.connect(mock=True)
await signal_rw_ca.connect(mock=True)
await signal_rw_pva.get_value() == "A1"
await signal_rw_ca.get_value() == "A2"
await signal_rw_pva.set("B1")
await signal_rw_ca.set("B2")
await signal_rw_pva.get_value() == "B1"
await signal_rw_ca.get_value() == "B2"

# Will accept string values even if they're not in the runtime enum
await signal_rw_pva.set("C1")
await signal_rw_ca.set("C2")
2 changes: 1 addition & 1 deletion tests/core/test_soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def string_d(value):


def enum_d(value):
return {"dtype": "string", "shape": [], "choices": ["Aaa", "Bbb", "Ccc"]}
return {"dtype": "string", "shape": [], "choices": ("Aaa", "Bbb", "Ccc")}


def waveform_d(value):
Expand Down
2 changes: 1 addition & 1 deletion tests/epics/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ async def test_read_sensor(mock_sensor: demo.Sensor):
] == demo.EnergyMode.low
desc = (await mock_sensor.describe_configuration())["mock_sensor-mode"]
assert desc["dtype"] == "string"
assert desc["choices"] == ["Low Energy", "High Energy"] # type: ignore
assert desc["choices"] == ("Low Energy", "High Energy") # type: ignore
set_mock_value(mock_sensor.mode, demo.EnergyMode.high)
assert (await mock_sensor.read_configuration())["mock_sensor-mode"][
"value"
Expand Down

0 comments on commit 02522cf

Please sign in to comment.