From 92fa26bd19e5a3b4eff662457db21e8b61dc30ea Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 14 Mar 2024 13:23:03 +0000 Subject: [PATCH] Add explicit bool type --- src/ophyd_async/epics/_backend/_aioca.py | 19 ++++++++++++++++--- src/ophyd_async/epics/_backend/_p4p.py | 8 ++++---- tests/epics/test_signals.py | 9 +++++---- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/ophyd_async/epics/_backend/_aioca.py b/src/ophyd_async/epics/_backend/_aioca.py index 13e58d029a..86a79e48b4 100644 --- a/src/ophyd_async/epics/_backend/_aioca.py +++ b/src/ophyd_async/epics/_backend/_aioca.py @@ -47,7 +47,10 @@ def _data_key_from_augmented_value( - value: AugmentedValue, *, choices: Optional[List[str]] = None + value: AugmentedValue, + *, + choices: Optional[List[str]] = None, + dtype: Optional[str] = None, ) -> DataKey: """Use the return value of get with FORMAT_CTRL to construct a DataKey describing the signal. See docstring of AugmentedValue for expected @@ -65,7 +68,7 @@ def _data_key_from_augmented_value( assert value.ok, f"Error reading {source}: {value}" scalar = value.element_count == 1 - dtype = dbr_to_dtype[value.datatype] + dtype = dtype or dbr_to_dtype[value.datatype] d = DataKey( source=source, @@ -142,6 +145,16 @@ def descriptor(self, value: AugmentedValue) -> DataKey: return _data_key_from_augmented_value(value, choices=self.choices) +@dataclass +class CaBoolConverter(CaConverter): + + def value(self, value: AugmentedValue) -> bool: + return bool(value) + + def descriptor(self, value: AugmentedValue) -> DataKey: + return _data_key_from_augmented_value(value, dtype="bool") + + class DisconnectedCaConverter(CaConverter): def __getattribute__(self, __name: str) -> Any: raise NotImplementedError("No PV has been set as connect() has not been called") @@ -179,7 +192,7 @@ def make_converter( ) if pv_choices_len != 2: raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return CaConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) + return CaBoolConverter(dbr.DBR_SHORT, dbr.DBR_SHORT) elif pv_dbr == dbr.DBR_ENUM: # This is an Enum pv_choices = get_unique( diff --git a/src/ophyd_async/epics/_backend/_p4p.py b/src/ophyd_async/epics/_backend/_p4p.py index 2f2ceeb84f..17d040eb9f 100644 --- a/src/ophyd_async/epics/_backend/_p4p.py +++ b/src/ophyd_async/epics/_backend/_p4p.py @@ -133,12 +133,12 @@ def descriptor(self, source: str, value) -> Descriptor: return dict(source=source, dtype="string", shape=[], choices=self.choices) -class PvaEnumBoolConverter(PvaConverter): +class PvaBoolConverter(PvaConverter): def value(self, value): - return value["value"]["index"] + return bool(value["value"]["index"]) def descriptor(self, source: str, value) -> Descriptor: - return dict(source=source, dtype="integer", shape=[]) + return dict(source=source, dtype="bool", shape=[]) class PvaTableConverter(PvaConverter): @@ -216,7 +216,7 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve ) if pv_choices_len != 2: raise TypeError(f"{pv} has {pv_choices_len} choices, can't map to bool") - return PvaEnumBoolConverter() + return PvaBoolConverter() elif "NTEnum" in typeid: # This is an Enum pv_choices = get_unique( diff --git a/tests/epics/test_signals.py b/tests/epics/test_signals.py index b6eac6b5ca..dab48503ef 100644 --- a/tests/epics/test_signals.py +++ b/tests/epics/test_signals.py @@ -147,9 +147,8 @@ class MyEnum(str, Enum): _metadata: Dict[str, Dict[str, Any]] = { - "enum": {}, - "string": {}, "integer": {"units": ANY}, + "bool": {"units": ANY}, "number": {"units": ANY, "precision": ANY}, } @@ -159,8 +158,10 @@ def get_internal_dtype(suffix: str) -> str: # uint32, [u]int64 backed by DBR_DOUBLE, have precision if "float" in suffix or "uint32" in suffix or "int64" in suffix: return "number" - if "int" in suffix or "bool" in suffix: + if "int" in suffix: return "integer" + if "bool" in suffix: + return "bool" if "enum" in suffix: return "enum" return "string" @@ -179,7 +180,7 @@ def get_dtype(suffix: str) -> str: d["choices"] = [e.value for e in type(value)] if protocol == "ca": - d.update(_metadata[get_internal_dtype(suffix)]) + d.update(_metadata.get(get_internal_dtype(suffix), {})) return d