From cf81fcc649af82c109eb30ecaab1daf2743cb653 Mon Sep 17 00:00:00 2001 From: a_corni Date: Thu, 13 Jul 2023 11:05:02 +0200 Subject: [PATCH] Add validation RegisterLayout,EOM,Channel,Device --- .../pulser/channels/comparison_tools.py | 201 ++++++++++++++++++ .../pulser/devices/comparison_tools.py | 119 +++++++++++ .../pulser/register/register_layout.py | 28 +++ tests/test_channels.py | 123 +++++++++++ tests/test_devices.py | 146 ++++++++++++- tests/test_register_layout.py | 38 ++++ 6 files changed, 653 insertions(+), 2 deletions(-) create mode 100644 pulser-core/pulser/channels/comparison_tools.py create mode 100644 pulser-core/pulser/devices/comparison_tools.py diff --git a/pulser-core/pulser/channels/comparison_tools.py b/pulser-core/pulser/channels/comparison_tools.py new file mode 100644 index 00000000..3db2cbff --- /dev/null +++ b/pulser-core/pulser/channels/comparison_tools.py @@ -0,0 +1,201 @@ +# Copyright 2020 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines functions for comparison of Channels.""" +from __future__ import annotations + +import numbers +from collections.abc import Callable +from dataclasses import asdict +from operator import gt, lt, ne +from typing import TYPE_CHECKING, Any + +import numpy as np + +from pulser.channels.eom import RydbergEOM + +if TYPE_CHECKING: + from pulser.channels.base_channel import Channel + from pulser.channels.eom import BaseEOM + + +def _compare_with_None( + comparison_ops: dict[str, Callable], + leftvalues: dict[str, float | None], + rightvalues: dict[str, float | None], +) -> dict[str, Any]: + """Compare dict of values using a dict of comparison operators. + + If comparison operator returns a boolean, the value returned is: + - True if right value is None and left value is defined. + - False if left value is None and right value is defined. + - None if left and right values are None. + Implemented for lt, gt, min, max. + + Args: + comparison_ops: Associate keys to compare with comparison operator + leftvalues: Dict of values on the left of the comparison operator + rightvalues: Dict of values on the right of the comparison operator + + Returns: + Dictionary having the keys of the comparison operator, associating + None if both values are None for this key, and the result of the + comparison otherwise. + """ + # Error if some keys of comparison operators are not dict of left or right + if not ( + comparison_ops.keys() <= leftvalues.keys() + and comparison_ops.keys() <= rightvalues.keys() + ): + raise ValueError( + "Keys in comparison_ops should be in left values and right values." + ) + # Compare using +inf and -inf to replace None values + return { + key: comparison_op( + *( + value + or ( + float("+inf") + if comparison_op in [lt, min] + else float("-inf") + ) + for value in (leftvalues[key], rightvalues[key]) + ) + ) + if (leftvalues[key], rightvalues[key]) != (None, None) + else None + for (key, comparison_op) in comparison_ops.items() + } + + +def _validate_obj_from_best( + obj_dict: dict, best_obj_dict: dict, comparison_ops: dict +) -> bool: + """Validates an object by comparing it with a better one. + + Attributes: + obj_dict: Dict of attributes and values of the object to compare. + best_obj_dict: Dict of attributes and values of the best object. + comparison_ops: Dict of attributes and comparison operators to + use to compare the object and the best object. + + Returns: + True if the comparison works, raises a ValueError otherwise. + """ + # If the two values are almost equal then there is no need to compare them + comparison_ops_keys = list(comparison_ops.keys()) + for key in comparison_ops_keys: + if ( + obj_dict[key] is not None + and best_obj_dict[key] is not None + and isinstance(obj_dict[key], numbers.Number) + and isinstance(best_obj_dict[key], numbers.Number) + and np.isclose(obj_dict[key], best_obj_dict[key], rtol=1e-14) + ): + comparison_ops.pop(key) + is_wrong_effective_ch = _compare_with_None( + comparison_ops, obj_dict, best_obj_dict + ) + # Validates if no True in the dictionary of the comparisons + if not (True in is_wrong_effective_ch.values()): + return True + is_wrong_effective_index = list(is_wrong_effective_ch.values()).index(True) + is_wrong_key = list(is_wrong_effective_ch.keys())[is_wrong_effective_index] + raise ValueError( + f"{is_wrong_key} cannot be" + + ( + " below " + if comparison_ops[is_wrong_key] == lt + else ( + " above " + if comparison_ops[is_wrong_key] == gt + else " different than " + ) + ) + + f"{best_obj_dict[is_wrong_key]}." + ) + + +def validate_channel_from_best( + channel: Channel, best_channel: Channel +) -> bool: + """Checks that a channel can be realized from another one. + + Attributes: + channel: The channel to check. + best_channel: The channel that should have better properties. + """ + if type(channel) != type(best_channel): + raise ValueError( + "Channels do not have the same types, " + f"{type(channel)} and {type(best_channel)}" + ) + if channel.eom_config: + if best_channel.eom_config: + validate_eom_from_best(channel.eom_config, best_channel.eom_config) + else: + raise ValueError( + "eom_config cannot be defined in channel as the best_channel" + " does not have one." + ) + best_ch_att = asdict(best_channel) + ch_att = asdict(channel) + + # Error if attributes in channel and best_channel compare to True + comparison_ops = { + "addressing": ne, + "max_abs_detuning": gt, + "max_amp": gt, + "min_retarget_interval": lt, + "fixed_retarget_t": lt, + "max_targets": gt, + "clock_period": lt, + "min_duration": lt, + "max_duration": gt, + "mod_bandwidth": gt, + } + return _validate_obj_from_best(ch_att, best_ch_att, comparison_ops) + + +def validate_eom_from_best(eom: BaseEOM, best_eom: BaseEOM) -> bool: + """Checks that an EOM config can be realized from another one. + + Attributes: + eom: The EOM config to check. + best_eom: The EOM config that should have better properties. + """ + best_eom_att = asdict(best_eom) + eom_att = asdict(eom) + + # Error if attributes in eom and best_eom compare to True + comparison_ops = {"mod_bandwidth": gt} + if isinstance(eom, RydbergEOM): + if isinstance(best_eom, RydbergEOM): + comparison_ops.update( + { + "limiting_beam": ne, + "max_limiting_amp": gt, + "intermediate_detuning": ne, + "controlled_beams": gt, + } + ) + best_eom_att["controlled_beams"] = set( + best_eom_att["controlled_beams"] + ) + eom_att["controlled_beams"] = set(eom_att["controlled_beams"]) + else: + raise ValueError( + "EOM config is RydbergEOM whereas best EOM config is not." + ) + return _validate_obj_from_best(eom_att, best_eom_att, comparison_ops) diff --git a/pulser-core/pulser/devices/comparison_tools.py b/pulser-core/pulser/devices/comparison_tools.py new file mode 100644 index 00000000..b4f1f47a --- /dev/null +++ b/pulser-core/pulser/devices/comparison_tools.py @@ -0,0 +1,119 @@ +# Copyright 2020 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines functions for comparison of Channels.""" +from __future__ import annotations + +import itertools +from dataclasses import asdict +from operator import gt, lt, ne +from typing import TYPE_CHECKING, cast + +from pulser.channels.comparison_tools import ( + _validate_obj_from_best, + validate_channel_from_best, +) +from pulser.devices import Device, VirtualDevice + +if TYPE_CHECKING: + from pulser.devices._device_datacls import BaseDevice + + +def _exist_good_configuration( + possible_configurations: dict[str, list[str]] +) -> bool: + # If one value is an empty list then no configuration can work + if any( + [ + len(possible_values) == 0 + for possible_values in possible_configurations.values() + ] + ): + return False + print(list(possible_configurations.values())) + print(itertools.product(*list(possible_configurations.values()))) + for config in itertools.product(*list(possible_configurations.values())): + # True if each value in the list is different + print(config) + set(config) + if len(set(config)) == len(config): + return True + return False + + +def validate_device_from_best( + device: BaseDevice, best_device: BaseDevice +) -> bool: + """Checks that a device can be realized from another one. + + Attributes: + device: The device to check. + best_device: The device that should have better properties. + """ + if type(device) != type(best_device): + raise ValueError( + "Devices do not have the same types, " + f"{type(device)} and {type(best_device)}" + ) + equivalent_channels: dict[str, list[str]] = { + ch_name: [] for ch_name in device.channels.keys() + } + for ch_name, channel in device.channels.items(): + for best_ch_name, best_channel in best_device.channels.items(): + try: + validate_channel_from_best(channel, best_channel) + except ValueError: + continue + equivalent_channels[ch_name].append(best_ch_name) + if not _exist_good_configuration(equivalent_channels): + raise ValueError( + "No configuration could be found where each channel of the device" + " could be realized with one channel of the best device." + ) + + if isinstance(device, Device) and device.calibrated_register_layouts: + equivalent_layouts: dict[str, list[str]] = { + str(id): [] + for id in range(len(device.calibrated_register_layouts)) + } + for id, layout in enumerate(device.calibrated_register_layouts): + for id_best, best_layout in enumerate( + cast(Device, best_device).calibrated_register_layouts + ): + if best_layout > layout: + equivalent_layouts[str(id)].append(str(id_best)) + if not _exist_good_configuration(equivalent_layouts): + raise ValueError( + "No configuration could be found where each calibrated layouts" + " of the device could be realized with a calibrated layout of" + " the best device." + ) + + best_device_att = asdict(best_device) + device_att = asdict(device) + + # Error if attributes in device and best_device compare to True + comparison_ops = { + "dimensions": gt, + "rydberg_level": ne, + "min_atom_distance": lt, + "max_atom_num": gt, + "max_radial_distance": gt, + "interaction_coeff_xy": ne, + "supports_slm_mask": gt, + "max_layout_filling": gt, + } + if isinstance(device, VirtualDevice): + comparison_ops["reusable_channels"] = gt + + return _validate_obj_from_best(device_att, best_device_att, comparison_ops) diff --git a/pulser-core/pulser/register/register_layout.py b/pulser-core/pulser/register/register_layout.py index a8aa8143..69873cd3 100644 --- a/pulser-core/pulser/register/register_layout.py +++ b/pulser-core/pulser/register/register_layout.py @@ -311,6 +311,34 @@ def __eq__(self, other: Any) -> bool: return False return self._safe_hash() == other._safe_hash() + def __gt__(self, other: RegisterLayout) -> bool: + if not isinstance(other, RegisterLayout): + raise TypeError("Right operand should be of type RegisterLayout.") + return ( + set(tuple(self._coords[i]) for i in range(self.number_of_traps)) + > set( + tuple(other._coords[i]) for i in range(other.number_of_traps) + ) + and self.dimensionality == other.dimensionality + ) + + def __ge__(self, other: Any) -> bool: + return self.__eq__(other) or self.__gt__(other) + + def __lt__(self, other: RegisterLayout) -> bool: + if not isinstance(other, RegisterLayout): + raise TypeError("Right operand should be of type RegisterLayout.") + return ( + set(tuple(self._coords[i]) for i in range(self.number_of_traps)) + < set( + tuple(other._coords[i]) for i in range(other.number_of_traps) + ) + and self.dimensionality == other.dimensionality + ) + + def __le__(self, other: Any) -> bool: + return self.__eq__(other) or self.__lt__(other) + def __hash__(self) -> int: return hash(self._safe_hash()) diff --git a/tests/test_channels.py b/tests/test_channels.py index 1a47c3a8..0dcd2ae9 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -13,12 +13,19 @@ # limitations under the License. import re +from dataclasses import replace +from operator import gt import numpy as np import pytest import pulser from pulser.channels import Microwave, Raman, Rydberg +from pulser.channels.comparison_tools import ( + _compare_with_None, + validate_channel_from_best, + validate_eom_from_best, +) from pulser.channels.eom import MODBW_TO_TR, BaseEOM, RydbergBeam, RydbergEOM from pulser.waveforms import BlackmanWaveform, ConstantWaveform @@ -267,3 +274,119 @@ def test_modulation(channel, tr, eom, side_buffer_len): side_buffer_len, side_buffer_len, ) + + +def test_compare_with_None(): + comp_ops = {"max_amp": gt, "max_duration": gt} + left_values = {"max_amp": 1, "clock_period": 1, "max_duration": None} + right_values = {"max_amp": None, "clock_period": 2} + with pytest.raises( + ValueError, + match="Keys in comparison_ops should be in left values and right", + ): + _compare_with_None(comp_ops, left_values, right_values) + right_values["max_duration"] = None + comp_ops["clock_period"] = gt + assert _compare_with_None(comp_ops, left_values, right_values) == { + "max_amp": True, + "clock_period": False, + "max_duration": None, + } + assert _compare_with_None(comp_ops, right_values, left_values) == { + "max_amp": False, + "clock_period": True, + "max_duration": None, + } + + +@pytest.mark.parametrize( + "param, value, comparison", + [ + ("mod_bandwidth", 10, None), + ("mod_bandwidth", 20, None), + ("mod_bandwidth", 40, "above"), + ("limiting_beam", RydbergBeam.RED, None), + ("limiting_beam", RydbergBeam.BLUE, "different"), + ("max_limiting_amp", 2 * np.pi * 100, None), + ("max_limiting_amp", 2 * np.pi * 150, "above"), + ("controlled_beams", (RydbergBeam.RED,), None), + ("controlled_beams", tuple(RydbergBeam), "above"), + ("intermediate_detuning", 500 * 2 * np.pi, None), + ("intermediate_detuning", 200 * 2 * np.pi, "different"), + ], +) +def test_compare_eom(param, value, comparison): + print(param, value) + best_eom_config = replace(_eom_config, controlled_beams=(RydbergBeam.RED,)) + + with pytest.raises( + ValueError, + match="EOM config is RydbergEOM whereas best EOM config is not.", + ): + validate_eom_from_best(best_eom_config, BaseEOM(20)) + + eom_config = replace(best_eom_config, **{param: value}) + if comparison is None: + assert validate_eom_from_best(eom_config, best_eom_config) + if param == "mod_bandwidth": + # Works because BaseEOM can be implemented by RydbergEOM + validate_eom_from_best(BaseEOM(value), best_eom_config) + else: + with pytest.raises( + ValueError, match=f"{param} cannot be {comparison} " + ): + validate_eom_from_best(eom_config, best_eom_config) + if param == "mod_bandwidth": + with pytest.raises( + ValueError, match=f"{param} cannot be {comparison} " + ): + validate_eom_from_best(BaseEOM(value), best_eom_config) + + +@pytest.mark.parametrize( + "param, value, comparison", + [ + ("max_abs_detuning", 300, "above"), + ("max_amp", 100, "above"), + ("mod_bandwidth", 10, "above"), + ("clock_period", 2, "below"), + ("max_duration", 2e8, "above"), + ("min_duration", 1, None), + ("max_targets", 4, "above"), + ("fixed_retarget_t", 0, None), + ("min_retarget_interval", 180, "below"), + ], +) +def test_compare_channels(mod_device, param, value, comparison): + # Error if channel types are different + with pytest.raises( + ValueError, + match=f"Channels do not have the same types, {Raman} and {Rydberg}", + ): + validate_channel_from_best(_raman_local, _eom_rydberg) + # Error if the best channel does not have an EOM + rydberg_no_eom = Rydberg.Global( + max_amp=2 * np.pi * 10, + max_abs_detuning=2 * np.pi * 5, + mod_bandwidth=10, + ) + with pytest.raises( + ValueError, match="eom_config cannot be defined in channel" + ): + validate_channel_from_best(_eom_rydberg, rydberg_no_eom) + assert validate_eom_from_best(rydberg_no_eom, _eom_rydberg) + # Error if addressing is not the same + best_channel = mod_device.channels["rydberg_local"] + with pytest.raises( + ValueError, match="addressing cannot be different than Local." + ): + validate_channel_from_best(rydberg_no_eom, best_channel) + # Changing parameters of a Local channel. + channel = replace(best_channel, **{param: value}) + if comparison is None: + assert validate_channel_from_best(channel, best_channel) + else: + with pytest.raises( + ValueError, match=f"{param} cannot be {comparison} " + ): + validate_channel_from_best(channel, best_channel) diff --git a/tests/test_devices.py b/tests/test_devices.py index 4bd3b4d2..f798e3fa 100644 --- a/tests/test_devices.py +++ b/tests/test_devices.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from dataclasses import FrozenInstanceError +from dataclasses import FrozenInstanceError, replace from unittest.mock import patch import numpy as np @@ -21,7 +21,18 @@ import pulser from pulser.channels import Microwave, Raman, Rydberg -from pulser.devices import Chadoq2, Device, VirtualDevice +from pulser.devices import ( + Chadoq2, + Device, + IroiseMVP, + MockDevice, + VirtualDevice, +) +from pulser.devices.comparison_tools import ( + _exist_good_configuration, + validate_channel_from_best, + validate_device_from_best, +) from pulser.register import Register, Register3D from pulser.register.register_layout import RegisterLayout from pulser.register.special_layouts import TriangularLatticeLayout @@ -376,3 +387,134 @@ def test_device_params(): assert set(all_params) - set(all_virtual_params) == { "pre_calibrated_layouts" } + + +def test_exist_good_config(): + good_config = { + "chA": { + "ch1", + }, + "chB": {"ch1", "ch2"}, + "chC": { + "ch3", + }, + } + another_good_config = { + "chA": {"ch1", "ch2"}, + "chB": {"ch1", "ch2"}, + "chC": { + "ch3", + }, + } + bad_config = { + "chA": { + "ch1", + }, + "chB": { + "ch1", + }, + "chC": { + "ch2", + }, + } + assert _exist_good_configuration(good_config) + assert _exist_good_configuration(another_good_config) + assert not _exist_good_configuration(bad_config) + + +@pytest.mark.parametrize( + "param, value, comparison", + [ + ("dimensions", 3, "above"), + ("rydberg_level", 69, "different"), + ("max_atom_num", 200, "above"), + ("max_radial_distance", 60, "above"), + ("min_atom_distance", 2, "below"), + ("supports_slm_mask", False, None), + ("max_layout_filling", 0.9, "above"), + ("interaction_coeff_xy", 1, "different"), + ], +) +def test_compare_channels(mod_device, param, value, comparison): + # Checking the defined devices + assert validate_device_from_best(Chadoq2, mod_device) + Iroise_eom = replace( + IroiseMVP.channels["rydberg_global"].eom_config, + intermediate_detuning=800 * 2 * np.pi, + ) + Iroise_global = replace( + IroiseMVP.channels["rydberg_global"], eom_config=Iroise_eom + ) + assert validate_channel_from_best( + Iroise_global, mod_device.channels["rydberg_global"] + ) + assert validate_device_from_best( + replace(IroiseMVP, channel_objects=(Iroise_global,), rydberg_level=70), + mod_device, + ) + # Error if channel types are different + with pytest.raises( + ValueError, + match=( + "Devices do not have the same types, " + + f"{VirtualDevice} and {Device}" + ), + ): + validate_device_from_best(MockDevice, mod_device) + # Error if the channels don't match + with pytest.raises(ValueError, match="No configuration could be found"): + validate_device_from_best(IroiseMVP, mod_device) + + best_device = Chadoq2.to_virtual() + channel_with_locals = replace( + best_device, + channel_ids=( + "rydberg_global", + "rydberg_local1", + "raman_local", + "rydberg_local2", + ), + channel_objects=( + best_device.channels["rydberg_global"], + best_device.channels["rydberg_local"], + best_device.channels["raman_local"], + replace(best_device.channels["rydberg_local"], max_amp=100), + ), + ) + assert validate_device_from_best(best_device, channel_with_locals) + assert validate_device_from_best(channel_with_locals, channel_with_locals) + with pytest.raises(ValueError, match="No configuration could be found"): + validate_device_from_best(channel_with_locals, best_device) + # Error if pre-calibrated layouts don't match + zero_layout = RegisterLayout([[0, 0]]) + assert validate_device_from_best( + Chadoq2, replace(Chadoq2, pre_calibrated_layouts=[zero_layout]) + ) + with pytest.raises(ValueError, match="No configuration could be found "): + validate_device_from_best( + replace(Chadoq2, pre_calibrated_layouts=[zero_layout]), + Chadoq2, + ) + pre_calibrated_layouts = [RegisterLayout([[0, 0], [4, 0]])] + assert validate_device_from_best( + replace(Chadoq2, pre_calibrated_layouts=[zero_layout]), + replace(Chadoq2, pre_calibrated_layouts=pre_calibrated_layouts), + ) + with pytest.raises(ValueError, match="No configuration could be found "): + validate_device_from_best( + replace( + Chadoq2, + pre_calibrated_layouts=[zero_layout] + pre_calibrated_layouts, + ), + replace(Chadoq2, pre_calibrated_layouts=pre_calibrated_layouts), + ) + + # Changing parameters of a Local channel. + device = replace(best_device, **{param: value}) + if comparison is None: + assert validate_device_from_best(device, best_device) + else: + with pytest.raises( + ValueError, match=f"{param} cannot be {comparison} " + ): + validate_device_from_best(device, best_device) diff --git a/tests/test_register_layout.py b/tests/test_register_layout.py index 38b76f21..9f35ab45 100644 --- a/tests/test_register_layout.py +++ b/tests/test_register_layout.py @@ -153,6 +153,44 @@ def test_eq(layout, layout3d): assert hash(layout1) == hash(layout2) +def test_ineq(layout, layout3d): + layout1 = RegisterLayout([[0, 0], [1, 1]]) + layout2 = RegisterLayout([[1, 0, 1], [0, 0, 0]]) + zero_layout = RegisterLayout([[0, 0]]) + zero_3Dlayout = RegisterLayout([[0, 0, 0]]) + + with pytest.raises( + TypeError, match="Right operand should be of type RegisterLayout." + ): + zero_layout < Register.from_coordinates([[1, 0], [0, 0]]) + + with pytest.raises( + TypeError, match="Right operand should be of type RegisterLayout." + ): + zero_layout > Register.from_coordinates([[1, 0], [0, 0]]) + + assert layout >= layout1 + assert not layout1 >= layout + assert layout1 <= layout + assert not layout <= layout1 + + assert layout3d >= layout2 + assert not layout2 >= layout3d + assert layout2 <= layout3d + assert not layout3d <= layout2 + + # Returns False if dimensionality issues + assert zero_layout < layout1 + assert not zero_layout < layout2 + assert layout1 > zero_layout + assert not layout2 > zero_layout + + assert zero_3Dlayout < layout2 + assert not zero_3Dlayout < layout1 + assert layout2 > zero_3Dlayout + assert not layout1 > zero_3Dlayout + + def test_traps_from_coordinates(layout): assert layout._coords_to_traps == { (0, 0): 0,