Skip to content

Commit

Permalink
Fix ParamObj serialization with missing keyword arguments (#491)
Browse files Browse the repository at this point in the history
* Write UTs for incoming features

* ParamObj serialization with missing kwargs

* Handle InterpolatedWaveform in ParamObj
  • Loading branch information
HGSilveri authored Mar 30, 2023
1 parent fa57b6a commit b123ac8
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 24 deletions.
8 changes: 2 additions & 6 deletions pulser-core/pulser/json/abstract_repr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,8 @@ def abstract_repr(name: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
res.update(signature.extra) # Starts with extra info ({} if undefined)
# With PulseSignature.all_pos_args(), we safeguard against the opposite
# case where an expected keyword argument is given as a positional argument
res.update(
{
arg_name: arg_val
for arg_name, arg_val in zip(signature.all_pos_args(), args)
}
)
res.update(dict(zip(signature.all_pos_args(), args)))

# Account for keyword arguments given as pos args
max_pos_args = len(signature.pos) + len(
set(signature.keyword) - set(kwargs)
Expand Down
75 changes: 60 additions & 15 deletions pulser-core/pulser/parametrized/paramobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np

import pulser.parametrized
from pulser.json.abstract_repr.serializer import abstract_repr
from pulser.json.abstract_repr.signatures import (
BINARY_OPERATORS,
Expand Down Expand Up @@ -161,6 +162,16 @@ def __init__(self, cls: Callable, *args: Any, **kwargs: Any) -> None:
self._instance = None
self._vars_state: dict[str, int] = {}

@property
def _default_kwargs(self) -> dict[str, Any]:
"""The default values for the object's keyword arguments."""
cls_signature = inspect.signature(self.cls).parameters
return {
param: cls_signature[param].default
for param in cls_signature
if cls_signature[param].default != cls_signature[param].empty
}

@property
def variables(self) -> dict[str, Variable]:
"""Returns all involved variables."""
Expand Down Expand Up @@ -241,19 +252,18 @@ def _to_abstract_repr(self) -> dict[str, Any]:
# classmethod
cls_name = self.args[0].__name__
name = f"{cls_name}.{op_name}"
if cls_name == "Pulse":
signature = (
"amplitude",
"detuning",
"phase",
"post_phase_shift",
)
all_args = {
**dict(zip(signature, self.args[1:])),
**self.kwargs,
}
if "post_phase_shift" not in all_args:
all_args["post_phase_shift"] = 0.0
signature = SIGNATURES[
"Pulse" if cls_name == "Pulse" else name
]
# No existing classmethod has *args in its signature
assert (
signature.var_pos is None
), "Unexpected signature with VAR_POSITIONAL arguments."
all_args = {
**self._default_kwargs,
**dict(zip(signature.all_pos_args(), self.args[1:])),
**self.kwargs,
}
if name == "Pulse.ConstantAmplitude":
all_args["amplitude"] = abstract_repr(
"ConstantWaveform", 0, all_args["amplitude"]
Expand All @@ -265,13 +275,48 @@ def _to_abstract_repr(self) -> dict[str, Any]:
)
return abstract_repr("Pulse", **all_args)
else:
return abstract_repr(name, *self.args[1:], **self.kwargs)
return abstract_repr(name, **all_args)

raise NotImplementedError(
"Instance or static method serialization is not supported."
)
elif op_name in SIGNATURES:
return abstract_repr(op_name, *self.args, **self.kwargs)
signature = SIGNATURES[op_name]
filtered_defaults = {
key: value
for key, value in self._default_kwargs.items()
if key in signature.keyword
}
full_kwargs = {**filtered_defaults, **self.kwargs}
if signature.var_pos is not None:
# No args can be given with a keyword
return abstract_repr(op_name, *self.args, **full_kwargs)

all_args = {
**full_kwargs,
**dict(zip(signature.all_pos_args(), self.args)),
}
if op_name == "InterpolatedWaveform" and all_args["times"] is None:
if isinstance(
all_args["values"],
pulser.parametrized.Variable, # Avoids circular import
):
num_values = all_args["values"].size
else:
try:
num_values = len(all_args["values"])
except TypeError:
raise AbstractReprError(
"An InterpolatedWaveform with 'values' of unknown "
"length and unspecified 'times' can't be "
"serialized to the abstract representation. To "
"keep the same argument for 'values', provide "
"compatible 'times' explicitly."
)

all_args["times"] = np.linspace(0, 1, num=num_values)

return abstract_repr(op_name, **all_args)

elif op_name in UNARY_OPERATORS:
return dict(expression=op_name, lhs=self.args[0])
Expand Down
58 changes: 55 additions & 3 deletions tests/test_abstract_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,18 +490,18 @@ def test_paramobj_serialization(self, sequence):
)

s = json.dumps(
Pulse.ConstantDetuning(wf, 0.0, var, post_phase_shift=1.0),
Pulse.ConstantDetuning(wf, 0.0, var),
cls=AbstractReprEncoder,
)
assert json.loads(s) == dict(
amplitude=ser_wf,
detuning={"kind": "constant", "duration": 0, "value": 0.0},
phase=ser_var,
post_phase_shift=1.0,
post_phase_shift=0.0, # The default is added
)

s = json.dumps(
Pulse.ConstantPulse(var, 2.0, 0.0, 1.0, 1.0),
Pulse.ConstantPulse(var, 2.0, 0.0, 1.0, post_phase_shift=1.0),
cls=AbstractReprEncoder,
)
assert json.loads(s) == dict(
Expand All @@ -518,6 +518,58 @@ def test_paramobj_serialization(self, sequence):
):
method_call._to_abstract_repr()

# Check the defaults are added when not specified
s = json.dumps(
KaiserWaveform.from_max_val(1.0, var), cls=AbstractReprEncoder
)
assert json.loads(s) == dict(
kind="kaiser_max",
max_val=1.0,
area=ser_var,
beta=14.0, # The default beta parameter
)

s = json.dumps(KaiserWaveform(var, var, var), cls=AbstractReprEncoder)
assert json.loads(s) == dict(
kind="kaiser",
duration=ser_var,
area=ser_var,
beta=ser_var, # The given beta parameter
)

s = json.dumps(
InterpolatedWaveform(var, [1, 2, -3]), cls=AbstractReprEncoder
)
assert json.loads(s) == dict(
kind="interpolated",
duration=ser_var,
values=[1, 2, -3],
times=[0.0, 0.5, 1.0],
)

list_var = sequence.declare_variable("list_var", size=3)
ser_list_var = {"variable": "list_var"}
s = json.dumps(
InterpolatedWaveform(var, list_var), cls=AbstractReprEncoder
)
assert json.loads(s) == dict(
kind="interpolated",
duration=ser_var,
values=ser_list_var,
times=[0.0, 0.5, 1.0],
)

err_msg = (
"An InterpolatedWaveform with 'values' of unknown length "
"and unspecified 'times' can't be serialized to the abstract"
" representation."
)
with pytest.raises(AbstractReprError, match=err_msg):
json.dumps(
InterpolatedWaveform(1000, np.cos(list_var)),
cls=AbstractReprEncoder,
)

with pytest.raises(
AbstractReprError, match="No abstract representation for 'Foo'"
):
Expand Down

0 comments on commit b123ac8

Please sign in to comment.