diff --git a/pulser-core/pulser/json/abstract_repr/serializer.py b/pulser-core/pulser/json/abstract_repr/serializer.py index 239b5e10f..6376f7c6c 100644 --- a/pulser-core/pulser/json/abstract_repr/serializer.py +++ b/pulser-core/pulser/json/abstract_repr/serializer.py @@ -70,12 +70,8 @@ def abstract_repr(name: str, *args: Any, **kwargs: Any) -> dict[str, Any]: res.update(signature.extra) # Starts with extra info ({} if undefined) # With PulseSignature.all_pos_args(), we safeguard against the opposite # case where an expected keyword argument is given as a positional argument - res.update( - { - arg_name: arg_val - for arg_name, arg_val in zip(signature.all_pos_args(), args) - } - ) + res.update(dict(zip(signature.all_pos_args(), args))) + # Account for keyword arguments given as pos args max_pos_args = len(signature.pos) + len( set(signature.keyword) - set(kwargs) diff --git a/pulser-core/pulser/parametrized/paramobj.py b/pulser-core/pulser/parametrized/paramobj.py index 439d1b3be..60675d05d 100644 --- a/pulser-core/pulser/parametrized/paramobj.py +++ b/pulser-core/pulser/parametrized/paramobj.py @@ -24,6 +24,7 @@ import numpy as np +import pulser.parametrized from pulser.json.abstract_repr.serializer import abstract_repr from pulser.json.abstract_repr.signatures import ( BINARY_OPERATORS, @@ -161,6 +162,16 @@ def __init__(self, cls: Callable, *args: Any, **kwargs: Any) -> None: self._instance = None self._vars_state: dict[str, int] = {} + @property + def _default_kwargs(self) -> dict[str, Any]: + """The default values for the object's keyword arguments.""" + cls_signature = inspect.signature(self.cls).parameters + return { + param: cls_signature[param].default + for param in cls_signature + if cls_signature[param].default != cls_signature[param].empty + } + @property def variables(self) -> dict[str, Variable]: """Returns all involved variables.""" @@ -241,19 +252,18 @@ def _to_abstract_repr(self) -> dict[str, Any]: # classmethod cls_name = self.args[0].__name__ name = f"{cls_name}.{op_name}" - if cls_name == "Pulse": - signature = ( - "amplitude", - "detuning", - "phase", - "post_phase_shift", - ) - all_args = { - **dict(zip(signature, self.args[1:])), - **self.kwargs, - } - if "post_phase_shift" not in all_args: - all_args["post_phase_shift"] = 0.0 + signature = SIGNATURES[ + "Pulse" if cls_name == "Pulse" else name + ] + # No existing classmethod has *args in its signature + assert ( + signature.var_pos is None + ), "Unexpected signature with VAR_POSITIONAL arguments." + all_args = { + **self._default_kwargs, + **dict(zip(signature.all_pos_args(), self.args[1:])), + **self.kwargs, + } if name == "Pulse.ConstantAmplitude": all_args["amplitude"] = abstract_repr( "ConstantWaveform", 0, all_args["amplitude"] @@ -265,13 +275,48 @@ def _to_abstract_repr(self) -> dict[str, Any]: ) return abstract_repr("Pulse", **all_args) else: - return abstract_repr(name, *self.args[1:], **self.kwargs) + return abstract_repr(name, **all_args) raise NotImplementedError( "Instance or static method serialization is not supported." ) elif op_name in SIGNATURES: - return abstract_repr(op_name, *self.args, **self.kwargs) + signature = SIGNATURES[op_name] + filtered_defaults = { + key: value + for key, value in self._default_kwargs.items() + if key in signature.keyword + } + full_kwargs = {**filtered_defaults, **self.kwargs} + if signature.var_pos is not None: + # No args can be given with a keyword + return abstract_repr(op_name, *self.args, **full_kwargs) + + all_args = { + **full_kwargs, + **dict(zip(signature.all_pos_args(), self.args)), + } + if op_name == "InterpolatedWaveform" and all_args["times"] is None: + if isinstance( + all_args["values"], + pulser.parametrized.Variable, # Avoids circular import + ): + num_values = all_args["values"].size + else: + try: + num_values = len(all_args["values"]) + except TypeError: + raise AbstractReprError( + "An InterpolatedWaveform with 'values' of unknown " + "length and unspecified 'times' can't be " + "serialized to the abstract representation. To " + "keep the same argument for 'values', provide " + "compatible 'times' explicitly." + ) + + all_args["times"] = np.linspace(0, 1, num=num_values) + + return abstract_repr(op_name, **all_args) elif op_name in UNARY_OPERATORS: return dict(expression=op_name, lhs=self.args[0]) diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index e281f13a2..50331a0b9 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -490,18 +490,18 @@ def test_paramobj_serialization(self, sequence): ) s = json.dumps( - Pulse.ConstantDetuning(wf, 0.0, var, post_phase_shift=1.0), + Pulse.ConstantDetuning(wf, 0.0, var), cls=AbstractReprEncoder, ) assert json.loads(s) == dict( amplitude=ser_wf, detuning={"kind": "constant", "duration": 0, "value": 0.0}, phase=ser_var, - post_phase_shift=1.0, + post_phase_shift=0.0, # The default is added ) s = json.dumps( - Pulse.ConstantPulse(var, 2.0, 0.0, 1.0, 1.0), + Pulse.ConstantPulse(var, 2.0, 0.0, 1.0, post_phase_shift=1.0), cls=AbstractReprEncoder, ) assert json.loads(s) == dict( @@ -518,6 +518,58 @@ def test_paramobj_serialization(self, sequence): ): method_call._to_abstract_repr() + # Check the defaults are added when not specified + s = json.dumps( + KaiserWaveform.from_max_val(1.0, var), cls=AbstractReprEncoder + ) + assert json.loads(s) == dict( + kind="kaiser_max", + max_val=1.0, + area=ser_var, + beta=14.0, # The default beta parameter + ) + + s = json.dumps(KaiserWaveform(var, var, var), cls=AbstractReprEncoder) + assert json.loads(s) == dict( + kind="kaiser", + duration=ser_var, + area=ser_var, + beta=ser_var, # The given beta parameter + ) + + s = json.dumps( + InterpolatedWaveform(var, [1, 2, -3]), cls=AbstractReprEncoder + ) + assert json.loads(s) == dict( + kind="interpolated", + duration=ser_var, + values=[1, 2, -3], + times=[0.0, 0.5, 1.0], + ) + + list_var = sequence.declare_variable("list_var", size=3) + ser_list_var = {"variable": "list_var"} + s = json.dumps( + InterpolatedWaveform(var, list_var), cls=AbstractReprEncoder + ) + assert json.loads(s) == dict( + kind="interpolated", + duration=ser_var, + values=ser_list_var, + times=[0.0, 0.5, 1.0], + ) + + err_msg = ( + "An InterpolatedWaveform with 'values' of unknown length " + "and unspecified 'times' can't be serialized to the abstract" + " representation." + ) + with pytest.raises(AbstractReprError, match=err_msg): + json.dumps( + InterpolatedWaveform(1000, np.cos(list_var)), + cls=AbstractReprEncoder, + ) + with pytest.raises( AbstractReprError, match="No abstract representation for 'Foo'" ):