Skip to content

Commit

Permalink
Changes for compatibility with upcoming export format (#353)
Browse files Browse the repository at this point in the history
* Adding new operators to OpSupport

* Removing `__len__` from `Variable` to allow for numpy ufuncs

* Adding missing operations

* Finished new operations + unit tests

* Forbiding variable channels and measurements + mypy

* Adding the new operations to the supported list

* Support for serialization of the new ops

* Changing implementation of `round`

* Changing implementation of `floordiv`

* Increasing coverage of serialization support for ops
  • Loading branch information
HGSilveri authored Apr 1, 2022
1 parent 3ac1f66 commit c34dffb
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 63 deletions.
17 changes: 15 additions & 2 deletions pulser/json/supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,31 @@
"sub",
"mul",
"truediv",
"floordiv",
"pow",
"mod",
)

SUPPORTED_NUMPY = (
"array",
"round_",
"ceil",
"floor",
"sqrt",
"exp",
"log2",
"log",
"sin",
"cos",
"tan",
)

SUPPORTS_SUBMODULE = ("Pulse", "BlackmanWaveform", "KaiserWaveform")

SUPPORTED_MODULES = {
"builtins": SUPPORTED_BUILTINS,
"_operator": SUPPORTED_OPERATORS,
"operator": SUPPORTED_OPERATORS,
"numpy": ("array",),
"numpy": SUPPORTED_NUMPY,
"pulser.register.register": ("Register",),
"pulser.register.register3d": ("Register3D",),
"pulser.register.register_layout": ("RegisterLayout",),
Expand Down
79 changes: 63 additions & 16 deletions pulser/parametrized/paramobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import warnings
from collections.abc import Callable
from itertools import chain
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Union, cast

import numpy as np

from pulser.json.utils import obj_to_dict
from pulser.parametrized import Parametrized
Expand All @@ -32,53 +34,97 @@
class OpSupport:
"""Methods for supporting operators on parametrized objects."""

# Unary operators
def __neg__(self) -> ParamObj:
return ParamObj(operator.neg, self)

def __abs__(self) -> ParamObj:
return ParamObj(operator.abs, self)

def __ceil__(self) -> ParamObj:
return ParamObj(np.ceil, self)

def __floor__(self) -> ParamObj:
return ParamObj(np.floor, self)

def __round__(self, n: int = 0) -> ParamObj:
return cast(ParamObj, (self * 10**n).rint() / 10**n)

def rint(self) -> ParamObj:
"""Rounds the value to the nearest int."""
# Defined because np.round looks for 'rint'
return ParamObj(np.round, self)

def sqrt(self) -> ParamObj:
"""Calculates the square root of the object."""
return ParamObj(np.sqrt, self)

def exp(self) -> ParamObj:
"""Calculates the exponential of the object."""
return ParamObj(np.exp, self)

def log2(self) -> ParamObj:
"""Calculates the base-2 logarithm of the object."""
return ParamObj(np.log2, self)

def log(self) -> ParamObj:
"""Calculates the natural logarithm of the object."""
return ParamObj(np.log, self)

def sin(self) -> ParamObj:
"""Calculates the trigonometric sine of the object."""
return ParamObj(np.sin, self)

def cos(self) -> ParamObj:
"""Calculates the trigonometric cosine of the object."""
return ParamObj(np.cos, self)

def tan(self) -> ParamObj:
"""Calculates the trigonometric tangent of the object."""
return ParamObj(np.tan, self)

# Binary operators
def __add__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__add__, self, other)
return ParamObj(operator.add, self, other)

def __radd__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__add__, other, self)
return ParamObj(operator.add, other, self)

def __sub__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__sub__, self, other)
return ParamObj(operator.sub, self, other)

def __rsub__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__sub__, other, self)
return ParamObj(operator.sub, other, self)

def __mul__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__mul__, self, other)
return ParamObj(operator.mul, self, other)

def __rmul__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__mul__, other, self)
return ParamObj(operator.mul, other, self)

def __truediv__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__truediv__, self, other)
return ParamObj(operator.truediv, self, other)

def __rtruediv__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__truediv__, other, self)
return ParamObj(operator.truediv, other, self)

def __floordiv__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__floordiv__, self, other)
return (self / other).__floor__()

def __rfloordiv__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__floordiv__, other, self)
return (other / self).__floor__()

def __pow__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__pow__, self, other)
return ParamObj(operator.pow, self, other)

def __rpow__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__pow__, other, self)
return ParamObj(operator.pow, other, self)

def __mod__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__mod__, self, other)
return ParamObj(operator.mod, self, other)

def __rmod__(self, other: Union[int, float]) -> ParamObj:
return ParamObj(operator.__mod__, other, self)
return ParamObj(operator.mod, other, self)


class ParamObj(Parametrized, OpSupport):
Expand Down Expand Up @@ -135,8 +181,9 @@ def build(self) -> Any:

def _to_dict(self) -> dict[str, Any]:
def class_to_dict(cls: Callable) -> dict[str, Any]:
module = "numpy" if isinstance(cls, np.ufunc) else cls.__module__
return obj_to_dict(
self, _build=False, _name=cls.__name__, _module=cls.__module__
self, _build=False, _name=cls.__name__, _module=module
)

args = list(self.args)
Expand Down
5 changes: 1 addition & 4 deletions pulser/parametrized/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def _to_dict(self) -> dict[str, Any]:
def __str__(self) -> str:
return self.name

def __len__(self) -> int:
return self.size

def __getitem__(self, key: Union[int, slice]) -> VariableItem:
if not isinstance(key, (int, slice)):
raise TypeError(f"Invalid key type {type(key)} for '{self.name}'.")
Expand All @@ -116,7 +113,7 @@ def __getitem__(self, key: Union[int, slice]) -> VariableItem:
return VariableItem(self, key)

def __iter__(self) -> Iterator[VariableItem]:
for i in range(len(self)):
for i in range(self.size):
yield self[i]


Expand Down
48 changes: 22 additions & 26 deletions pulser/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def declare_variable(
def add(
self,
pulse: Union[Pulse, Parametrized],
channel: Union[str, Parametrized],
channel: str,
protocol: PROTOCOLS = "min-delay",
) -> None:
"""Adds a pulse to a channel.
Expand All @@ -673,9 +673,6 @@ def add(
that idles the channel until the end of the other channels'
latest pulse.
"""
pulse = cast(Pulse, pulse)
channel = cast(str, channel)

self._validate_channel(channel)

valid_protocols = get_args(PROTOCOLS)
Expand Down Expand Up @@ -817,7 +814,7 @@ def add(
def target(
self,
qubits: Union[QubitId, Iterable[QubitId], Parametrized],
channel: Union[str, Parametrized],
channel: str,
) -> None:
"""Changes the target qubit of a 'Local' channel.
Expand All @@ -828,32 +825,24 @@ def target(
channel (str): The channel's name provided when declared. Must be
a channel with 'Local' addressing.
"""
qubits = cast(QubitId, qubits)
channel = cast(str, channel)

self._target(qubits, channel)

@_store
def delay(
self,
duration: Union[int, Parametrized],
channel: Union[str, Parametrized],
channel: str,
) -> None:
"""Idles a given channel for a specific duration.
Args:
duration (int): Time to delay (in multiples of 4 ns).
channel (str): The channel's name provided when declared.
"""
duration = cast(int, duration)
channel = cast(str, channel)

self._delay(duration, channel)

@_store
def measure(
self, basis: Union[str, Parametrized] = "ground-rydberg"
) -> None:
def measure(self, basis: str = "ground-rydberg") -> None:
"""Measures in a valid basis.
Note:
Expand Down Expand Up @@ -893,7 +882,7 @@ def phase_shift(
self,
phi: Union[float, Parametrized],
*targets: Union[QubitId, Parametrized],
basis: Union[str, Parametrized] = "digital",
basis: str = "digital",
) -> None:
r"""Shifts the phase of a qubit's reference by 'phi', for a given basis.
Expand All @@ -909,14 +898,10 @@ def phase_shift(
the phase shift to. Must correspond to the basis of a declared
channel.
"""
phi = cast(float, phi)
basis = cast(str, basis)
targets = cast(Tuple[QubitId], targets)

self._phase_shift(phi, *targets, basis=basis)

@_store
def align(self, *channels: Union[str, Parametrized]) -> None:
def align(self, *channels: str) -> None:
"""Aligns multiple channels in time.
Introduces delays that align the provided channels with the one that
Expand All @@ -942,7 +927,6 @@ def align(self, *channels: Union[str, Parametrized]) -> None:
if self.is_parametrized():
return

channels = cast(Tuple[str], channels)
last_ts = {
id: self.get_duration(id, include_fall_time=True)
for id in channels
Expand Down Expand Up @@ -1187,7 +1171,9 @@ def draw(
plt.show()

def _target(
self, qubits: Union[Iterable[QubitId], QubitId], channel: str
self,
qubits: Union[Iterable[QubitId], QubitId, Parametrized],
channel: str,
) -> None:
self._validate_channel(channel)
channel_obj = self._channels[channel]
Expand Down Expand Up @@ -1263,11 +1249,12 @@ def _target(
self._last_target[channel] = tf
self._add_to_schedule(channel, _TimeSlot("target", ti, tf, qubits_set))

def _delay(self, duration: int, channel: str) -> None:
def _delay(self, duration: Union[int, Parametrized], channel: str) -> None:
self._validate_channel(channel)
if self.is_parametrized():
return

duration = cast(int, duration)
last = self._last(channel)
ti = last.tf
tf = ti + self._channels[channel].validate_duration(duration)
Expand All @@ -1276,7 +1263,10 @@ def _delay(self, duration: int, channel: str) -> None:
)

def _phase_shift(
self, phi: float, *targets: QubitId, basis: str = "digital"
self,
phi: Union[float, Parametrized],
*targets: Union[QubitId, Parametrized],
basis: str,
) -> None:
if basis not in self._phase_ref:
raise ValueError("No declared channel targets the given 'basis'.")
Expand All @@ -1293,7 +1283,8 @@ def _phase_shift(
"All given targets have to be qubit ids declared"
" in this sequence's register."
)

phi = cast(float, phi)
targets = cast(Tuple[QubitId], targets)
if phi % (2 * np.pi) == 0:
return

Expand Down Expand Up @@ -1383,6 +1374,11 @@ def _last(self, channel: str) -> _TimeSlot:
raise ValueError("The chosen channel has no target.")

def _validate_channel(self, channel: str) -> None:
if isinstance(channel, Parametrized):
raise NotImplementedError(
"Using parametrized objects or variables to refer to channels "
"is not supported."
)
if channel not in self._channels:
raise ValueError("Use the name of a declared channel.")

Expand Down
2 changes: 1 addition & 1 deletion pulser/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
# Sets the config as well as builds the hamiltonian
self.set_config(config) if config else self.set_config(SimConfig())
if hasattr(self._seq, "_measurement"):
self._meas_basis = cast(str, self._seq._measurement)
self._meas_basis = self._seq._measurement
else:
if self.basis_name in {"digital", "all"}:
self._meas_basis = "digital"
Expand Down
Loading

0 comments on commit c34dffb

Please sign in to comment.