From caa568bb474006ed00598e0244c738b389214f0d Mon Sep 17 00:00:00 2001 From: a_corni Date: Tue, 10 Dec 2024 14:37:46 +0100 Subject: [PATCH] Fix tests --- tests/test_dmm.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/test_dmm.py b/tests/test_dmm.py index 80a8c9cb0..436515ad6 100644 --- a/tests/test_dmm.py +++ b/tests/test_dmm.py @@ -65,7 +65,7 @@ def slm_map( ) -> DetuningMap: return layout.define_detuning_map(slm_dict) - @pytest.mark.parametrize("bad_key", [{"1": 1.0}, {4: 1.0}]) + @pytest.mark.parametrize("bad_key", [{1: 1.0}, {"4": 1.0}]) def test_define_detuning_map( self, layout: RegisterLayout, @@ -74,6 +74,13 @@ def test_define_detuning_map( bad_key: dict, ): for reg in (layout, map_reg): + if type(list(bad_key.keys())[0]) == int: + with pytest.raises( + ValueError, + match="'trap_coordinates' must be an array or list", + ): + reg.define_detuning_map(bad_key) # type: ignore + continue with pytest.raises( ValueError, match=re.escape( @@ -93,7 +100,7 @@ def test_define_detuning_map( def test_qubit_weight_map(self, register): # Purposefully unsorted - qid_weight_map = {1: 1.0, 0: 0.1, 3: 0.4} + qid_weight_map = {"1": 1.0, "0": 0.1, "3": 0.4} sorted_qids = sorted(qid_weight_map) det_map = register.define_detuning_map(qid_weight_map) qubits = register.qubits @@ -106,7 +113,7 @@ def test_qubit_weight_map(self, register): # We recover the original qid_weight_map (and undefined qids show as 0) assert det_map.get_qubit_weight_map(qubits) == { **qid_weight_map, - 2: 0.0, + "2": 0.0, } tri_layout = TriangularLatticeLayout(100, spacing=5) @@ -174,8 +181,11 @@ def test_detuning_map_bad_init( ): DetuningMap([(0, 0), (1, 0)], [0]) - bad_weights = {0: -1.0, 1: 1.0, 2: 1.0} for reg in (layout, map_reg, register): + if reg == register: + bad_weights = {"0": -1.0, "1": 1.0, "2": 1.0} + else: + bad_weights = {0: -1.0, 1: 1.0, 2: 1.0} with pytest.raises( ValueError, match="All weights must be between 0 and 1." ): @@ -189,11 +199,19 @@ def test_init( det_dict: dict[int, float], slm_dict: dict[int, float], ): + for reg in (layout, map_reg, register): for detuning_map_dict in (det_dict, slm_dict): + if reg == register: + reg_det_map_dict = { + str(id): weight + for (id, weight) in detuning_map_dict.items() + } + else: + reg_det_map_dict = detuning_map_dict.copy() detuning_map = cast( DetuningMap, - reg.define_detuning_map(detuning_map_dict), # type: ignore + reg.define_detuning_map(reg_det_map_dict), # type: ignore ) assert np.all( [