From 83894fe5213a13de861a5794aa4dcbcdf6351101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrique=20Silv=C3=A9rio?= <29920212+HGSilveri@users.noreply.github.com> Date: Fri, 1 Mar 2024 16:40:57 +0100 Subject: [PATCH] FIX: Deserialization of sized variable items (#653) * FIX: Deserialization of sized variable items * Bump version to 0.17.2 --- VERSION.txt | 2 +- pulser-core/pulser/parametrized/variable.py | 52 +++++++++++++++++---- tests/test_abstract_repr.py | 36 +++++++++++++- tests/test_parametrized.py | 18 ++++++- 4 files changed, 96 insertions(+), 12 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index 7cca7711..c3d16c16 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.17.1 +0.17.2 diff --git a/pulser-core/pulser/parametrized/variable.py b/pulser-core/pulser/parametrized/variable.py index 971790ce..63b08b66 100644 --- a/pulser-core/pulser/parametrized/variable.py +++ b/pulser-core/pulser/parametrized/variable.py @@ -15,7 +15,7 @@ from __future__ import annotations -import collections.abc # To use collections.abc.Sequence +import collections.abc as abc # To use collections.abc.Sequence import dataclasses from typing import Any, Iterator, Optional, Union, cast @@ -99,15 +99,37 @@ def _to_abstract_repr(self) -> dict[str, str]: def __str__(self) -> str: return self.name - def __getitem__(self, key: Union[int, slice]) -> VariableItem: - if not isinstance(key, (int, slice)): + def __getitem__( + self, key: Union[int, slice, abc.Sequence[int]] + ) -> VariableItem: + if not isinstance(key, (int, slice, abc.Sequence)): raise TypeError(f"Invalid key type {type(key)} for '{self.name}'.") - if isinstance(key, int): - if not -self.size <= key < self.size: - raise IndexError(f"{key} outside of range for '{self.name}'.") + bad_ind = None + if isinstance(key, int) and not -self.size <= key < self.size: + bad_ind = key + elif isinstance(key, abc.Sequence): + for ind_ in key: + if not isinstance(ind_, int): + raise TypeError( + f"Invalid index type {type(ind_)} for variable " + f"'{self.name}'." + ) + if not -self.size <= ind_ < self.size: + bad_ind = ind_ + break + else: + key = list(key) + if bad_ind is not None: + raise IndexError( + f"Index {bad_ind} out of bounds for variable '{self.name}' " + f"with size {self.size}." + ) return VariableItem(self, key) + # NOTE: __len__ cannot be defined because it makes numpy.ufuncs convert a + # Variable into an array of VariableItem's + def __iter__(self) -> Iterator[VariableItem]: for i in range(self.size): yield self[i] @@ -118,7 +140,7 @@ class VariableItem(Parametrized, OpSupport): """Stores access to items of a variable with multiple values.""" var: Variable - key: Union[int, slice] + key: Union[int, slice, abc.Sequence[int]] @property def variables(self) -> dict[str, Variable]: @@ -127,7 +149,10 @@ def variables(self) -> dict[str, Variable]: def build(self) -> Union[ArrayLike, float, int]: """Return the variable's item(s) values.""" - return cast(collections.abc.Sequence, self.var.build())[self.key] + built_var = cast(abc.Sequence, self.var.build()) + if isinstance(self.key, abc.Sequence): + return [built_var[k] for k in self.key] + return built_var[self.key] def _to_dict(self) -> dict[str, Any]: return obj_to_dict( @@ -135,7 +160,11 @@ def _to_dict(self) -> dict[str, Any]: ) def _to_abstract_repr(self) -> dict[str, Any]: - indices = list(range(self.var.size))[self.key] + indices: int | list[int] + if isinstance(self.key, abc.Sequence): + indices = list(self.key) + else: + indices = list(range(self.var.size))[self.key] return {"expression": "index", "lhs": self.var, "rhs": indices} def __str__(self) -> str: @@ -148,3 +177,8 @@ def __str__(self) -> str: else: key_str = str(self.key) return f"{str(self.var)}[{key_str}]" + + def __len__(self) -> int: + if isinstance(self.key, int): + raise TypeError(f"len() of unsized variable item '{self!s}'.") + return len(np.arange(self.var.size)[self.key]) diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index 4b02eaf9..b8c9173e 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -700,6 +700,25 @@ def test_paramobj_serialization(self, sequence): times=[0.0, 0.5, 1.0], ) + ser_inv_list_var_items = { + "expression": "index", + "lhs": {"variable": "list_var"}, + "rhs": [2, 1, 0], + } + s = json.dumps( + InterpolatedWaveform(var, list_var[::-1]), cls=AbstractReprEncoder + ) + assert json.loads(s) == dict( + kind="interpolated", + duration=ser_var, + values=ser_inv_list_var_items, + times=[0.0, 0.5, 1.0], + ) + assert s == json.dumps( + InterpolatedWaveform(var, list_var[[2, 1, 0]]), + cls=AbstractReprEncoder, + ) + err_msg = ( "An InterpolatedWaveform with 'values' of unknown length " "and unspecified 'times' can't be serialized to the abstract" @@ -2023,6 +2042,16 @@ def test_deserialize_parametrized_waveform(self, wf_obj): {"expression": "cos", "lhs": var1}, {"expression": "tan", "lhs": {"variable": "var1"}}, {"expression": "index", "lhs": {"variable": "var1"}, "rhs": 0}, + { + "expression": "index", + "lhs": {"variable": "var2"}, + "rhs": [1, 2], + }, + { + "expression": "index", + "lhs": {"variable": "var2"}, + "rhs": [4, 2, 0], + }, {"expression": "add", "lhs": var1, "rhs": 0.5}, {"expression": "sub", "lhs": {"variable": "var1"}, "rhs": 0.5}, {"expression": "mul", "lhs": {"variable": "var1"}, "rhs": 0.5}, @@ -2058,6 +2087,7 @@ def test_deserialize_param(self, json_param): ], variables={ "var1": {"type": "float", "value": [1.5]}, + "var2": {"type": "int", "value": [0, 1, 2, 3, 4]}, }, ) # Note: If built, some of these sequences will be invalid @@ -2072,6 +2102,7 @@ def test_deserialize_param(self, json_param): _check_roundtrip(s) seq = Sequence.from_abstract_repr(json.dumps(s)) seq_var1 = seq._variables["var1"] + seq_var2 = seq._variables["var2"] # init + declare channels + 1 operation offset = 1 + len(s["channels"]) @@ -2111,7 +2142,10 @@ def test_deserialize_param(self, json_param): assert param == np.tan(seq_var1) if expression == "index": - assert param == seq_var1[rhs] + if json_param["lhs"] == {"variable": "var1"}: + assert param == seq_var1[rhs] + else: + assert param == seq_var2[rhs] if expression == "add": assert param == seq_var1[0] + rhs if expression == "sub": diff --git a/tests/test_parametrized.py b/tests/test_parametrized.py index da9a5819..b94a7de0 100644 --- a/tests/test_parametrized.py +++ b/tests/test_parametrized.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import re from dataclasses import FrozenInstanceError import numpy as np @@ -87,26 +88,41 @@ def test_var(a, b): assert np.all(var_.build() == np.array([1, 2])) with pytest.raises(TypeError, match="Invalid key type"): - b[[0, 1]] + b[{0, 1}] + with pytest.raises(TypeError, match="Invalid index type"): + b[[0.0, -1.0]] with pytest.raises(IndexError): b[2] + with pytest.raises(IndexError): + b[[-3, 1]] def test_varitem(a, b, d): a0 = a[0] b1 = b[1] b01 = b[100::-1] + b01_2 = b[[-1, -2]] + b01_3 = b[(1, 0)] d0 = d[0] assert b01.variables == {"b": b} assert str(a0) == "a[0]" assert str(b1) == "b[1]" assert str(b01) == "b[100::-1]" + assert str(b01_2) == "b[[-1, -2]]" + assert str(b01_3) == "b[[1, 0]]" assert str(d0) == "d[0]" assert b1.build() == 1 assert np.all(b01.build() == np.array([1, -1])) assert d0.build() == 0.5 with pytest.raises(FrozenInstanceError): b1.key = 0 + np.testing.assert_equal(b01.build(), b01_2.build()) + np.testing.assert_equal(b01_2.build(), b01_3.build()) + with pytest.raises( + TypeError, match=re.escape("len() of unsized variable item 'b[1]'") + ): + len(b1) + assert len(b01) == len(b01_2) == len(b01_3) == b.size == 2 def test_paramobj(bwf, t, a, b):