From 87499636ddd8e6e24c2cea5afe07f10a40c2aa2d Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Mon, 4 Mar 2024 17:59:57 +0000 Subject: [PATCH 1/8] Restructure load_unitconv and update tests Allow loading of other UnitConv types without poly or pchip .csv datafiles Allow users to set limits on the UnitConv object added to fields by default Update existing tests and add new tests to cover restructured load_unitconv --- src/pytac/load_csv.py | 112 +++++++++++++++++++++------------- tests/conftest.py | 16 +++++ tests/data/dummy/unitconv.csv | 3 + tests/test_load.py | 62 ++++++++++++++++++- tests/test_machine.py | 9 +++ 5 files changed, 158 insertions(+), 44 deletions(-) create mode 100644 tests/data/dummy/unitconv.csv diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index 6c7eba09..93a54bb7 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -13,19 +13,17 @@ import contextlib import copy import csv +import logging from pathlib import Path from typing import Dict, Iterator import pytac from pytac import data_source, element, utils from pytac.device import EpicsDevice, SimpleDevice -from pytac.exceptions import ControlSystemException +from pytac.exceptions import ControlSystemException, UnitsException from pytac.lattice import EpicsLattice, Lattice from pytac.units import NullUnitConv, PchipUnitConv, PolyUnitConv, UnitConv -# Create a default unit conversion object that returns the input unchanged. -DEFAULT_UC = NullUnitConv() - ELEMENTS_FILENAME = "elements.csv" EPICS_DEVICES_FILENAME = "epics_devices.csv" SIMPLE_DEVICES_FILENAME = "simple_devices.csv" @@ -86,6 +84,52 @@ def load_pchip_unitconv(filepath: Path) -> Dict[int, PchipUnitConv]: return unitconvs +def resolve_unitconv( + uc_params: Dict, unitconvs: Dict, polyconv_file: Path, pchipconv_file: Path +): + """Create a unit conversion object based on the dictionary of parameters passed. + + Args: + uc_params (Dict): A dictionary of parameters specifying the unit conversion + object's properties. + unitconvs (Dict): A dictionary of all loaded unit conversion objects. + polyconv_file (Path): The path to the .csv file from which all PolyUnitConv + objects are loaded. + pchipconv_file (Path): The path to the .csv file from which all PchipUnitConv + objects are loaded. + Returns: + UnitConv: The unit conversion object as specified by uc_params. + + Raises: + UnitsException: if the "uc_id" given in uc_params isn't in the unitconvs Dict. + """ + error_msg = ( + f"Unable to resolve {uc_params['uc_type']} unit conversion with ID " + f"{uc_params['uc_id']}, " + ) + if uc_params["uc_type"] == "null": + uc = NullUnitConv(uc_params["eng_units"], uc_params["phys_units"]) + else: + # Each element needs its own UnitConv object as it may have different limits. + try: + uc = copy.copy(unitconvs[int(uc_params["uc_id"])]) + except KeyError: + if uc_params["uc_type"] == "poly" and not polyconv_file.exists(): + raise UnitsException(error_msg + f"{polyconv_file} not found.") + elif uc_params["uc_type"] == "pchip" and not pchipconv_file.exists(): + raise UnitsException(error_msg + f"{pchipconv_file} not found.") + else: + raise UnitsException(error_msg + "unrecognised UnitConv type.") + uc.phys_units = uc_params["phys_units"] + uc.eng_units = uc_params["eng_units"] + lower, upper = [ + float(lim) if lim != "" else None + for lim in [uc_params["lower_lim"], uc_params["upper_lim"]] + ] + uc.set_conversion_limits(lower, upper) + return uc + + def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: """Load the unit conversion objects from a file. @@ -95,52 +139,34 @@ def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: """ unitconvs: Dict[int, UnitConv] = {} # Assemble datasets from the polynomial file - unitconvs.update(load_poly_unitconv(mode_dir / POLY_FILENAME)) + polyconv_file = mode_dir / POLY_FILENAME + if polyconv_file.exists(): + unitconvs.update(load_poly_unitconv(polyconv_file)) + else: + logging.warning(f"{polyconv_file} not found, unable to load PolyUnitConvs.") # Assemble datasets from the pchip file - unitconvs.update(load_pchip_unitconv(mode_dir / PCHIP_FILENAME)) + pchipconv_file = mode_dir / PCHIP_FILENAME + if pchipconv_file.exists(): + unitconvs.update(load_pchip_unitconv(pchipconv_file)) + else: + logging.warning(f"{pchipconv_file} not found, unable to load PchipUnitConvs.") # Add the unitconv objects to the elements with csv_loader(mode_dir / UNITCONV_FILENAME) as csv_reader: for item in csv_reader: + uc = resolve_unitconv(item, unitconvs, polyconv_file, pchipconv_file) # Special case for element 0: the lattice itself. if int(item["el_id"]) == 0: - if item["uc_type"] != "null": - # Each element needs its own unitconv object as - # it may for example have different limit. - uc = copy.copy(unitconvs[int(item["uc_id"])]) - uc.phys_units = item["phys_units"] - uc.eng_units = item["eng_units"] - upper, lower = ( - float(lim) if lim != "" else None - for lim in [item["upper_lim"], item["lower_lim"]] - ) - uc.set_conversion_limits(lower, upper) - else: - uc = NullUnitConv(item["eng_units"], item["phys_units"]) lattice.set_unitconv(item["field"], uc) else: element = lattice[int(item["el_id"]) - 1] # For certain magnet types, we need an additional rigidity # conversion factor as well as the raw conversion. - if item["uc_type"] == "null": - uc = NullUnitConv(item["eng_units"], item["phys_units"]) - else: - # Each element needs its own unitconv object as - # it may for example have different limit. - uc = copy.copy(unitconvs[int(item["uc_id"])]) - if any( - element.is_in_family(f) - for f in ("HSTR", "VSTR", "Quadrupole", "Sextupole", "Bend") - ): - energy = lattice.get_value("energy", units=pytac.PHYS) - uc.set_post_eng_to_phys(utils.get_div_rigidity(energy)) - uc.set_pre_phys_to_eng(utils.get_mult_rigidity(energy)) - uc.phys_units = item["phys_units"] - uc.eng_units = item["eng_units"] - upper, lower = ( - float(lim) if lim != "" else None - for lim in [item["upper_lim"], item["lower_lim"]] - ) - uc.set_conversion_limits(lower, upper) + # TODO: This should probably be moved into the .csv files somewhere. + rigidity_families = {"hstr", "vstr", "quadrupole", "sextupole", "bend"} + if item["uc_type"] != "null" and element._families & rigidity_families: + energy = lattice.get_value("energy", units=pytac.PHYS) + uc.set_post_eng_to_phys(utils.get_div_rigidity(energy)) + uc.set_pre_phys_to_eng(utils.get_mult_rigidity(energy)) element.set_unitconv(item["field"], uc) @@ -173,9 +199,8 @@ def load(mode, control_system=None, directory=None, symmetry=None): control_system = cothread_cs.CothreadControlSystem() except ImportError: raise ControlSystemException( - "Please install cothread to load a " - "lattice using the default control " - "system (found in cothread_cs.py)." + "Please install cothread to load a lattice using the default control system" + " (found in cothread_cs.py)." ) if directory is None: directory = Path(__file__).resolve().parent / "data" @@ -199,7 +224,8 @@ def load(mode, control_system=None, directory=None, symmetry=None): d = EpicsDevice(name, control_system, pve, get_pv, set_pv) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] - target.add_device(item["field"], d, DEFAULT_UC) + # Create with a default UnitConv object that returns the input unchanged. + target.add_device(item["field"], d, NullUnitConv()) # Add basic devices to the lattice. positions = [] for elem in lat: diff --git a/tests/conftest.py b/tests/conftest.py index 1438b8eb..2db45bd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import sys import types +from pathlib import Path from unittest import mock import pytest @@ -178,3 +179,18 @@ def simple_epics_lattice(simple_epics_element, mock_cs, unit_uc): lat.add_device("x", x_device, unit_uc) lat.add_device("y", y_device, unit_uc) return lat + + +@pytest.fixture +def mode_dir(): + return Path(__file__).resolve().parent / "data/dummy" + + +@pytest.fixture +def polyconv_file(mode_dir): + return mode_dir / load_csv.POLY_FILENAME + + +@pytest.fixture +def pchipconv_file(mode_dir): + return mode_dir / load_csv.PCHIP_FILENAME diff --git a/tests/data/dummy/unitconv.csv b/tests/data/dummy/unitconv.csv new file mode 100644 index 00000000..7b91c79b --- /dev/null +++ b/tests/data/dummy/unitconv.csv @@ -0,0 +1,3 @@ +el_id,field,uc_type,uc_id,phys_units,eng_units,lower_lim,upper_lim +2,b1,null,1,m^-2,A,0,200 +4,b2,null,2,m^-3,A,-100,100 diff --git a/tests/test_load.py b/tests/test_load.py index da7208e3..b7d91381 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,9 +1,10 @@ from unittest.mock import patch import pytest +from testfixtures import LogCapture import pytac -from pytac.load_csv import load +from pytac.load_csv import load, load_unitconv, resolve_unitconv @pytest.fixture @@ -70,3 +71,62 @@ def test_families_loaded(lattice): ["drift", "sext", "quad", "ds", "qf", "qs", "sd"] ) assert lattice.get_elements("quad")[0].families == set(["quad", "qf", "qs"]) + + +def test_load_unitconv_warns_if_pchip_or_poly_data_file_not_found( + lattice, mode_dir, polyconv_file, pchipconv_file +): + with LogCapture() as log: + load_unitconv(mode_dir, lattice) + log.check( + ( + "root", + "WARNING", + f"{polyconv_file} not found, unable to load PolyUnitConvs.", + ), + ( + "root", + "WARNING", + f"{pchipconv_file} not found, unable to load PchipUnitConvs.", + ), + ) + + +def test_resolve_unitconv_raises_UnitsException_if_pchip_or_poly_data_file_not_found( + polyconv_file, pchipconv_file +): + uc_params = { + "uc_type": "poly", + "uc_id": 1, + "phys_units": "m^-2", + "eng_units": "A", + "lower_lim": 0, + "upper_lim": 200, + } + with pytest.raises(pytac.exceptions.UnitsException): + resolve_unitconv(uc_params, {}, polyconv_file, pchipconv_file) + uc_params = { + "uc_type": "pchip", + "uc_id": 2, + "phys_units": "m^-3", + "eng_units": "A", + "lower_lim": -100, + "upper_lim": 100, + } + with pytest.raises(pytac.exceptions.UnitsException): + resolve_unitconv(uc_params, {}, polyconv_file, pchipconv_file) + + +def test_resolve_unitconv_raises_UnitsException_if_unrecognised_UnitConv_type( + polyconv_file, pchipconv_file +): + uc_params = { + "uc_type": "unrecognised", + "uc_id": 0, + "phys_units": "", + "eng_units": "", + "lower_lim": 0, + "upper_lim": 0, + } + with pytest.raises(pytac.exceptions.UnitsException): + resolve_unitconv(uc_params, {}, polyconv_file, pchipconv_file) diff --git a/tests/test_machine.py b/tests/test_machine.py index 933306e4..99fd0753 100644 --- a/tests/test_machine.py +++ b/tests/test_machine.py @@ -171,6 +171,15 @@ def test_bpm_unitconv(lattice, field): assert uc.phys_to_eng(2) == 2000 +def test_hstr_unitconv(vmx_ring): + # From MML: hw2physics('HTRIM', 'Monitor', 2.5, [1]) + htrim = vmx_ring.get_elements("HTRIM")[0] + # This test depends on the lattice having an energy of 3000Mev. + uc = htrim._data_source_manager._uc["x_kick"] + numpy.testing.assert_allclose(uc.eng_to_phys(2.5), 0.0001925) + numpy.testing.assert_allclose(uc.phys_to_eng(0.0001925), 2.5) + + def test_quad_unitconv(vmx_ring): # From MML: hw2physics('Q1D', 'Monitor', 70, [1]) q1d = vmx_ring.get_elements("Q1D") From fc63f4d4faddeaf4515ac521e033ff1a792c7de5 Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Mon, 4 Mar 2024 18:27:36 +0000 Subject: [PATCH 2/8] Move utils directory into src/pytac/data --- {utils => src/pytac/data/utils}/load_mml.m | 0 {utils => src/pytac/data/utils}/load_unitconv.m | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {utils => src/pytac/data/utils}/load_mml.m (100%) rename {utils => src/pytac/data/utils}/load_unitconv.m (100%) diff --git a/utils/load_mml.m b/src/pytac/data/utils/load_mml.m similarity index 100% rename from utils/load_mml.m rename to src/pytac/data/utils/load_mml.m diff --git a/utils/load_unitconv.m b/src/pytac/data/utils/load_unitconv.m similarity index 100% rename from utils/load_unitconv.m rename to src/pytac/data/utils/load_unitconv.m From b01cd93b0ed28308e5f39df90f4e4d494fcb8507 Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Mon, 4 Mar 2024 18:36:47 +0000 Subject: [PATCH 3/8] Update relative filepaths in matlab utils scripts --- src/pytac/data/utils/load_mml.m | 2 +- src/pytac/data/utils/load_unitconv.m | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/pytac/data/utils/load_mml.m b/src/pytac/data/utils/load_mml.m index 2394be94..e701c2fc 100644 --- a/src/pytac/data/utils/load_mml.m +++ b/src/pytac/data/utils/load_mml.m @@ -18,7 +18,7 @@ function load_mml(ringmode) dir = fileparts(mfilename('fullpath')); cd(dir); - datadir = fullfile(dir, '..', 'pytac', 'data', ringmode); + datadir = fullfile(dir, '..', ringmode); if ~exist(datadir, 'dir') fprintf('Data directory %s does not exist. Please create it.\n', datadir); fprintf('Script will exit.\n'); diff --git a/src/pytac/data/utils/load_unitconv.m b/src/pytac/data/utils/load_unitconv.m index d40addcb..d201e0b0 100644 --- a/src/pytac/data/utils/load_unitconv.m +++ b/src/pytac/data/utils/load_unitconv.m @@ -1,9 +1,17 @@ function load_unitconv(ringmode, renamedIndexes) dir = fileparts(mfilename('fullpath')); cd(dir); -units_file = fullfile(dir, '..', 'pytac', 'data', ringmode, 'unitconv.csv'); -poly_file = fullfile(dir, '..', 'pytac', 'data', ringmode, 'uc_poly_data.csv'); -pchip_file = fullfile(dir, '..', 'pytac', 'data', ringmode, 'uc_pchip_data.csv'); +datadir = fullfile(dir, '..', ringmode); +if ~exist(datadir, 'dir') + fprintf('Data directory %s does not exist. Please create it.\n', datadir); + fprintf('Script will exit.\n'); + return; +end + +% Open the CSV files that store the Pytac data. +units_file = fullfile(datadir, 'unitconv.csv'); +poly_file = fullfile(datadir, 'uc_poly_data.csv'); +pchip_file = fullfile(datadir, 'uc_pchip_data.csv'); fprintf('Loading unit conversions...\n'); From 5179766fff8388541dd9b1866eac9ff53c9c6dba Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Mon, 4 Mar 2024 19:03:16 +0000 Subject: [PATCH 4/8] Correct CI liniting failures --- src/pytac/data_source.py | 1 + src/pytac/device.py | 1 + src/pytac/element.py | 1 + src/pytac/lattice.py | 1 + src/pytac/load_csv.py | 1 + src/pytac/units.py | 1 + src/pytac/utils.py | 1 + tests/conftest.py | 3 +-- tests/test_cothread_cs.py | 1 + tests/test_machine.py | 1 + 10 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/pytac/data_source.py b/src/pytac/data_source.py index 53949d73..ec75fa74 100644 --- a/src/pytac/data_source.py +++ b/src/pytac/data_source.py @@ -1,4 +1,5 @@ """Module containing pytac data source classes.""" + import pytac from pytac.exceptions import DataSourceException, FieldException diff --git a/src/pytac/device.py b/src/pytac/device.py index f385f39b..c139f0f8 100644 --- a/src/pytac/device.py +++ b/src/pytac/device.py @@ -5,6 +5,7 @@ DLS is a sextupole magnet that contains also horizontal and vertical corrector magnets and a skew quadrupole. """ + from typing import List, Union import pytac diff --git a/src/pytac/element.py b/src/pytac/element.py index a3632298..b82c0440 100644 --- a/src/pytac/element.py +++ b/src/pytac/element.py @@ -1,4 +1,5 @@ """Module containing the element class.""" + import pytac from pytac.data_source import DataSource, DataSourceManager from pytac.exceptions import DataSourceException, FieldException diff --git a/src/pytac/lattice.py b/src/pytac/lattice.py index 86e0c949..033ca142 100644 --- a/src/pytac/lattice.py +++ b/src/pytac/lattice.py @@ -1,6 +1,7 @@ """Representation of a lattice object which contains all the elements of the machine. """ + import logging from typing import List, Optional diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index 93a54bb7..f47022fb 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -9,6 +9,7 @@ * uc_poly_data.csv * uc_pchip_data.csv """ + import collections import contextlib import copy diff --git a/src/pytac/units.py b/src/pytac/units.py index fc8ac1fe..93f9e6f3 100644 --- a/src/pytac/units.py +++ b/src/pytac/units.py @@ -1,4 +1,5 @@ """Classes for use in unit conversion.""" + import numpy from scipy.interpolate import PchipInterpolator diff --git a/src/pytac/utils.py b/src/pytac/utils.py index bccab0ba..251f83ba 100644 --- a/src/pytac/utils.py +++ b/src/pytac/utils.py @@ -1,4 +1,5 @@ """Utility functions.""" + import math import scipy.constants diff --git a/tests/conftest.py b/tests/conftest.py index 2db45bd2..d5097151 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import sys import types -from pathlib import Path from unittest import mock import pytest @@ -183,7 +182,7 @@ def simple_epics_lattice(simple_epics_element, mock_cs, unit_uc): @pytest.fixture def mode_dir(): - return Path(__file__).resolve().parent / "data/dummy" + return CURRENT_DIR_PATH / "data/dummy" @pytest.fixture diff --git a/tests/test_cothread_cs.py b/tests/test_cothread_cs.py index 999a8a69..ca8b6ac7 100644 --- a/tests/test_cothread_cs.py +++ b/tests/test_cothread_cs.py @@ -4,6 +4,7 @@ See pytest_sessionstart() in conftest.py for more. """ + import pytest from constants import RB_PV, SP_PV from cothread.catools import ca_nothing, caget, caput diff --git a/tests/test_machine.py b/tests/test_machine.py index 99fd0753..22727e29 100644 --- a/tests/test_machine.py +++ b/tests/test_machine.py @@ -2,6 +2,7 @@ files in the data directory. These are more like integration tests, and allows us to check that the pytac setup is working correctly. """ + import re from unittest import mock From c32f154035876b75342e39f911e5dd2513ca0695 Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Mon, 4 Mar 2024 19:54:14 +0000 Subject: [PATCH 5/8] Stop using pytest-lazy-fixture As it does not work with pytest>8.0.0 --- pyproject.toml | 1 - tests/test_data_source.py | 61 ++++++++++--------------------- tests/test_machine.py | 75 ++++++++++++++++----------------------- 3 files changed, 48 insertions(+), 89 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3de08a88..9ce04cf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dev = [ "pydata-sphinx-theme>=0.12", "pytest", "pytest-cov", - "pytest-lazy-fixture", "sphinx-autobuild", "sphinx-copybutton", "sphinx-design", diff --git a/tests/test_data_source.py b/tests/test_data_source.py index e73fb4df..45532e1c 100644 --- a/tests/test_data_source.py +++ b/tests/test_data_source.py @@ -1,82 +1,57 @@ import pytest from constants import DUMMY_VALUE_2 -from pytest_lazyfixture import lazy_fixture import pytac @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_device(simple_object, y_device): +def test_get_device(simple_object, y_device, request): + simple_object = request.getfixturevalue(simple_object) assert simple_object.get_device("y") == y_device @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_unitconv(simple_object, unit_uc): +def test_get_unitconv(simple_object, unit_uc, request): + simple_object = request.getfixturevalue(simple_object) assert simple_object.get_unitconv("x") == unit_uc @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_fields(simple_object): +def test_get_fields(simple_object, request): + simple_object = request.getfixturevalue(simple_object) fields = simple_object.get_fields()[pytac.LIVE] assert set(fields) == {"x", "y"} @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_set_value(simple_object): +def test_set_value(simple_object, request): + simple_object = request.getfixturevalue(simple_object) simple_object.set_value("x", DUMMY_VALUE_2, pytac.ENG, pytac.LIVE) simple_object.get_device("x").set_value.assert_called_with(DUMMY_VALUE_2, True) @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_value_sim(simple_object): +def test_get_value_sim(simple_object, request): + simple_object = request.getfixturevalue(simple_object) assert ( simple_object.get_value("x", pytac.RB, pytac.PHYS, pytac.SIM) == DUMMY_VALUE_2 ) @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_unit_conversion(simple_object, double_uc): +def test_unit_conversion(simple_object, double_uc, request): + simple_object = request.getfixturevalue(simple_object) simple_object.set_value("y", DUMMY_VALUE_2, pytac.PHYS, pytac.LIVE) simple_object.get_device("y").set_value.assert_called_with(DUMMY_VALUE_2 / 2, True) diff --git a/tests/test_machine.py b/tests/test_machine.py index 22727e29..14b7c001 100644 --- a/tests/test_machine.py +++ b/tests/test_machine.py @@ -8,7 +8,6 @@ import numpy import pytest -from pytest_lazyfixture import lazy_fixture import pytac @@ -28,22 +27,18 @@ def test_load_lattice_using_default_dir(): @pytest.mark.parametrize( "lattice, name, n_elements, length", - [ - (lazy_fixture("vmx_ring"), "VMX", 2142, 561.571), - (lazy_fixture("diad_ring"), "DIAD", 2144, 561.571), - ], + [("vmx_ring", "VMX", 2142, 561.571), ("diad_ring", "DIAD", 2144, 561.571)], ) -def test_load_lattice(lattice, name, n_elements, length): +def test_load_lattice(lattice, name, n_elements, length, request): + lattice = request.getfixturevalue(lattice) assert len(lattice) == n_elements assert lattice.name == name assert (lattice.get_length() - length) < EPS -@pytest.mark.parametrize( - "lattice, n_bpms", - [(lazy_fixture("vmx_ring"), 173), (lazy_fixture("diad_ring"), 173)], -) -def test_get_pv_names(lattice, n_bpms): +@pytest.mark.parametrize("lattice, n_bpms", [("vmx_ring", 173), ("diad_ring", 173)]) +def test_get_pv_names(lattice, n_bpms, request): + lattice = request.getfixturevalue(lattice) bpm_x_pvs = lattice.get_element_pv_names("BPM", "x", handle="readback") assert len(bpm_x_pvs) == n_bpms for pv in bpm_x_pvs: @@ -56,11 +51,9 @@ def test_get_pv_names(lattice, n_bpms): assert re.match("SR.*HBPM.*SLOW:DISABLED", pv) -@pytest.mark.parametrize( - "lattice, n_bpms", - [(lazy_fixture("vmx_ring"), 173), (lazy_fixture("diad_ring"), 173)], -) -def test_load_bpms(lattice, n_bpms): +@pytest.mark.parametrize("lattice, n_bpms", [("vmx_ring", 173), ("diad_ring", 173)]) +def test_load_bpms(lattice, n_bpms, request): + lattice = request.getfixturevalue(lattice) bpms = lattice.get_elements("BPM") bpm_fields = { "x", @@ -81,20 +74,16 @@ def test_load_bpms(lattice, n_bpms): assert bpms[-1].cell == 24 -@pytest.mark.parametrize( - "lattice, n_drifts", - [(lazy_fixture("vmx_ring"), 1308), (lazy_fixture("diad_ring"), 1311)], -) -def test_load_drift_elements(lattice, n_drifts): +@pytest.mark.parametrize("lattice, n_drifts", [("vmx_ring", 1308), ("diad_ring", 1311)]) +def test_load_drift_elements(lattice, n_drifts, request): + lattice = request.getfixturevalue(lattice) drifts = lattice.get_elements("DRIFT") assert len(drifts) == n_drifts -@pytest.mark.parametrize( - "lattice, n_quads", - [(lazy_fixture("vmx_ring"), 248), (lazy_fixture("diad_ring"), 248)], -) -def test_load_quadrupoles(lattice, n_quads): +@pytest.mark.parametrize("lattice, n_quads", [("vmx_ring", 248), ("diad_ring", 248)]) +def test_load_quadrupoles(lattice, n_quads, request): + lattice = request.getfixturevalue(lattice) quads = lattice.get_elements("Quadrupole") assert len(quads) == n_quads for quad in quads: @@ -105,10 +94,10 @@ def test_load_quadrupoles(lattice, n_quads): @pytest.mark.parametrize( - "lattice, n_q1b, n_q1d", - [(lazy_fixture("vmx_ring"), 34, 12), (lazy_fixture("diad_ring"), 34, 12)], + "lattice, n_q1b, n_q1d", [("vmx_ring", 34, 12), ("diad_ring", 34, 12)] ) -def test_load_quad_family(lattice, n_q1b, n_q1d): +def test_load_quad_family(lattice, n_q1b, n_q1d, request): + lattice = request.getfixturevalue(lattice) q1b = lattice.get_elements("Q1B") assert len(q1b) == n_q1b q1d = lattice.get_elements("Q1D") @@ -116,10 +105,10 @@ def test_load_quad_family(lattice, n_q1b, n_q1d): @pytest.mark.parametrize( - "lattice, n_correctors", - [(lazy_fixture("vmx_ring"), 173), (lazy_fixture("diad_ring"), 172)], + "lattice, n_correctors", [("vmx_ring", 173), ("diad_ring", 172)] ) -def test_load_correctors(lattice, n_correctors): +def test_load_correctors(lattice, n_correctors, request): + lattice = request.getfixturevalue(lattice) hcm = lattice.get_elements("HSTR") vcm = lattice.get_elements("VSTR") assert len(hcm) == n_correctors @@ -136,11 +125,9 @@ def test_load_correctors(lattice, n_correctors): ) -@pytest.mark.parametrize( - "lattice, n_squads", - [(lazy_fixture("vmx_ring"), 98), (lazy_fixture("diad_ring"), 98)], -) -def test_load_squads(lattice, n_squads): +@pytest.mark.parametrize("lattice, n_squads", [("vmx_ring", 98), ("diad_ring", 98)]) +def test_load_squads(lattice, n_squads, request): + lattice = request.getfixturevalue(lattice) squads = lattice.get_elements("SQUAD") assert len(squads) == n_squads for squad in squads: @@ -150,21 +137,19 @@ def test_load_squads(lattice, n_squads): assert re.match("SR.*SQ.*:SETI", device.sp_pv) -@pytest.mark.parametrize( - "lattice", (lazy_fixture("diad_ring"), lazy_fixture("vmx_ring")) -) -def test_cell(lattice): +@pytest.mark.parametrize("lattice", ["diad_ring", "vmx_ring"]) +def test_cell(lattice, request): + lattice = request.getfixturevalue(lattice) # there are squads in every cell sq = lattice.get_elements("SQUAD") assert sq[0].cell == 1 assert sq[-1].cell == 24 -@pytest.mark.parametrize( - "lattice", (lazy_fixture("diad_ring"), lazy_fixture("vmx_ring")) -) +@pytest.mark.parametrize("lattice", ["diad_ring", "vmx_ring"]) @pytest.mark.parametrize("field", ("x", "y")) -def test_bpm_unitconv(lattice, field): +def test_bpm_unitconv(lattice, field, request): + lattice = request.getfixturevalue(lattice) bpm = lattice.get_elements("BPM")[0] uc = bpm._data_source_manager._uc[field] From 0d95c13e98741673621b500e27f859e246fccd16 Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Fri, 7 Jun 2024 15:59:16 +0100 Subject: [PATCH 6/8] Properly parse boolean values from .csv files --- src/pytac/data/DIAD/simple_devices.csv | 2 +- src/pytac/data/DIADSP/simple_devices.csv | 2 +- src/pytac/data/DIADTHz/simple_devices.csv | 2 +- src/pytac/data/I04/simple_devices.csv | 2 +- src/pytac/data/I04SP/simple_devices.csv | 2 +- src/pytac/data/I04THz/simple_devices.csv | 2 +- .../data/SRI0913_MOGA/simple_devices.csv | 2 +- src/pytac/data/VMX/simple_devices.csv | 2 +- src/pytac/data/VMXSP/simple_devices.csv | 2 +- src/pytac/data/VMXTHz/simple_devices.csv | 2 +- src/pytac/data/utils/load_mml.m | 2 +- src/pytac/load_csv.py | 20 ++++++++++++++----- 12 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/pytac/data/DIAD/simple_devices.csv b/src/pytac/data/DIAD/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/DIAD/simple_devices.csv +++ b/src/pytac/data/DIAD/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/DIADSP/simple_devices.csv b/src/pytac/data/DIADSP/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/DIADSP/simple_devices.csv +++ b/src/pytac/data/DIADSP/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/DIADTHz/simple_devices.csv b/src/pytac/data/DIADTHz/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/DIADTHz/simple_devices.csv +++ b/src/pytac/data/DIADTHz/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/I04/simple_devices.csv b/src/pytac/data/I04/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/I04/simple_devices.csv +++ b/src/pytac/data/I04/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/I04SP/simple_devices.csv b/src/pytac/data/I04SP/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/I04SP/simple_devices.csv +++ b/src/pytac/data/I04SP/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/I04THz/simple_devices.csv b/src/pytac/data/I04THz/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/I04THz/simple_devices.csv +++ b/src/pytac/data/I04THz/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/SRI0913_MOGA/simple_devices.csv b/src/pytac/data/SRI0913_MOGA/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/SRI0913_MOGA/simple_devices.csv +++ b/src/pytac/data/SRI0913_MOGA/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/VMX/simple_devices.csv b/src/pytac/data/VMX/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/VMX/simple_devices.csv +++ b/src/pytac/data/VMX/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/VMXSP/simple_devices.csv b/src/pytac/data/VMXSP/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/VMXSP/simple_devices.csv +++ b/src/pytac/data/VMXSP/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/VMXTHz/simple_devices.csv b/src/pytac/data/VMXTHz/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/VMXTHz/simple_devices.csv +++ b/src/pytac/data/VMXTHz/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/utils/load_mml.m b/src/pytac/data/utils/load_mml.m index e701c2fc..0d5514a0 100644 --- a/src/pytac/data/utils/load_mml.m +++ b/src/pytac/data/utils/load_mml.m @@ -43,7 +43,7 @@ function load_mml(ringmode) ao = getao(); % Hard-coded beam energy value. - fprintf(f_simple_devices, '0,energy,3e9,true\n'); + fprintf(f_simple_devices, '0,energy,3e9,True\n'); % The individual BPM PVs are not stored in middlelayer. BPMS = get_bpm_pvs(ao); diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index f47022fb..42f844d1 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -10,6 +10,7 @@ * uc_pchip_data.csv """ +import ast import collections import contextlib import copy @@ -221,8 +222,7 @@ def load(mode, control_system=None, directory=None, symmetry=None): index = int(item["el_id"]) get_pv = item["get_pv"] if item["get_pv"] else None set_pv = item["set_pv"] if item["set_pv"] else None - pve = True - d = EpicsDevice(name, control_system, pve, get_pv, set_pv) + d = EpicsDevice(name, control_system, rb_pv=get_pv, sp_pv=set_pv) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] # Create with a default UnitConv object that returns the input unchanged. @@ -231,7 +231,9 @@ def load(mode, control_system=None, directory=None, symmetry=None): positions = [] for elem in lat: positions.append(elem.s) - lat.add_device("s_position", SimpleDevice(positions, readonly=True), True) + lat.add_device( + "s_position", SimpleDevice(positions, readonly=True), NullUnitConv() + ) simple_devices_file = mode_dir / SIMPLE_DEVICES_FILENAME if simple_devices_file.exists(): with csv_loader(simple_devices_file) as csv_reader: @@ -239,10 +241,18 @@ def load(mode, control_system=None, directory=None, symmetry=None): index = int(item["el_id"]) field = item["field"] value = float(item["value"]) - readonly = item["readonly"].lower() == "true" + try: + readonly = ast.literal_eval(item["readonly"]) + assert isinstance(readonly, bool) + except (ValueError, AssertionError): + raise ValueError( + f"Unable to evaluate {item['readonly']} as a boolean." + ) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] - target.add_device(field, SimpleDevice(value, readonly=readonly), True) + target.add_device( + field, SimpleDevice(value, readonly=readonly), NullUnitConv() + ) with csv_loader(mode_dir / FAMILIES_FILENAME) as csv_reader: for item in csv_reader: lat[int(item["el_id"]) - 1].add_to_family(item["family"]) From 96910589e9d17f3ec8e7f4533d8defd45803a436 Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Fri, 21 Jun 2024 18:06:35 +0100 Subject: [PATCH 7/8] Switch to f-strings --- src/pytac/element.py | 10 +++++----- src/pytac/load_csv.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pytac/element.py b/src/pytac/element.py index b82c0440..160ce8f1 100644 --- a/src/pytac/element.py +++ b/src/pytac/element.py @@ -90,13 +90,13 @@ def __str__(self): """ repn = "" return repn __repr__ = __str__ diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index 42f844d1..47f6a64c 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -88,7 +88,7 @@ def load_pchip_unitconv(filepath: Path) -> Dict[int, PchipUnitConv]: def resolve_unitconv( uc_params: Dict, unitconvs: Dict, polyconv_file: Path, pchipconv_file: Path -): +) -> UnitConv: """Create a unit conversion object based on the dictionary of parameters passed. Args: @@ -172,7 +172,7 @@ def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: element.set_unitconv(item["field"], uc) -def load(mode, control_system=None, directory=None, symmetry=None): +def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLattice: """Load the elements of a lattice from a directory. Args: From 2258498d8b9c7dfc038fad958184722dfa0280a3 Mon Sep 17 00:00:00 2001 From: T-Nicholls Date: Fri, 21 Jun 2024 18:46:20 +0100 Subject: [PATCH 8/8] Resolve mypy errors --- src/pytac/load_csv.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index 47f6a64c..2bb8b4c9 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -218,18 +218,20 @@ def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLatti lat.add_element(e) with csv_loader(mode_dir / EPICS_DEVICES_FILENAME) as csv_reader: for item in csv_reader: - name = item["name"] index = int(item["el_id"]) get_pv = item["get_pv"] if item["get_pv"] else None set_pv = item["set_pv"] if item["set_pv"] else None - d = EpicsDevice(name, control_system, rb_pv=get_pv, sp_pv=set_pv) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] - # Create with a default UnitConv object that returns the input unchanged. - target.add_device(item["field"], d, NullUnitConv()) + # Create with a default UnitConv that returns the input unchanged. + target.add_device( # type: ignore[attr-defined] + item["field"], + EpicsDevice(item["name"], control_system, rb_pv=get_pv, sp_pv=set_pv), + NullUnitConv(), + ) # Add basic devices to the lattice. positions = [] - for elem in lat: + for elem in lat: # type: ignore[attr-defined] positions.append(elem.s) lat.add_device( "s_position", SimpleDevice(positions, readonly=True), NullUnitConv() @@ -239,8 +241,6 @@ def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLatti with csv_loader(simple_devices_file) as csv_reader: for item in csv_reader: index = int(item["el_id"]) - field = item["field"] - value = float(item["value"]) try: readonly = ast.literal_eval(item["readonly"]) assert isinstance(readonly, bool) @@ -250,8 +250,11 @@ def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLatti ) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] - target.add_device( - field, SimpleDevice(value, readonly=readonly), NullUnitConv() + # Create with a default UnitConv that returns the input unchanged. + target.add_device( # type: ignore[attr-defined] + item["field"], + SimpleDevice(float(item["value"]), readonly=readonly), + NullUnitConv(), ) with csv_loader(mode_dir / FAMILIES_FILENAME) as csv_reader: for item in csv_reader: