diff --git a/.flake8 b/.flake8 index 32b17d17..5f8d6eef 100644 --- a/.flake8 +++ b/.flake8 @@ -15,4 +15,5 @@ per-file-ignores = tests/*: D100, D101, D102, D103 __init__.py: F401 pulser-core/pulser/backends.py: F401 + pulser-core/pulser/math/__init__.py: D103 setup.py: D100 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb0e77ad..47ced781 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,6 +59,7 @@ jobs: fail-fast: false matrix: python-version: ["3.8", "3.12"] + with-torch: ["with-torch", "no-torch"] steps: - name: Check out Pulser uses: actions/checkout@v4 @@ -67,8 +68,13 @@ jobs: with: python-version: ${{ matrix.python-version }} extra-packages: pytest + with-torch: ${{ matrix.with-torch }} - name: Run the unit tests & generate coverage report + if: ${{ matrix.with-torch == 'with-torch' }} run: pytest --cov --cov-fail-under=100 + - name: Run the unit tests without torch installed + if: ${{ matrix.with-torch != 'with-torch' }} + run: pytest --cov - name: Test validation with legacy jsonschema run: | pip install jsonschema==4.17.3 diff --git a/.github/workflows/pulser-setup/action.yml b/.github/workflows/pulser-setup/action.yml index ba4677ba..94c171dc 100644 --- a/.github/workflows/pulser-setup/action.yml +++ b/.github/workflows/pulser-setup/action.yml @@ -9,6 +9,10 @@ inputs: description: Extra packages to install (give to grep) required: false default: "" + with-torch: + description: Whether to include pytorch + required: false + default: "with-torch" runs: using: "composite" steps: @@ -17,11 +21,18 @@ runs: with: python-version: ${{ inputs.python-version }} cache: "pip" - - name: Install Pulser + - name: Install Pulser (with torch) + if: ${{ inputs.with-torch == 'with-torch' }} shell: bash run: | python -m pip install --upgrade pip make dev-install + - name: Install Pulser (without torch) + if: ${{ inputs.with-torch != 'with-torch' }} + shell: bash + run: | + python -m pip install --upgrade pip + make dev-install-no-torch - name: Install extra packages from the dev requirements if: "${{ inputs.extra-packages != '' }}" shell: bash diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a293f022..cf79e920 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,6 +18,7 @@ jobs: # Python 3.8 and 3.9 does not run on macos-latest (14) # Uses macos-13 for 3.8 and 3.9 and macos-latest for >=3.10 os: [ubuntu-latest, macos-13, macos-latest, windows-latest] + with-torch: ["with-torch", "no-torch"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] exclude: - os: macos-latest @@ -38,5 +39,6 @@ jobs: with: python-version: ${{ matrix.python-version }} extra-packages: pytest + with-torch: ${{ matrix.with-torch }} - name: Run the unit tests & generate coverage report - run: pytest --cov --cov-fail-under=100 + run: pytest --cov diff --git a/Makefile b/Makefile index fa2aad32..74f2dde9 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,15 @@ .PHONY: dev-install dev-install: dev-install-core dev-install-simulation dev-install-pasqal +.PHONY: dev-install-no-torch +dev-install-no-torch: dev-install-core-no-torch dev-install-simulation dev-install-pasqal + .PHONY: dev-install-core dev-install-core: + pip install -e ./pulser-core[torch] + +.PHONY: dev-install-core-no-torch +dev-install-core-no-torch: pip install -e ./pulser-core .PHONY: dev-install-simulation diff --git a/README.md b/README.md index 65c67774..848ec72c 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,24 @@ If you wish to install only the core ``pulser`` features, you can instead run: pip install pulser-core ``` +### Including PyTorch + +To include PyTorch in your installation, append the ``[torch]`` suffix to the commands outlined above, i.e. + +```bash +pip install pulser[torch] +``` + +for the standard ``pulser`` distribution with PyTorch, **or** + +```bash +pip install pulser-core[torch] +``` + +for just the core features plus PyTorch support. + +### Development install + If you wish to **install the development version of Pulser from source** instead, do the following from within this repository after cloning it: ```bash diff --git a/pulser-core/pulser/channels/base_channel.py b/pulser-core/pulser/channels/base_channel.py index ea07f679..456c1157 100644 --- a/pulser-core/pulser/channels/base_channel.py +++ b/pulser-core/pulser/channels/base_channel.py @@ -23,8 +23,8 @@ import numpy as np from numpy.typing import ArrayLike -from scipy.fft import fft, fftfreq, ifft +import pulser.math as pm from pulser.channels.eom import MODBW_TO_TR, BaseEOM from pulser.json.utils import get_dataclass_defaults, obj_to_dict from pulser.pulse import Pulse @@ -420,22 +420,24 @@ def validate_pulse(self, pulse: Pulse) -> None: f"'pulse' must be of type Pulse, not of type {type(pulse)}." ) - if self.max_amp is not None and np.any( - pulse.amplitude.samples > self.max_amp - ): + amp_samples_np = pulse.amplitude.samples.as_array(detach=True) + if self.max_amp is not None and np.any(amp_samples_np > self.max_amp): raise ValueError( "The pulse's amplitude goes over the maximum " "value allowed for the chosen channel." ) if self.max_abs_detuning is not None and np.any( - np.round(np.abs(pulse.detuning.samples), decimals=6) + np.round( + np.abs(pulse.detuning.samples.as_array(detach=True)), + decimals=6, + ) > self.max_abs_detuning ): raise ValueError( "The pulse's detuning values go out of the range " "allowed for the chosen channel." ) - avg_amp = np.average(pulse.amplitude.samples) + avg_amp = np.average(amp_samples_np) if 0 < avg_amp < self.min_avg_amp: raise ValueError( "The pulse's average amplitude is below the chosen " @@ -453,10 +455,10 @@ def _modulation_padding(self) -> int: def modulate( self, - input_samples: np.ndarray, + input_samples: ArrayLike, keep_ends: bool = False, eom: bool = False, - ) -> np.ndarray: + ) -> pm.AbstractArray: """Modulates the input according to the channel's modulation bandwidth. Args: @@ -482,17 +484,17 @@ def modulate( " 'Channel.modulate()' returns the 'input_samples' unchanged.", stacklevel=2, ) - return input_samples + return pm.AbstractArray(input_samples) else: mod_bandwidth = self.mod_bandwidth mod_padding = self._modulation_padding if keep_ends: - samples = np.pad( + samples = pm.pad( input_samples, mod_padding + self.rise_time, mode="edge" ) else: - samples = np.pad(input_samples, mod_padding) + samples = pm.pad(input_samples, mod_padding) mod_samples = self.apply_modulation(samples, mod_bandwidth) if keep_ends: # Cut off the extra ends @@ -501,8 +503,8 @@ def modulate( @staticmethod def apply_modulation( - input_samples: np.ndarray, mod_bandwidth: float - ) -> np.ndarray: + input_samples: ArrayLike, mod_bandwidth: float + ) -> pm.AbstractArray: """Applies the modulation transfer fuction to the input samples. Note: @@ -516,10 +518,11 @@ def apply_modulation( """ # The cutoff frequency (fc) and the modulation transfer function # are defined in https://tinyurl.com/bdeumc8k + input_samples = pm.AbstractArray(input_samples) fc = mod_bandwidth * 1e-3 / np.sqrt(np.log(2)) - freqs = fftfreq(input_samples.size) - modulation = np.exp(-(freqs**2) / fc**2) - return cast(np.ndarray, ifft(fft(input_samples) * modulation).real) + freqs = pm.fftfreq(input_samples.size) + modulation = pm.exp(-(freqs**2) / fc**2) + return pm.ifft(pm.fft(input_samples) * modulation).real def calc_modulation_buffer( self, @@ -553,8 +556,11 @@ def calc_modulation_buffer( f"The channel {self} doesn't have a modulation bandwidth." ) tr = self.rise_time - samples = np.pad(input_samples, tr) - diffs = np.abs(samples - mod_samples) <= max_allowed_diff + samples = pm.pad(input_samples, tr) + diffs = ( + abs(samples - mod_samples).as_array(detach=True) + <= max_allowed_diff + ) try: # Finds the last index in the start buffer that's below the max # allowed diff. Considers that the waveform could start at the next diff --git a/pulser-core/pulser/channels/dmm.py b/pulser-core/pulser/channels/dmm.py index 2af8faa5..50720d78 100644 --- a/pulser-core/pulser/channels/dmm.py +++ b/pulser-core/pulser/channels/dmm.py @@ -19,6 +19,7 @@ import numpy as np +import pulser.math as pm from pulser.channels.base_channel import Channel from pulser.json.utils import get_dataclass_defaults from pulser.pulse import Pulse @@ -112,7 +113,9 @@ def validate_pulse( (defaults to a detuning map with weight 1.0). """ super().validate_pulse(pulse) - round_detuning = np.round(pulse.detuning.samples, decimals=6) + round_detuning = pm.round(pulse.detuning.samples, 6).as_array( + detach=True + ) # Check that detuning is negative if np.any(round_detuning > 0): raise ValueError("The detuning in a DMM must not be positive.") diff --git a/pulser-core/pulser/channels/eom.py b/pulser-core/pulser/channels/eom.py index 6abba783..0db609ff 100644 --- a/pulser-core/pulser/channels/eom.py +++ b/pulser-core/pulser/channels/eom.py @@ -21,6 +21,7 @@ import numpy as np +import pulser.math as pm from pulser.json.utils import get_dataclass_defaults, obj_to_dict # Conversion factor from modulation bandwith to rise time @@ -210,30 +211,30 @@ def _switching_beams_combos(self) -> list[tuple[RydbergBeam, ...]]: @overload def calculate_detuning_off( self, - amp_on: float, - detuning_on: float, + amp_on: float | pm.TensorLike, + detuning_on: float | pm.TensorLike, optimal_detuning_off: float, return_switching_beams: Literal[False], - ) -> float: + ) -> pm.AbstractArray: pass @overload def calculate_detuning_off( self, - amp_on: float, - detuning_on: float, + amp_on: float | pm.TensorLike, + detuning_on: float | pm.TensorLike, optimal_detuning_off: float, return_switching_beams: Literal[True], - ) -> tuple[float, tuple[RydbergBeam, ...]]: + ) -> tuple[pm.AbstractArray, tuple[RydbergBeam, ...]]: pass def calculate_detuning_off( self, - amp_on: float, - detuning_on: float, + amp_on: float | pm.TensorLike, + detuning_on: float | pm.TensorLike, optimal_detuning_off: float, return_switching_beams: bool = False, - ) -> float | tuple[float, tuple[RydbergBeam, ...]]: + ) -> pm.AbstractArray | tuple[pm.AbstractArray, tuple[RydbergBeam, ...]]: """Calculates the detuning when the amplitude is off in EOM mode. Args: @@ -246,17 +247,19 @@ def calculate_detuning_off( on and off. """ off_options = self.detuning_off_options(amp_on, detuning_on) - closest_option = np.abs(off_options - optimal_detuning_off).argmin() - best_det_off = cast(float, off_options[closest_option]) + closest_option = np.abs( + off_options.as_array(detach=True) - optimal_detuning_off + ).argmin() + best_det_off = off_options[closest_option] if not return_switching_beams: return best_det_off return best_det_off, self._switching_beams_combos[closest_option] def detuning_off_options( self, - rabi_frequency: float, - detuning_on: float, - ) -> np.ndarray: + rabi_frequency: float | pm.TensorLike, + detuning_on: float | pm.TensorLike, + ) -> pm.AbstractArray: """Calculates the possible detuning values when the amplitude is off. Args: @@ -267,11 +270,14 @@ def detuning_off_options( Returns: The possible detuning values when in between pulses. """ + rabi_frequency = pm.AbstractArray(rabi_frequency) # detuning = offset + lightshift # offset takes into account the lightshift when both beams are on # which is not zero when the Rabi freq of both beams is not equal - offset = detuning_on - self._lightshift(rabi_frequency, *RydbergBeam) + offset = pm.AbstractArray(detuning_on) - self._lightshift( + rabi_frequency, *RydbergBeam + ) all_beams: set[RydbergBeam] = set(RydbergBeam) lightshifts = [] for beams_off in self._switching_beams_combos: @@ -280,11 +286,11 @@ def detuning_off_options( lightshifts.append(self._lightshift(rabi_frequency, *beams_on)) # We sum the offset to all lightshifts to get the effective detuning - return np.array(lightshifts) + offset + return pm.flatten(pm.vstack(lightshifts)) + offset def _lightshift( - self, rabi_frequency: float, *beams_on: RydbergBeam - ) -> float: + self, rabi_frequency: pm.AbstractArray, *beams_on: RydbergBeam + ) -> pm.AbstractArray: # lightshift = (rabi_blue**2 - rabi_red**2) / 4 * int_detuning rabi_freqs = self._rabi_freq_per_beam(rabi_frequency) bias = { @@ -292,13 +298,14 @@ def _lightshift( RydbergBeam.BLUE: self.blue_shift_coeff, } # beam off -> beam_rabi_freq = 0 - return sum(bias[beam] * rabi_freqs[beam] ** 2 for beam in beams_on) / ( - 4 * self.intermediate_detuning + return pm.AbstractArray( + sum(bias[beam] * rabi_freqs[beam] ** 2 for beam in beams_on) + / (4 * self.intermediate_detuning) ) def _rabi_freq_per_beam( - self, rabi_frequency: float - ) -> dict[RydbergBeam, float]: + self, rabi_frequency: pm.AbstractArray + ) -> dict[RydbergBeam, pm.AbstractArray]: shift_factor = np.sqrt( self.red_shift_coeff / self.blue_shift_coeff if self.limiting_beam == RydbergBeam.RED @@ -315,14 +322,14 @@ def _rabi_freq_per_beam( if rabi_frequency <= limit_rabi_freq: base_amp_squared = 2 * rabi_frequency * self.intermediate_detuning return { - self.limiting_beam: np.sqrt(base_amp_squared / shift_factor), - ~self.limiting_beam: np.sqrt(base_amp_squared * shift_factor), + self.limiting_beam: pm.sqrt(base_amp_squared / shift_factor), + ~self.limiting_beam: pm.sqrt(base_amp_squared * shift_factor), } # The limiting beam is at its maximum amplitude while the other # has the necessary amplitude to reach the desired effective rabi freq return { - self.limiting_beam: self.max_limiting_amp, + self.limiting_beam: pm.AbstractArray(self.max_limiting_amp), ~self.limiting_beam: 2 * self.intermediate_detuning * rabi_frequency diff --git a/pulser-core/pulser/devices/_device_datacls.py b/pulser-core/pulser/devices/_device_datacls.py index 472f373f..c4e26051 100644 --- a/pulser-core/pulser/devices/_device_datacls.py +++ b/pulser-core/pulser/devices/_device_datacls.py @@ -17,12 +17,14 @@ import json from abc import ABC, abstractmethod from collections import Counter +from collections.abc import Mapping from dataclasses import dataclass, field, fields from typing import Any, Literal, cast, get_args import numpy as np -from scipy.spatial.distance import pdist, squareform +from scipy.spatial.distance import squareform +import pulser.math as pm from pulser.channels.base_channel import Channel, States, get_states_from_bases from pulser.channels.dmm import DMM from pulser.devices.interaction_coefficients import c6_dict @@ -386,7 +388,7 @@ def validate_layout_filling( f"{max_qubits} qubits." ) - def _validate_atom_number(self, coords: list[np.ndarray]) -> None: + def _validate_atom_number(self, coords: list[pm.AbstractArray]) -> None: max_atom_num = cast(int, self.max_atom_num) if len(coords) > max_atom_num: raise ValueError( @@ -397,7 +399,7 @@ def _validate_atom_number(self, coords: list[np.ndarray]) -> None: ) def _validate_atom_distance( - self, ids: list[QubitId], coords: list[np.ndarray], kind: str + self, ids: list[QubitId], coords: list[pm.AbstractArray], kind: str ) -> None: def invalid_dists(dists: np.ndarray) -> np.ndarray: cond1 = dists - self.min_atom_distance < -( @@ -409,9 +411,11 @@ def invalid_dists(dists: np.ndarray) -> np.ndarray: return cast(np.ndarray, np.logical_or(cond1, cond2)) if len(coords) > 1: - distances = pdist(coords) # Pairwise distance between atoms - if np.any(invalid_dists(distances)): - sq_dists = squareform(distances) + distances = pm.pdist( + pm.vstack(coords) + ) # Pairwise distance between atoms + if np.any(invalid_dists(distances.as_array(detach=True))): + sq_dists = squareform(distances.as_array(detach=True)) mask = np.triu(np.ones(len(coords), dtype=bool), k=1) bad_pairs = np.argwhere( np.logical_and(invalid_dists(sq_dists), mask) @@ -425,9 +429,12 @@ def invalid_dists(dists: np.ndarray) -> np.ndarray: ) def _validate_radial_distance( - self, ids: list[QubitId], coords: list[np.ndarray], kind: str + self, ids: list[QubitId], coords: list[pm.AbstractArray], kind: str ) -> None: - too_far = np.linalg.norm(coords, axis=1) > self.max_radial_distance + too_far = ( + np.linalg.norm(pm.vstack(coords).as_array(detach=True), axis=1) + > self.max_radial_distance + ) if np.any(too_far): raise ValueError( f"All {kind} must be at most {self.max_radial_distance} μm " @@ -452,10 +459,14 @@ def _params(self, init_only: bool = False) -> dict[str, Any]: } def _validate_coords( - self, coords_dict: dict[QubitId, np.ndarray], kind: str = "atoms" + self, + coords_dict: ( + Mapping[QubitId, pm.AbstractArray] | Mapping[int, np.ndarray] + ), + kind: Literal["atoms", "traps"] = "atoms", ) -> None: ids = list(coords_dict.keys()) - coords = list(coords_dict.values()) + coords = list(map(pm.AbstractArray, coords_dict.values())) if kind == "atoms" and not ( "max_atom_num" in self._optional_parameters and self.max_atom_num is None diff --git a/pulser-core/pulser/json/supported.py b/pulser-core/pulser/json/supported.py index 597fbcb2..5a0c04a9 100644 --- a/pulser-core/pulser/json/supported.py +++ b/pulser-core/pulser/json/supported.py @@ -62,6 +62,8 @@ "_operator": SUPPORTED_OPERATORS, "operator": SUPPORTED_OPERATORS, "numpy": SUPPORTED_NUMPY, + "pulser.math": SUPPORTED_NUMPY, # Numpy funcs replicated in pulser.math + "pulser.math.abstract_array": ("AbstractArray",), "pulser.register.register": ("Register",), "pulser.register.register3d": ("Register3D",), "pulser.register.register_layout": ("RegisterLayout",), diff --git a/pulser-core/pulser/math/__init__.py b/pulser-core/pulser/math/__init__.py new file mode 100644 index 00000000..d33d4aa3 --- /dev/null +++ b/pulser-core/pulser/math/__init__.py @@ -0,0 +1,242 @@ +# Copyright 2024 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom implementation of math and array functions.""" +from __future__ import annotations + +from collections.abc import Sequence +from typing import cast, Protocol, TypeVar + +import numpy as np +import scipy.fft + +from pulser.math.abstract_array import ( + AbstractArray as AbstractArray, + AbstractArrayLike, +) + +try: + import torch +except ImportError: # pragma: no cover + pass + + +T = TypeVar("T", covariant=True) + + +class TensorLike(Protocol[T]): + """A type hint to signal that a parameter behaves like a torch Tensor.""" + + def detach(self: T) -> T: ... # noqa: D102 + + def __array__(self) -> np.ndarray: ... + + +# Custom function definitions + + +def exp(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.exp(a.as_tensor())) + return AbstractArray(np.exp(a.as_array())) + + +def sqrt(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.sqrt(a.as_tensor())) + return AbstractArray(np.sqrt(a.as_array())) + + +def log2(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.log2(a.as_tensor())) + return AbstractArray(np.log2(a.as_array())) + + +def log(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.log(a.as_tensor())) + return AbstractArray(np.log(a.as_array())) + + +def sin(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.sin(a.as_tensor())) + return AbstractArray(np.sin(a.as_array())) + + +def cos(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.cos(a.as_tensor())) + return AbstractArray(np.cos(a.as_array())) + + +def tan(a: AbstractArrayLike, /) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.tan(a.as_tensor())) + return AbstractArray(np.tan(a.as_array())) + + +def pad( + a: AbstractArrayLike, + pad_width: tuple | int, + mode: str = "constant", + constant_values: tuple | int | float = 0, +) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + t = cast(torch.Tensor, a._array) + if isinstance(pad_width, (int, float)): + pad_width = (pad_width, pad_width) + if mode == "constant": + if isinstance(constant_values, (int, float)): + out = torch.nn.functional.pad( + t, pad_width, "constant", constant_values + ) + else: + out = torch.nn.functional.pad( + t, (pad_width[0], 0), "constant", constant_values[0] + ) + out = torch.nn.functional.pad( + out, (0, pad_width[1]), "constant", constant_values[1] + ) + elif mode == "edge": + out = torch.nn.functional.pad( + t, (pad_width[0], 0), "constant", float(t[0]) + ) + out = torch.nn.functional.pad( + out, (0, pad_width[1]), "constant", float(t[-1]) + ) + return AbstractArray(out) + + arr = cast(np.ndarray, a._array) + kwargs = ( + dict(constant_values=constant_values) if mode == "constant" else {} + ) + return AbstractArray( + np.pad(arr, pad_width, mode, **kwargs), # type: ignore[call-overload] + ) + + +def fft(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.fft.fft(a.as_tensor())) + return AbstractArray(scipy.fft.fft(a.as_array())) + + +def ifft(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.fft.ifft(a.as_tensor())) + return AbstractArray(scipy.fft.ifft(a.as_array())) + + +def fftfreq(n: int) -> AbstractArray: + return AbstractArray(scipy.fft.fftfreq(n)) + + +def round(a: AbstractArrayLike, decimals: int = 0) -> AbstractArray: + return AbstractArray(a).__round__(decimals) + + +def ceil(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.ceil(a.as_tensor())) + return AbstractArray(np.ceil(a.as_array())) + + +def floor(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.floor(a.as_tensor())) + return AbstractArray(np.floor(a.as_array())) + + +def mean(a: AbstractArrayLike, axis: int | None = None) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.mean(a.as_tensor(), dim=axis)) + return AbstractArray(np.mean(a.as_array(), axis=axis)) + + +def sum(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.sum(a.as_tensor())) + return AbstractArray(np.sum(a.as_array())) + + +def cumsum(a: AbstractArrayLike, axis: int = 0) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.cumsum(a.as_tensor(), dim=axis)) + return AbstractArray(np.cumsum(a.as_array(), axis=axis)) + + +def diff(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.diff(a.as_tensor())) + return AbstractArray(np.diff(a.as_array())) + + +def dot(a: AbstractArrayLike, b: AbstractArrayLike) -> AbstractArray: + a, b = map(AbstractArray, (a, b)) + if a.is_tensor or b.is_tensor: + return AbstractArray(torch.dot(a.as_tensor(), b.as_tensor())) + return AbstractArray(np.dot(a.as_array(), b.as_array())) + + +def pdist(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.nn.functional.pdist(a.as_tensor())) + return AbstractArray(scipy.spatial.distance.pdist(a.as_array())) + + +def concatenate(arrs: Sequence[AbstractArrayLike]) -> AbstractArray: + abst_arrs = tuple(map(AbstractArray, arrs)) + if any(a.is_tensor for a in abst_arrs): + return AbstractArray(torch.cat([a.as_tensor() for a in abst_arrs])) + return AbstractArray(np.concatenate([a.as_array() for a in abst_arrs])) + + +def vstack(arrs: Sequence[AbstractArrayLike]) -> AbstractArray: + abst_arrs = tuple(map(AbstractArray, arrs)) + if any(a.is_tensor for a in abst_arrs): + return AbstractArray(torch.vstack([a.as_tensor() for a in abst_arrs])) + return AbstractArray(np.vstack([a.as_array() for a in abst_arrs])) + + +def hstack(arrs: Sequence[AbstractArrayLike]) -> AbstractArray: + abst_arrs = tuple(map(AbstractArray, arrs)) + if any(a.is_tensor for a in abst_arrs): + return AbstractArray(torch.hstack([a.as_tensor() for a in abst_arrs])) + return AbstractArray(np.hstack([a.as_array() for a in abst_arrs])) + + +def flatten(a: AbstractArrayLike) -> AbstractArray: + a = AbstractArray(a) + if a.is_tensor: + return AbstractArray(torch.flatten(a.as_tensor())) + return AbstractArray(a.as_array().flatten()) diff --git a/pulser-core/pulser/math/abstract_array.py b/pulser-core/pulser/math/abstract_array.py new file mode 100644 index 00000000..c74805a6 --- /dev/null +++ b/pulser-core/pulser/math/abstract_array.py @@ -0,0 +1,312 @@ +# Copyright 2024 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the AbstractArray class.""" +from __future__ import annotations + +import functools +import importlib.util +import operator +from typing import Any, Generator, Union, cast + +import numpy as np +from numpy.typing import ArrayLike, DTypeLike + +from pulser.json.utils import obj_to_dict + +try: + import torch +except ImportError: # pragma: no cover + pass + + +class AbstractArray: + """An abstract array containing an array or tensor. + + Args: + array: The array to store. + dtype: The data type of the array. + force_array: Forces the array to be at least 1D. + """ + + def __init__( + self, + array: AbstractArrayLike, + dtype: DTypeLike = None, + force_array: bool = False, + ): + """Initializes a new AbstractArray.""" + self._array: np.ndarray | torch.Tensor + if isinstance(array, AbstractArray): + self._array = array._array + elif self.has_torch() and isinstance(array, torch.Tensor): + self._array = torch.as_tensor( + array, + dtype=dtype, # type: ignore[arg-type] + ) + else: + self._array = np.asarray(array, dtype=dtype) + + if force_array and self._array.ndim == 0: + self._array = self._array[None] + + @staticmethod + @functools.lru_cache + def has_torch() -> bool: + """Checks whether torch is installed.""" + return importlib.util.find_spec("torch") is not None + + @functools.cached_property + def is_tensor(self) -> bool: + """Whether the stored array is a tensor.""" + return self.has_torch() and isinstance(self._array, torch.Tensor) + + def astype(self, dtype: DTypeLike) -> AbstractArray: + """Casts the data type of the array contents.""" + if self.is_tensor: + return AbstractArray( + cast(torch.Tensor, self._array).to( + dtype=dtype # type: ignore[arg-type] + ) + ) + return AbstractArray(cast(np.ndarray, self._array).astype(dtype)) + + def as_tensor(self) -> torch.Tensor: + """Converts the stored array to a torch Tensor.""" + if not self.has_torch(): + raise RuntimeError("`torch` is not installed.") + return torch.as_tensor(self._array) + + def as_array(self, *, detach: bool = False) -> np.ndarray: + """Converts the stored array to a Numpy array. + + Args: + detach: Whether to detach before converting. + """ + if detach and self.is_tensor: + return cast(torch.Tensor, self._array).detach().numpy() + return np.asarray(self._array) + + def tolist(self) -> list: + """Converts the stored array to a Python list.""" + return self._array.tolist() + + def copy(self) -> AbstractArray: + """Makes a copy itself.""" + return AbstractArray( + cast(torch.Tensor, self._array).clone() + if self.is_tensor + else cast(np.ndarray, self._array).copy() + ) + + @property + def size(self) -> int: + """The number of elements in the array.""" + return int(np.prod(self._array.shape)) + + @property + def ndim(self) -> int: + """The number of dimensions in the array.""" + return self._array.ndim + + @property + def shape(self) -> tuple[int, ...]: + """Shape of the array.""" + return self._array.shape + + @property + def real(self) -> AbstractArray: + """The real part of each element in the array.""" + return AbstractArray(self._array.real) + + @property + def dtype(self) -> Any: + """The data type of the array elements.""" + return self._array.dtype + + def detach(self) -> AbstractArray: + """Detaches the data from the computational graph. + + Analogous to torch.Tensor.detach(). + """ + if self.is_tensor: + return AbstractArray(cast(torch.Tensor, self._array).detach()) + return self + + def __array__(self, dtype: Any = None) -> np.ndarray: + return self._array.__array__(dtype) + + def __repr__(self) -> str: + return str(self._array.__repr__()) + + def __int__(self) -> int: + return int(self._array) + + def __float__(self) -> float: + return float(self._array) + + def __bool__(self) -> bool: + return bool(self._array) + + # Unary operators + def __neg__(self) -> AbstractArray: + return AbstractArray(-self._array) + + def __abs__(self) -> AbstractArray: + return AbstractArray(cast(ArrayLike, abs(self._array))) + + def __round__(self, decimals: int = 0, /) -> AbstractArray: + return AbstractArray( + torch.round(cast(torch.Tensor, self._array), decimals=decimals) + if self.is_tensor + else np.round(cast(np.ndarray, self._array), decimals=decimals) + ) + + def _binary_operands( + self, other: AbstractArrayLike + ) -> tuple[np.ndarray, np.ndarray] | tuple[torch.Tensor, torch.Tensor]: + other = AbstractArray(other) + if self.is_tensor or other.is_tensor: + return self.as_tensor(), other.as_tensor() + return self.as_array(), other.as_array() + + # Comparison operators + + def __lt__(self, other: AbstractArrayLike) -> AbstractArray: + return AbstractArray(operator.lt(*self._binary_operands(other))) + + def __le__(self, other: AbstractArrayLike) -> AbstractArray: + return AbstractArray(operator.le(*self._binary_operands(other))) + + def __gt__(self, other: AbstractArrayLike) -> AbstractArray: + return AbstractArray(operator.gt(*self._binary_operands(other))) + + def __ge__(self, other: AbstractArrayLike) -> AbstractArray: + return AbstractArray(operator.ge(*self._binary_operands(other))) + + def __eq__(self, other: Any) -> AbstractArray: # type: ignore[override] + return AbstractArray(operator.eq(*self._binary_operands(other))) + + def __ne__(self, other: Any) -> AbstractArray: # type: ignore[override] + return AbstractArray(operator.ne(*self._binary_operands(other))) + + # Binary operators + def __add__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.add(*self._binary_operands(other))) + + def __radd__(self, other: ArrayLike, /) -> AbstractArray: + return self.__add__(other) + + def __mul__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.mul(*self._binary_operands(other))) + + def __rmul__(self, other: ArrayLike, /) -> AbstractArray: + return self.__mul__(other) + + def __sub__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.sub(*self._binary_operands(other))) + + def __rsub__(self, other: ArrayLike, /) -> AbstractArray: + return AbstractArray(operator.sub(*self._binary_operands(other)[::-1])) + + def __truediv__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.truediv(*self._binary_operands(other))) + + def __rtruediv__(self, other: ArrayLike, /) -> AbstractArray: + return AbstractArray( + operator.truediv(*self._binary_operands(other)[::-1]) + ) + + def __floordiv__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.floordiv(*self._binary_operands(other))) + + def __rfloordiv__(self, other: ArrayLike, /) -> AbstractArray: + return AbstractArray( + operator.floordiv(*self._binary_operands(other)[::-1]) + ) + + def __pow__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.pow(*self._binary_operands(other))) + + def __rpow__(self, other: ArrayLike, /) -> AbstractArray: + return AbstractArray(operator.pow(*self._binary_operands(other)[::-1])) + + def __mod__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.mod(*self._binary_operands(other))) + + def __rmod__(self, other: ArrayLike, /) -> AbstractArray: + return AbstractArray(operator.mod(*self._binary_operands(other)[::-1])) + + def __matmul__(self, other: AbstractArrayLike, /) -> AbstractArray: + return AbstractArray(operator.matmul(*self._binary_operands(other))) + + def __rmatmul__(self, other: ArrayLike, /) -> AbstractArray: + return AbstractArray( + operator.matmul(*self._binary_operands(other)[::-1]) + ) + + def _process_indices(self, indices: Any) -> Any: + try: + return indices.tolist() + except Exception: + return indices + + def __getitem__(self, indices: Any) -> AbstractArray: + return AbstractArray(self._array[self._process_indices(indices)]) + + def __setitem__(self, indices: Any, values: AbstractArrayLike) -> None: + array, values = self._binary_operands(values) + try: + array[ + self._process_indices(indices) + ] = values # type: ignore[assignment] + except RuntimeError as e: + if ( + self.is_tensor + and cast(torch.Tensor, self._array).requires_grad + ): + raise RuntimeError( + "Failed to modify a tensor that requires grad in place." + ) from e + else: # pragma: no cover + raise e + self._array = array + del self.is_tensor # Clears cache + + def __iter__(self) -> Generator[AbstractArray, None, None]: + for i in range(self.__len__()): + yield self.__getitem__(i) + + def __len__(self) -> int: + return len(self._array) + + def _to_dict(self) -> dict[str, Any]: + try: + return obj_to_dict(self, self.as_array()) + except RuntimeError as e: + raise NotImplementedError( + "A tensor that requires grad can't be serialized without" + " losing the computational graph information." + ) from e + + def _to_abstract_repr(self) -> Any: + try: + return self.as_array().tolist() + except RuntimeError as e: + raise NotImplementedError( + "A tensor that requires grad can't be serialized without" + " losing the computational graph information." + ) from e + + +AbstractArrayLike = Union[AbstractArray, ArrayLike] diff --git a/pulser-core/pulser/parametrized/paramobj.py b/pulser-core/pulser/parametrized/paramobj.py index 0815fd00..a3b70387 100644 --- a/pulser-core/pulser/parametrized/paramobj.py +++ b/pulser-core/pulser/parametrized/paramobj.py @@ -24,6 +24,7 @@ import numpy as np +import pulser.math as pm import pulser.parametrized from pulser.json.abstract_repr.serializer import abstract_repr from pulser.json.abstract_repr.signatures import ( @@ -50,10 +51,10 @@ def __abs__(self) -> ParamObj: return ParamObj(operator.abs, self) def __ceil__(self) -> ParamObj: - return ParamObj(np.ceil, self) + return ParamObj(pm.ceil, self) def __floor__(self) -> ParamObj: - return ParamObj(np.floor, self) + return ParamObj(pm.floor, self) def __round__(self, n: int = 0) -> ParamObj: return cast(ParamObj, (self * 10**n).rint() / 10**n) @@ -61,35 +62,35 @@ def __round__(self, n: int = 0) -> ParamObj: def rint(self) -> ParamObj: """Rounds the value to the nearest int.""" # Defined because np.round looks for 'rint' - return ParamObj(np.round, self) + return ParamObj(pm.round, self) def sqrt(self) -> ParamObj: """Calculates the square root of the object.""" - return ParamObj(np.sqrt, self) + return ParamObj(pm.sqrt, self) def exp(self) -> ParamObj: """Calculates the exponential of the object.""" - return ParamObj(np.exp, self) + return ParamObj(pm.exp, self) def log2(self) -> ParamObj: """Calculates the base-2 logarithm of the object.""" - return ParamObj(np.log2, self) + return ParamObj(pm.log2, self) def log(self) -> ParamObj: """Calculates the natural logarithm of the object.""" - return ParamObj(np.log, self) + return ParamObj(pm.log, self) def sin(self) -> ParamObj: """Calculates the trigonometric sine of the object.""" - return ParamObj(np.sin, self) + return ParamObj(pm.sin, self) def cos(self) -> ParamObj: """Calculates the trigonometric cosine of the object.""" - return ParamObj(np.cos, self) + return ParamObj(pm.cos, self) def tan(self) -> ParamObj: """Calculates the trigonometric tangent of the object.""" - return ParamObj(np.tan, self) + return ParamObj(pm.tan, self) # Binary operators def __add__(self, other: Union[int, float], /) -> ParamObj: @@ -210,8 +211,10 @@ def class_to_dict(cls: Callable) -> dict[str, Any]: "Serialization of calls to parametrized objects is not " "supported." ) - elif hasattr(args[0], self.cls.__name__) and inspect.isfunction( - self.cls + elif ( + hasattr(args[0], self.cls.__name__) + and inspect.isfunction(self.cls) + and self.cls.__module__ != "pulser.math" ): # Check for parametrized methods if inspect.isclass(self.args[0]): @@ -245,6 +248,7 @@ def _to_abstract_repr(self) -> dict[str, Any]: self.args # If it is a classmethod the first arg will be the class and hasattr(self.args[0], op_name) and inspect.isfunction(self.cls) + and not self.cls.__module__ == "pulser.math" ): # Check for parametrized methods if inspect.isclass(self.args[0]): @@ -279,7 +283,6 @@ def _to_abstract_repr(self) -> dict[str, Any]: return abstract_repr("Pulse", **all_args) else: return abstract_repr(name, **all_args) - raise NotImplementedError( "Instance or static method serialization is not supported." ) diff --git a/pulser-core/pulser/parametrized/variable.py b/pulser-core/pulser/parametrized/variable.py index 63b08b66..cddf316a 100644 --- a/pulser-core/pulser/parametrized/variable.py +++ b/pulser-core/pulser/parametrized/variable.py @@ -17,11 +17,12 @@ import collections.abc as abc # To use collections.abc.Sequence import dataclasses -from typing import Any, Iterator, Optional, Union, cast +from typing import Any, Iterator, Union import numpy as np from numpy.typing import ArrayLike +import pulser.math as pm from pulser.json.utils import obj_to_dict from pulser.parametrized import Parametrized from pulser.parametrized.paramobj import OpSupport @@ -72,8 +73,8 @@ def _assign(self, value: Union[ArrayLike, float, int]) -> None: def _validate_value( self, value: Union[ArrayLike, float, int] - ) -> np.ndarray: - val = np.array(value, dtype=self.dtype, ndmin=1) + ) -> pm.AbstractArray: + val = pm.AbstractArray(value, dtype=self.dtype, force_array=True) if val.size != self.size: raise ValueError( f"Can't assign array of size {val.size} to " @@ -81,9 +82,9 @@ def _validate_value( ) return val - def build(self) -> ArrayLike: + def build(self) -> pm.AbstractArray: """Returns the variable's current value.""" - self.value: Optional[ArrayLike] + self.value: pm.AbstractArray | None if self.value is None: raise ValueError(f"No value assigned to variable '{self.name}'.") return self.value @@ -147,12 +148,9 @@ def variables(self) -> dict[str, Variable]: """All the variables involved with this object.""" return self.var.variables - def build(self) -> Union[ArrayLike, float, int]: + def build(self) -> pm.AbstractArray: """Return the variable's item(s) values.""" - built_var = cast(abc.Sequence, self.var.build()) - if isinstance(self.key, abc.Sequence): - return [built_var[k] for k in self.key] - return built_var[self.key] + return self.var.build()[self.key] def _to_dict(self) -> dict[str, Any]: return obj_to_dict( diff --git a/pulser-core/pulser/pulse.py b/pulser-core/pulser/pulse.py index 7a94c481..8bf05b95 100644 --- a/pulser-core/pulser/pulse.py +++ b/pulser-core/pulser/pulse.py @@ -18,12 +18,13 @@ import functools import itertools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast import matplotlib.pyplot as plt import numpy as np import pulser +import pulser.math as pm from pulser.json.abstract_repr.serializer import abstract_repr from pulser.json.utils import obj_to_dict from pulser.parametrized import Parametrized, ParamObj @@ -75,7 +76,7 @@ class Pulse: amplitude: Waveform = field(init=False) detuning: Waveform = field(init=False) - phase: float = field(init=False) + phase: pm.AbstractArray = field(init=False) post_phase_shift: float = field(default=0.0, init=False) def __new__(cls, *args, **kwargs): # type: ignore @@ -88,10 +89,10 @@ def __new__(cls, *args, **kwargs): # type: ignore def __init__( self, - amplitude: Union[Waveform, Parametrized], - detuning: Union[Waveform, Parametrized], - phase: Union[float, Parametrized], - post_phase_shift: Union[float, Parametrized] = 0.0, + amplitude: Waveform | Parametrized, + detuning: Waveform | Parametrized, + phase: float | pm.TensorLike | Parametrized, + post_phase_shift: float | Parametrized = 0.0, ): """Initializes a new Pulse.""" if not ( @@ -103,15 +104,17 @@ def __init__( raise ValueError( "The duration of detuning and amplitude waveforms must match." ) - if np.any(amplitude.samples < 0): + if np.any(amplitude.samples.as_array(detach=True) < 0): raise ValueError( "All samples of an amplitude waveform must be " "greater than or equal to zero." ) object.__setattr__(self, "amplitude", amplitude) object.__setattr__(self, "detuning", detuning) - phase = cast(float, phase) - object.__setattr__(self, "phase", float(phase) % (2 * np.pi)) + assert not isinstance(phase, Parametrized) + if (phase_ := pm.AbstractArray(phase, dtype=float)).size != 1: + raise TypeError(f"'phase' must be a single float, not {phase!r}.") + object.__setattr__(self, "phase", phase_ % (2 * np.pi)) post_phase_shift = cast(float, post_phase_shift) object.__setattr__( self, "post_phase_shift", float(post_phase_shift) % (2 * np.pi) @@ -126,10 +129,10 @@ def duration(self) -> int: @parametrize def ConstantDetuning( cls, - amplitude: Union[Waveform, Parametrized], - detuning: Union[float, Parametrized], - phase: Union[float, Parametrized], - post_phase_shift: Union[float, Parametrized] = 0.0, + amplitude: Waveform | Parametrized, + detuning: float | pm.TensorLike | Parametrized, + phase: float | pm.TensorLike | Parametrized, + post_phase_shift: float | Parametrized = 0.0, ) -> Pulse: """Creates a Pulse with an amplitude waveform and a constant detuning. @@ -149,10 +152,10 @@ def ConstantDetuning( @parametrize def ConstantAmplitude( cls, - amplitude: Union[float, Parametrized], - detuning: Union[Waveform, Parametrized], - phase: Union[float, Parametrized], - post_phase_shift: Union[float, Parametrized] = 0.0, + amplitude: float | pm.TensorLike | Parametrized, + detuning: Waveform | Parametrized, + phase: float | pm.TensorLike | Parametrized, + post_phase_shift: float | Parametrized = 0.0, ) -> Pulse: """Pulse with a constant amplitude and a detuning waveform. @@ -171,11 +174,11 @@ def ConstantAmplitude( @classmethod def ConstantPulse( cls, - duration: Union[int, Parametrized], - amplitude: Union[float, Parametrized], - detuning: Union[float, Parametrized], - phase: Union[float, Parametrized], - post_phase_shift: Union[float, Parametrized] = 0.0, + duration: int | Parametrized, + amplitude: float | pm.TensorLike | Parametrized, + detuning: float | pm.TensorLike | Parametrized, + phase: float | pm.TensorLike | Parametrized, + post_phase_shift: float | Parametrized = 0.0, ) -> Pulse: """Pulse with a constant amplitude and a constant detuning. @@ -236,15 +239,15 @@ def ArbitraryPhase( if isinstance(phase, ConstantWaveform): detuning = ConstantWaveform(phase.duration, 0.0) elif isinstance(phase, RampWaveform): - detuning = ConstantWaveform(phase.duration, -phase.slope * 1e3) + detuning = ConstantWaveform(phase.duration, -phase._slope * 1e3) else: - detuning_samples = -np.diff(phase.samples) * 1e3 # rad/ns->rad/µs + detuning_samples = -pm.diff(phase.samples) * 1e3 # rad/ns->rad/µs # Use the same value in the first two detuning samples detuning = CustomWaveform( - np.pad(detuning_samples, (1, 0), mode="edge") + pm.pad(detuning_samples, (1, 0), mode="edge") ) # Adjust phase_c to incorporate the first detuning sample - phase_c = phase.first_value + detuning.first_value * 1e-3 + phase_c = phase[0] + detuning[0] * 1e-3 return cls(amplitude, detuning, phase_c, post_phase_shift) def draw(self) -> None: @@ -319,15 +322,15 @@ def __str__(self) -> str: return ( f"Pulse(Amp={self.amplitude!s} rad/µs, " f"Detuning={self.detuning!s} rad/µs, " - f"Phase={self.phase:.3g})" + f"Phase={float(self.phase):.3g})" ) def __repr__(self) -> str: return ( f"Pulse(amp={self.amplitude!r} rad/µs, " f"detuning={self.detuning!r} rad/µs, " - f"phase={self.phase:.3g}, " - f"post_phase_shift={self.post_phase_shift:.3g})" + f"phase={float(self.phase):.3g}, " + f"post_phase_shift={float(self.post_phase_shift):.3g})" ) def __eq__(self, other: Any) -> bool: @@ -346,7 +349,7 @@ def check_phase_eq(phase1: float, phase2: float) -> np.bool_: return bool( self.amplitude == other.amplitude and self.detuning == other.detuning - and check_phase_eq(self.phase, other.phase) + and check_phase_eq(float(self.phase), float(other.phase)) and check_phase_eq(self.post_phase_shift, other.post_phase_shift) ) diff --git a/pulser-core/pulser/register/_coordinates.py b/pulser-core/pulser/register/_coordinates.py index 575e65cd..404375a3 100644 --- a/pulser-core/pulser/register/_coordinates.py +++ b/pulser-core/pulser/register/_coordinates.py @@ -3,12 +3,15 @@ from __future__ import annotations import hashlib +from collections.abc import Sequence from dataclasses import dataclass from functools import cached_property from typing import cast import numpy as np +import pulser.math as pm + COORD_PRECISION = 6 @@ -24,7 +27,7 @@ class CoordsCollection: _coords: The coordinates. """ - _coords: np.ndarray | list + _coords: pm.AbstractArray | list @property def dimensionality(self) -> int: @@ -35,22 +38,27 @@ def dimensionality(self) -> int: def sorted_coords(self) -> np.ndarray: """The sorted coordinates.""" # Copies to prevent direct access to self._sorted_coords - return self._sorted_coords.copy() + return self._sorted_coords.as_array(detach=True).copy() + + @cached_property + def _coords_arr(self) -> pm.AbstractArray: + return pm.vstack(cast(Sequence, self._coords)) + + @cached_property + def _rounded_coords(self) -> pm.AbstractArray: + return pm.round(self._coords_arr, decimals=COORD_PRECISION) @cached_property # Acts as an attribute in a frozen dataclass - def _sorted_coords(self) -> np.ndarray: - coords = np.array(self._coords, dtype=float) - rounded_coords = np.round(coords, decimals=COORD_PRECISION) + def _sorted_coords(self) -> pm.AbstractArray: sorting = self._calc_sorting_order() - return cast(np.ndarray, rounded_coords[sorting]) + return self._rounded_coords[sorting] def _calc_sorting_order(self) -> np.ndarray: """Calculates the unique order that sorts the coordinates.""" - coords = np.array(self._coords, dtype=float) # Sorting the coordinates 1st left to right, 2nd bottom to top - rounded_coords = np.round(coords, decimals=COORD_PRECISION) - dims = rounded_coords.shape[1] - sorter = [rounded_coords[:, i] for i in range(dims - 1, -1, -1)] + dims = self._rounded_coords.shape[1] + arr = self._rounded_coords.as_array(detach=True) + sorter = [arr[:, i] for i in range(dims - 1, -1, -1)] sorting = np.lexsort(tuple(sorter)) return cast(np.ndarray, sorting) diff --git a/pulser-core/pulser/register/_reg_drawer.py b/pulser-core/pulser/register/_reg_drawer.py index 298e9886..f0ed2701 100644 --- a/pulser-core/pulser/register/_reg_drawer.py +++ b/pulser-core/pulser/register/_reg_drawer.py @@ -353,7 +353,7 @@ def _register_dims( draw_half_radius: bool = False, ) -> np.ndarray: """Returns the dimensions of the register to be drawn.""" - diffs = np.ptp(pos, axis=0) + diffs = np.ptp(pos, axis=0).astype(float) diffs[diffs < 9] *= 1.5 diffs[diffs < 9] += 2 if blockade_radius and draw_half_radius: diff --git a/pulser-core/pulser/register/base_register.py b/pulser-core/pulser/register/base_register.py index eb03c597..d01253db 100644 --- a/pulser-core/pulser/register/base_register.py +++ b/pulser-core/pulser/register/base_register.py @@ -33,6 +33,7 @@ import numpy as np from numpy.typing import ArrayLike +import pulser.math as pm from pulser.json.abstract_repr.serializer import AbstractReprEncoder from pulser.json.abstract_repr.validation import validate_abstract_repr from pulser.json.utils import obj_to_dict @@ -57,7 +58,11 @@ class BaseRegister(ABC, CoordsCollection): """The abstract class for a register.""" @abstractmethod - def __init__(self, qubits: Mapping[Any, ArrayLike], **kwargs: Any): + def __init__( + self, + qubits: Mapping[str, ArrayLike] | Mapping[int, ArrayLike], + **kwargs: Any, + ): """Initializes a custom Register.""" if not isinstance(qubits, dict): raise TypeError( @@ -68,7 +73,9 @@ def __init__(self, qubits: Mapping[Any, ArrayLike], **kwargs: Any): raise ValueError( "Cannot create a Register with an empty qubit " "dictionary." ) - super().__init__([np.array(v, dtype=float) for v in qubits.values()]) + super().__init__( + [pm.AbstractArray(v, dtype=float) for v in qubits.values()] + ) self._ids: tuple[QubitId, ...] = tuple(qubits.keys()) self._layout_info: Optional[_LayoutInfo] = None self._init_kwargs(**kwargs) @@ -86,9 +93,9 @@ def _init_kwargs(self, **kwargs: Any) -> None: self._layout_info = _LayoutInfo(layout, trap_ids) @property - def qubits(self) -> dict[QubitId, np.ndarray]: + def qubits(self) -> dict[QubitId, pm.AbstractArray]: """Dictionary of the qubit names and their position coordinates.""" - return dict(zip(self._ids, self._coords)) + return dict(zip(self._ids, self._coords_arr)) @property def qubit_ids(self) -> tuple[QubitId, ...]: @@ -136,7 +143,7 @@ def find_indices(self, id_list: abcSequence[QubitId]) -> list[int]: @classmethod def from_coordinates( cls: Type[T], - coords: np.ndarray, + coords: ArrayLike | pm.TensorLike, center: bool = True, prefix: Optional[str] = None, labels: Optional[abcSequence[QubitId]] = None, @@ -160,11 +167,13 @@ def from_coordinates( Returns: A register with qubits placed on the given coordinates. """ + coords_ = pm.vstack(cast(abcSequence, coords)) if center: - coords = coords - np.mean(coords, axis=0) # Centers the array + coords_ = coords_ - pm.mean(coords_, axis=0) # Centers the array + qubits: dict[str, pm.AbstractArray] if prefix is not None: pre = str(prefix) - qubits = {pre + str(i): pos for i, pos in enumerate(coords)} + qubits = {pre + str(i): pos for i, pos in enumerate(coords_)} if labels is not None: raise NotImplementedError( "It is impossible to specify a prefix and " @@ -172,14 +181,14 @@ def from_coordinates( ) elif labels is not None: - if len(coords) != len(labels): + if len(coords_) != len(labels): raise ValueError( f"Label length ({len(labels)}) does not" - f"match number of coordinates ({len(coords)})" + f"match number of coordinates ({len(coords_)})" ) - qubits = dict(zip(cast(Iterable, labels), coords)) + qubits = dict(zip(cast(Iterable, labels), coords_)) else: - qubits = dict(cast(Iterable, enumerate(coords))) + qubits = dict(cast(Iterable, enumerate(coords_))) return cls(qubits, **kwargs) def _validate_layout( @@ -201,7 +210,9 @@ def _validate_layout( " in the register." ) - for reg_coord, trap_id in zip(self._coords, trap_ids): + for reg_coord, trap_id in zip( + self._coords_arr.as_array(detach=True), trap_ids + ): if np.any(reg_coord != trap_coords[trap_id]): raise ValueError( "The chosen traps from the RegisterLayout don't match this" @@ -230,7 +241,9 @@ def define_detuning_map( " in the register." ) return DetuningMap( - [self.qubits[qubit_id] for qubit_id in detuning_weights], + pm.vstack( + [self.qubits[qubit_id] for qubit_id in detuning_weights] + ), list(detuning_weights.values()), slug, ) @@ -258,7 +271,7 @@ def _to_dict(self) -> dict[str, Any]: return obj_to_dict( self, cls_dict, - [np.ndarray.tolist(qubit_coords) for qubit_coords in self._coords], + [qubit_coords.tolist() for qubit_coords in self._coords_arr], False, None, self._ids, @@ -271,16 +284,14 @@ def __eq__(self, other: Any) -> bool: if type(other) is not type(self): return False - return list(self._ids) == list(other._ids) and all( - ( - np.allclose( # Accounts for rounding errors - self._coords[i], - other._coords[other._ids.index(id)], - ) - for i, id in enumerate(self._ids) - ) + return self._ids == other._ids and np.allclose( + self._coords_arr.as_array(detach=True), + other._coords_arr.as_array(detach=True), ) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.qubits})" + def coords_hex_hash(self) -> str: """Returns the idempotent hash of the coordinates. diff --git a/pulser-core/pulser/register/register.py b/pulser-core/pulser/register/register.py index db6abd4c..69f4002f 100644 --- a/pulser-core/pulser/register/register.py +++ b/pulser-core/pulser/register/register.py @@ -25,6 +25,7 @@ from numpy.typing import ArrayLike import pulser +import pulser.math as pm import pulser.register._patterns as patterns from pulser.json.abstract_repr.deserializer import ( deserialize_abstract_register, @@ -43,11 +44,16 @@ class Register(BaseRegister, RegDrawer): (e.g. {'q0':(2, -1, 0), 'q1':(-5, 10, 0), ...}). """ - def __init__(self, qubits: Mapping[Any, ArrayLike], **kwargs: Any): + def __init__( + self, + qubits: Mapping[Any, ArrayLike | pm.TensorLike], + **kwargs: Any, + ): """Initializes a custom Register.""" super().__init__(qubits, **kwargs) - if any(c.shape != (self.dimensionality,) for c in self._coords) or ( - self.dimensionality != 2 + if ( + any(c.shape != (self.dimensionality,) for c in self._coords_arr) + or self.dimensionality != 2 ): raise ValueError( "All coordinates must be specified as vectors of size 2." @@ -55,7 +61,10 @@ def __init__(self, qubits: Mapping[Any, ArrayLike], **kwargs: Any): @classmethod def square( - cls, side: int, spacing: float = 4.0, prefix: Optional[str] = None + cls, + side: int, + spacing: float | pm.TensorLike = 4.0, + prefix: Optional[str] = None, ) -> Register: """Initializes the register with the qubits in a square array. @@ -83,7 +92,7 @@ def rectangle( cls, rows: int, columns: int, - spacing: float = 4.0, + spacing: float | pm.TensorLike = 4.0, prefix: Optional[str] = None, ) -> Register: """Creates a rectangular array of qubits on a square lattice. @@ -106,8 +115,8 @@ def rectangular_lattice( cls, rows: int, columns: int, - row_spacing: float = 4.0, - col_spacing: float = 2.0, + row_spacing: float | pm.TensorLike = 4.0, + col_spacing: float | pm.TensorLike = 2.0, prefix: Optional[str] = None, ) -> Register: """Creates a rectangular array of qubits on a rectangular lattice. @@ -139,13 +148,16 @@ def rectangular_lattice( " must be greater than or equal to 1." ) + row_spacing_ = pm.AbstractArray(row_spacing) + col_spacing_ = pm.AbstractArray(col_spacing) + # Check spacing - if row_spacing <= 0.0 or col_spacing <= 0.0: + if row_spacing_ <= 0.0 or col_spacing_ <= 0.0: raise ValueError("Spacing between atoms must be greater than 0.") - coords = patterns.square_rect(rows, columns) - coords[:, 0] = coords[:, 0] * col_spacing - coords[:, 1] = coords[:, 1] * row_spacing + coords = pm.AbstractArray(patterns.square_rect(rows, columns)) + coords[:, 0] = coords[:, 0] * col_spacing_ + coords[:, 1] = coords[:, 1] * row_spacing_ return cls.from_coordinates(coords, center=True, prefix=prefix) @@ -154,7 +166,7 @@ def triangular_lattice( cls, rows: int, atoms_per_row: int, - spacing: float = 4.0, + spacing: float | pm.TensorLike = 4.0, prefix: Optional[str] = None, ) -> Register: """Initializes the register with the qubits in a triangular lattice. @@ -189,20 +201,26 @@ def triangular_lattice( " must be greater than or equal to 1." ) + spacing_ = pm.AbstractArray(spacing) # Check spacing - if spacing <= 0.0: + if spacing_ <= 0.0: raise ValueError( f"Spacing between atoms (`spacing` = {spacing})" " must be greater than 0." ) - coords = patterns.triangular_rect(rows, atoms_per_row) * spacing - + coords = ( + pm.AbstractArray(patterns.triangular_rect(rows, atoms_per_row)) + * spacing_ + ) return cls.from_coordinates(coords, center=True, prefix=prefix) @classmethod def hexagon( - cls, layers: int, spacing: float = 4.0, prefix: Optional[str] = None + cls, + layers: int, + spacing: float | pm.TensorLike = 4.0, + prefix: Optional[str] = None, ) -> Register: """Initializes the register with the qubits in a hexagonal layout. @@ -223,15 +241,16 @@ def hexagon( " must be greater than or equal to 1." ) + spacing_ = pm.AbstractArray(spacing) # Check spacing - if spacing <= 0.0: + if spacing_ <= 0.0: raise ValueError( f"Spacing between atoms (`spacing` = {spacing})" " must be greater than 0." ) n_atoms = 1 + 3 * (layers**2 + layers) - coords = patterns.triangular_hex(n_atoms) * spacing + coords = pm.AbstractArray(patterns.triangular_hex(n_atoms)) * spacing_ return cls.from_coordinates(coords, center=False, prefix=prefix) @@ -240,7 +259,7 @@ def max_connectivity( cls, n_qubits: int, device: pulser.devices._device_datacls.BaseDevice, - spacing: float | None = None, + spacing: float | pm.TensorLike | None = None, prefix: str | None = None, ) -> Register: """Initializes the register with maximum connectivity for a device. @@ -284,22 +303,24 @@ def max_connectivity( # Default spacing or check minimal distance if spacing is None: - spacing = device.min_atom_distance - elif spacing < device.min_atom_distance: + spacing_ = pm.AbstractArray(device.min_atom_distance) + elif ( + spacing_ := pm.AbstractArray(spacing) + ) < device.min_atom_distance: raise ValueError( f"Spacing between atoms (`spacing = `{spacing})" " must be greater than or equal to the minimal" " distance supported by this device" f" ({device.min_atom_distance})." ) - if spacing <= 0.0: + if spacing_ <= 0.0: # spacing is None or 0.0, device.min_atom_distance is 0.0 raise NotImplementedError( "Maximum connectivity layouts are not well defined for a " "device with 'min_atom_distance=0.0'." ) - coords = patterns.triangular_hex(n_qubits) * spacing + coords = pm.AbstractArray(patterns.triangular_hex(n_qubits)) * spacing_ return cls.from_coordinates(coords, center=False, prefix=prefix) @@ -316,7 +337,7 @@ def rotated(self, degrees: float) -> Register: angle. """ theta = np.deg2rad(degrees) - rot = np.array( + rot = pm.vstack( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ) if self.layout is not None: @@ -327,7 +348,7 @@ def rotated(self, degrees: float) -> Register: ) return Register( - dict(zip(self.qubit_ids, [rot @ v for v in self._coords])) + dict(zip(self.qubit_ids, [rot @ v for v in self._coords_arr])) ) def draw( @@ -385,7 +406,7 @@ def draw( draw_half_radius=draw_half_radius, ) - pos = np.array(self._coords) + pos = self._coords_arr.as_array(detach=True) if custom_ax is None: _, custom_ax = self._initialize_fig_axes( pos, @@ -416,7 +437,7 @@ def _to_abstract_repr(self) -> list[dict[str, Union[QubitId, float]]]: names = stringify_qubit_ids(self._ids) return [ {"name": name, "x": x, "y": y} - for name, (x, y) in zip(names, self._coords) + for name, (x, y) in zip(names, self._coords_arr.tolist()) ] @staticmethod diff --git a/pulser-core/pulser/register/register3d.py b/pulser-core/pulser/register/register3d.py index 831c64b7..1cf24621 100644 --- a/pulser-core/pulser/register/register3d.py +++ b/pulser-core/pulser/register/register3d.py @@ -22,6 +22,7 @@ import numpy as np from numpy.typing import ArrayLike +import pulser.math as pm from pulser.json.abstract_repr.deserializer import ( deserialize_abstract_register, ) @@ -40,11 +41,16 @@ class Register3D(BaseRegister, RegDrawer): (e.g. {'q0':(2, -1, 0), 'q1':(-5, 10, 0), ...}). """ - def __init__(self, qubits: Mapping[Any, ArrayLike], **kwargs: Any): + def __init__( + self, + qubits: Mapping[Any, ArrayLike | pm.TensorLike], + **kwargs: Any, + ): """Initializes a custom Register.""" super().__init__(qubits, **kwargs) - if any(c.shape != (self.dimensionality,) for c in self._coords) or ( - self.dimensionality != 3 + if ( + any(c.shape != (self.dimensionality,) for c in self._coords_arr) + or self.dimensionality != 3 ): raise ValueError( "All coordinates must be specified as vectors of size 3." @@ -52,7 +58,10 @@ def __init__(self, qubits: Mapping[Any, ArrayLike], **kwargs: Any): @classmethod def cubic( - cls, side: int, spacing: float = 4.0, prefix: Optional[str] = None + cls, + side: int, + spacing: float | pm.TensorLike = 4.0, + prefix: Optional[str] = None, ) -> Register3D: """Initializes the register with the qubits in a cubic array. @@ -81,7 +90,7 @@ def cuboid( rows: int, columns: int, layers: int, - spacing: float = 4.0, + spacing: float | pm.TensorLike = 4.0, prefix: Optional[str] = None, ) -> Register3D: """Initializes the register with the qubits in a cuboid array. @@ -120,14 +129,15 @@ def cuboid( ) # Check spacing - if spacing <= 0.0: + spacing_ = pm.AbstractArray(spacing) + if spacing_ <= 0.0: raise ValueError( f"Spacing between atoms (`spacing` = {spacing})" " must be greater than 0." ) coords = ( - np.array( + pm.AbstractArray( [ (x, y, z) for z in range(layers) @@ -136,7 +146,7 @@ def cuboid( ], dtype=float, ) - * spacing + * spacing_ ) return cls.from_coordinates(coords, center=True, prefix=prefix) @@ -155,11 +165,10 @@ def to_2D(self, tol_width: float = 0.0) -> Register: Raises: ValueError: If the atoms are not coplanar. """ - coords = np.array(self._coords) - + coords = self._coords_arr.as_array(detach=True) barycenter = coords.sum(axis=0) / coords.shape[0] # run SVD - u, s, vh = np.linalg.svd(coords - barycenter) + _, _, vh = np.linalg.svd(coords - barycenter) e_z = vh[2, :] perp_extent = [e_z.dot(r) for r in coords] width = np.ptp(perp_extent) @@ -171,8 +180,11 @@ def to_2D(self, tol_width: float = 0.0) -> Register: else: e_x = vh[0, :] e_y = vh[1, :] - coords_2D = np.array( - [np.array([e_x.dot(r), e_y.dot(r)]) for r in coords] + coords_2D = pm.vstack( + [ + pm.hstack([pm.dot(e_x, r), pm.dot(e_y, r)]) + for r in self._coords_arr + ] ) return Register.from_coordinates(coords_2D, labels=self._ids) @@ -225,7 +237,7 @@ def draw( draw_half_radius=draw_half_radius, ) - pos = np.array(self._coords) + pos = self._coords_arr.as_array(detach=True) self._draw_3D( pos, diff --git a/pulser-core/pulser/register/register_layout.py b/pulser-core/pulser/register/register_layout.py index af4e5c6a..8cb2e720 100644 --- a/pulser-core/pulser/register/register_layout.py +++ b/pulser-core/pulser/register/register_layout.py @@ -247,7 +247,7 @@ def _to_dict(self) -> dict[str, Any]: # Allows for serialization of subclasses without a special _to_dict() return obj_to_dict( self, - self._coords, + self._coords_arr.tolist(), slug=self.slug, _module=__name__, _name="RegisterLayout", diff --git a/pulser-core/pulser/register/traps.py b/pulser-core/pulser/register/traps.py index c3c9b6fb..98028527 100644 --- a/pulser-core/pulser/register/traps.py +++ b/pulser-core/pulser/register/traps.py @@ -23,6 +23,7 @@ import numpy as np from numpy.typing import ArrayLike +import pulser.math as pm from pulser.register._coordinates import COORD_PRECISION, CoordsCollection @@ -41,13 +42,15 @@ class Traps(ABC, CoordsCollection): slug: str | None def __init__(self, trap_coordinates: ArrayLike, slug: str | None = None): - """Initializes a RegisterLayout.""" + """Initializes a set of traps.""" array_type_error_msg = ValueError( "'trap_coordinates' must be an array or list of coordinates." ) try: - coords_arr = np.array(trap_coordinates, dtype=float) + coords_arr = pm.AbstractArray( + trap_coordinates, dtype=float + ).as_array(detach=True) except ValueError as e: raise array_type_error_msg from e @@ -60,7 +63,7 @@ def __init__(self, trap_coordinates: ArrayLike, slug: str | None = None): f"Each coordinate must be of size 2 or 3, not {shape[1]}." ) - if len(np.unique(trap_coordinates, axis=0)) != shape[0]: + if len(np.unique(coords_arr, axis=0)) != shape[0]: raise ValueError( "All trap coordinates of a register layout must be unique." ) @@ -68,7 +71,7 @@ def __init__(self, trap_coordinates: ArrayLike, slug: str | None = None): object.__setattr__(self, "slug", slug) @property - def traps_dict(self) -> dict: + def traps_dict(self) -> dict[int, np.ndarray]: """Mapping between trap IDs and coordinates.""" return dict(enumerate(self.sorted_coords)) diff --git a/pulser-core/pulser/register/weight_maps.py b/pulser-core/pulser/register/weight_maps.py index a2d0e446..d740b53f 100644 --- a/pulser-core/pulser/register/weight_maps.py +++ b/pulser-core/pulser/register/weight_maps.py @@ -32,6 +32,8 @@ if TYPE_CHECKING: from pulser.register.base_register import QubitId +import pulser.math as pm + @dataclass(init=False, repr=False, eq=False, frozen=True) class WeightMap(Traps, RegDrawer): @@ -63,7 +65,7 @@ def __init__( @property def trap_coordinates(self) -> np.ndarray: """The array of trap coordinates, in the order they were given.""" - return np.array(self._coords) + return self._coords_arr.as_array(detach=True) @property def sorted_weights(self) -> np.ndarray: @@ -72,7 +74,7 @@ def sorted_weights(self) -> np.ndarray: return cast(np.ndarray, np.array(self.weights)[sorting]) def get_qubit_weight_map( - self, qubits: Mapping[QubitId, np.ndarray] + self, qubits: Mapping[QubitId, ArrayLike] ) -> dict[QubitId, float]: """Creates a map between qubit IDs and the weight on their sites.""" qubit_weight_map = {} @@ -81,7 +83,11 @@ def get_qubit_weight_map( for qid, pos in qubits.items(): matches = np.argwhere( np.all( - np.isclose(coords_arr, pos, atol=10 ** (-COORD_PRECISION)), + np.isclose( + coords_arr, + pm.AbstractArray(pos).as_array(detach=True), + atol=10 ** (-COORD_PRECISION), + ), axis=1, ) ) @@ -159,7 +165,8 @@ def _to_abstract_repr(self) -> dict[str, Any]: traps=[ {"weight": weight, "x": x, "y": y} for weight, (x, y) in zip( - self.sorted_weights, self.sorted_coords + self.sorted_weights, + self.sorted_coords, ) ] ) diff --git a/pulser-core/pulser/sampler/samples.py b/pulser-core/pulser/sampler/samples.py index ad2b1647..9b90be66 100644 --- a/pulser-core/pulser/sampler/samples.py +++ b/pulser-core/pulser/sampler/samples.py @@ -5,10 +5,11 @@ import itertools from collections import defaultdict from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Literal, Optional, cast, get_args import numpy as np +import pulser.math as pm from pulser.channels.base_channel import ( EIGENSTATES, Channel, @@ -39,9 +40,9 @@ def _prepare_dict(N: int, in_xy: bool = False) -> dict: def new_qty_dict() -> dict: return { - _AMP: np.zeros(N), - _DET: np.zeros(N), - _PHASE: np.zeros(N), + _AMP: pm.AbstractArray(np.zeros(N)), + _DET: pm.AbstractArray(np.zeros(N)), + _PHASE: pm.AbstractArray(np.zeros(N)), } def new_qdict() -> dict: @@ -95,15 +96,15 @@ class _SlmMask: class ChannelSamples: """Gathers samples of a channel.""" - amp: np.ndarray - det: np.ndarray - phase: np.ndarray + amp: pm.AbstractArray + det: pm.AbstractArray + phase: pm.AbstractArray slots: list[_PulseTargetSlot] = field(default_factory=list) eom_blocks: list[_EOMSettings] = field(default_factory=list) eom_start_buffers: list[tuple[int, int]] = field(default_factory=list) eom_end_buffers: list[tuple[int, int]] = field(default_factory=list) target_time_slots: list[_TimeSlot] = field(default_factory=list) - _centered_phase: np.ndarray | None = None + _centered_phase: pm.AbstractArray | None = None def __post_init__(self) -> None: assert ( @@ -129,7 +130,7 @@ def initial_targets(self) -> set[QubitId]: ) @property - def centered_phase(self) -> np.ndarray: + def centered_phase(self) -> pm.AbstractArray: """The phase samples centered in ]-π, π].""" if self._centered_phase is not None: return self._centered_phase @@ -138,7 +139,7 @@ def centered_phase(self) -> np.ndarray: return phase_ @property - def phase_modulation(self) -> np.ndarray: + def phase_modulation(self) -> pm.AbstractArray: r"""The phase modulation samples (in rad). Constructed by combining the integral of the detuning samples with the @@ -146,9 +147,7 @@ def phase_modulation(self) -> np.ndarray: .. math:: \phi(t) = \phi_c(t) - \sum_{k=0}^{t} \delta(k) """ - return cast( - np.ndarray, self.centered_phase - np.cumsum(self.det * 1e-3) - ) + return self.centered_phase - pm.cumsum(self.det * 1e-3) def extend_duration(self, new_duration: int) -> ChannelSamples: """Extends the duration of the samples. @@ -167,26 +166,26 @@ def extend_duration(self, new_duration: int) -> ChannelSamples: if extension < 0: raise ValueError("Can't extend samples to a lower duration.") - new_amp = np.pad(self.amp, (0, extension)) + new_amp = pm.pad(self.amp, (0, extension)) # When in EOM mode, we need to keep the detuning at detuning_off if self.eom_blocks and self.eom_blocks[-1].tf is None: - final_detuning = self.eom_blocks[-1].detuning_off + final_detuning = float(self.eom_blocks[-1].detuning_off) else: final_detuning = 0.0 - new_detuning = np.pad( + new_detuning = pm.pad( self.det, (0, extension), - constant_values=(final_detuning,), mode="constant", + constant_values=final_detuning, ) - new_phase = np.pad( + new_phase = pm.pad( self.phase, (0, extension), mode="edge" if self.phase.size > 0 else "constant", ) _new_centered_phase = None if self._centered_phase is not None: - _new_centered_phase = np.pad( + _new_centered_phase = pm.pad( self._centered_phase, (0, extension), mode="edge" if self._centered_phase.size > 0 else "constant", @@ -206,7 +205,11 @@ def is_empty(self) -> bool: The channel is considered empty if all amplitude and detuning samples are zero. """ - return np.count_nonzero(self.amp) + np.count_nonzero(self.det) == 0 + return ( + np.count_nonzero(self.amp.as_array(detach=True)) + + np.count_nonzero(self.det.as_array(detach=True)) + == 0 + ) def _generate_std_samples(self) -> ChannelSamples: new_samples = { @@ -258,10 +261,10 @@ def modulate( """ def masked( - samples: np.ndarray, + samples: pm.AbstractArray, mask: np.ndarray, keep_end_values: bool = False, - ) -> np.ndarray: + ) -> pm.AbstractArray: new_samples = samples.copy() # Extend the mask to fit the size of the samples mask = np.pad(mask, (0, len(new_samples) - len(mask)), mode="edge") @@ -294,9 +297,9 @@ def masked( new_samples[~mask] = 0 return new_samples - new_samples: dict[str, np.ndarray] = {} + new_samples: dict[str, pm.AbstractArray] = {} - eom_samples = { + eom_samples: dict[str, pm.AbstractArray] = { key: getattr(self, key).copy() for key in ("amp", "det") } @@ -356,7 +359,7 @@ def masked( ) else: std_mask = ~eom_mask - modulated_buffer = np.zeros_like(modulated_std) + modulated_buffer = pm.AbstractArray(modulated_std) * 0.0 std = masked(modulated_std, std_mask) buffers = masked( @@ -384,10 +387,13 @@ def masked( # such that the modulation starts off from that value # We then remove the extra value after modulation if eom_mask[0]: - samples_ = np.insert( + samples_ = pm.pad( samples_, - 0, - self.eom_blocks[0].detuning_off, + (1, 0), + "constant", + constant_values=float( + self.eom_blocks[0].detuning_off + ), ) # Finally, the modified EOM samples are modulated modulated_eom = channel_obj.modulate( @@ -408,7 +414,7 @@ def masked( # Extend shortest arrays to match the longest before summing new_samples[key] = sample_arrs[-1] for arr in sample_arrs[:-1]: - arr = np.pad( + arr = pm.pad( arr, (0, sample_arrs[-1].size - arr.size), ) @@ -423,7 +429,9 @@ def masked( self.centered_phase, keep_ends=True ) for key in new_samples: - new_samples[key] = new_samples[key][slice(0, max_duration)] + new_samples[key] = new_samples[key].astype(float)[ + slice(0, max_duration) + ] return replace(self, **new_samples) @@ -435,7 +443,10 @@ class DMMSamples(ChannelSamples): # Although these shouldn't have a default, in this way we can # subclass ChannelSamples detuning_map: DetuningMap | None = None - qubits: dict[QubitId, np.ndarray] = field(default_factory=dict) + qubits: dict[QubitId, pm.AbstractArray] = field(default_factory=dict) + + +_SamplesType = Literal["abstract", "array", "tensor"] @dataclass @@ -500,7 +511,11 @@ def extend_duration(self, new_duration: int) -> SequenceSamples: ], ) - def to_nested_dict(self, all_local: bool = False) -> dict: + def to_nested_dict( + self, + all_local: bool = False, + samples_type: _SamplesType = "array", + ) -> dict: """Format in the nested dictionary form. This is the format expected by `pulser_simulation.Simulation()`. @@ -508,12 +523,21 @@ def to_nested_dict(self, all_local: bool = False) -> dict: Args: all_local: Forces all samples to be distributed by their individual targets, even when applied by a global channel. + samples_type: The array type to return the samples in. Can be + "array" (the default), "tensor" or "abstract". Returns: A nested dictionary splitting the samples according to their addressing ('Global' or 'Local'), the targeted basis and, in the 'Local' case, the targeted qubit. """ + _samples_type_options = get_args(_SamplesType) + if samples_type not in _samples_type_options: + raise ValueError( + f"'samples_type' must be one of {_samples_type_options!r}, " + f"not {samples_type!r}." + ) + d = _prepare_dict(self.max_duration, in_xy=self._in_xy) for chname, samples in zip(self.channels, self.samples_list): cs = ( @@ -563,7 +587,25 @@ def to_nested_dict(self, all_local: bool = False) -> dict: ) d[_LOCAL][basis][t][_PHASE][times] += cs.phase[times] - return _default_to_regular(d) + regular_dict = _default_to_regular(d) + + def cast_arrays(arr_dict: dict) -> dict: + for k in arr_dict: + if isinstance(arr_dict[k], dict): + arr_dict[k] = cast_arrays(arr_dict[k]) + continue + assert isinstance(arr := arr_dict[k], pm.AbstractArray) + arr_dict[k] = ( + arr.as_tensor() + if samples_type == "tensor" + else arr.as_array(detach=True) + ) + return arr_dict + + if samples_type != "abstract": + regular_dict = cast_arrays(regular_dict) + + return regular_dict def __repr__(self) -> str: blocks = [ diff --git a/pulser-core/pulser/sequence/_schedule.py b/pulser-core/pulser/sequence/_schedule.py index 3384c63f..744040ed 100644 --- a/pulser-core/pulser/sequence/_schedule.py +++ b/pulser-core/pulser/sequence/_schedule.py @@ -21,6 +21,7 @@ import numpy as np +import pulser.math as pm from pulser.channels.base_channel import Channel from pulser.channels.dmm import DMM from pulser.channels.eom import RydbergBeam @@ -42,9 +43,9 @@ class _TimeSlot(NamedTuple): @dataclass class _EOMSettings: - rabi_freq: float - detuning_on: float - detuning_off: float + rabi_freq: pm.AbstractArray + detuning_on: pm.AbstractArray + detuning_off: pm.AbstractArray ti: int tf: int | None = None switching_beams: tuple[RydbergBeam, ...] = () @@ -52,10 +53,10 @@ class _EOMSettings: @dataclass class _PhaseDriftParams: - drift_rate: float # rad/µs + drift_rate: pm.AbstractArray # rad/µs ti: int # ns - def calc_phase_drift(self, tf: int) -> float: + def calc_phase_drift(self, tf: int) -> pm.AbstractArray: """Calculate the phase drift during the elapsed time.""" return self.drift_rate * (tf - self.ti) * 1e-3 @@ -97,7 +98,7 @@ def in_eom_mode(self, time_slot: Optional[_TimeSlot] = None) -> bool: @staticmethod def is_detuned_delay(pulse: Pulse) -> bool: """Tells if a pulse is actually a delay with a constant detuning.""" - return ( + return bool( isinstance(pulse, Pulse) and isinstance(pulse.amplitude, ConstantWaveform) and pulse.amplitude[0] == 0.0 @@ -150,7 +151,11 @@ def get_samples( # Keep only pulse slots channel_slots = [s for s in self.slots if isinstance(s.type, Pulse)] dt = self.get_duration() - amp, det, phase = np.zeros(dt), np.zeros(dt), np.zeros(dt) + amp, det, phase = ( + pm.AbstractArray(np.zeros(dt)), + pm.AbstractArray(np.zeros(dt)), + pm.AbstractArray(np.zeros(dt)), + ) slots: list[_PulseTargetSlot] = [] target_time_slots: list[_TimeSlot] = [ s for s in self.slots if s.type == "target" @@ -272,7 +277,7 @@ def __post_init__(self) -> None: def get_samples( self, ignore_detuned_delay_phase: bool = True, - qubits: dict[QubitId, np.ndarray] | None = None, + qubits: dict[QubitId, pm.AbstractArray] | None = None, ) -> DMMSamples: ch_samples = super().get_samples( ignore_detuned_delay_phase=ignore_detuned_delay_phase @@ -336,9 +341,9 @@ def find_slm_mask_times(self) -> list[int]: def enable_eom( self, channel_id: str, - amp_on: float, - detuning_on: float, - detuning_off: float, + amp_on: pm.AbstractArray, + detuning_on: pm.AbstractArray, + detuning_off: pm.AbstractArray, switching_beams: tuple[RydbergBeam, ...] = (), _skip_buffer: bool = False, _skip_wait_for_fall: bool = False, @@ -399,8 +404,8 @@ def add_pulse( protocol: str, phase_drift_params: _PhaseDriftParams | None = None, ) -> None: - def corrected_phase(tf: int) -> float: - phase_drift = ( + def corrected_phase(tf: int) -> pm.AbstractArray: + phase_drift = pm.AbstractArray( phase_drift_params.calc_phase_drift(tf) if phase_drift_params else 0 @@ -544,12 +549,12 @@ def _find_add_delay(self, t0: int, channel: str, protocol: str) -> int: return current_max_t - def _get_last_pulse_phase(self, channel: str) -> float: + def _get_last_pulse_phase(self, channel: str) -> pm.AbstractArray: try: last_pulse = cast(Pulse, self[channel].last_pulse_slot().type) phase = last_pulse.phase except RuntimeError: - phase = 0.0 + phase = pm.AbstractArray(0.0) return phase def _check_duration(self, t: int) -> None: diff --git a/pulser-core/pulser/sequence/_seq_drawer.py b/pulser-core/pulser/sequence/_seq_drawer.py index e26c9d2c..42372f06 100644 --- a/pulser-core/pulser/sequence/_seq_drawer.py +++ b/pulser-core/pulser/sequence/_seq_drawer.py @@ -28,6 +28,7 @@ from scipy.interpolate import CubicSpline import pulser +import pulser.math as pm from pulser import Register, Register3D from pulser.channels.base_channel import Channel from pulser.channels.dmm import DMM @@ -118,6 +119,21 @@ class ChannelDrawContent: phase_modulated: bool = False def __post_init__(self) -> None: + # Make sure there are no tensors in the channel samples + self.samples.amp = pm.AbstractArray( + self.samples.amp.as_array(detach=True) + ) + self.samples.det = pm.AbstractArray( + self.samples.det.as_array(detach=True) + ) + self.samples.phase = pm.AbstractArray( + self.samples.phase.as_array(detach=True) + ) + if self.samples._centered_phase is not None: + self.samples._centered_phase = pm.AbstractArray( + self.samples._centered_phase.as_array(detach=True) + ) + is_dmm = isinstance(self.samples, DMMSamples) self.curves_on = { "amplitude": not is_dmm, @@ -171,7 +187,10 @@ def _give_curves_from_samples( ) -> list[np.ndarray]: curves = [] for qty in CURVES_ORDER: - qty_arr = getattr(samples, self._samples_from_curves[qty]) + qty_arr = cast( + pm.AbstractArray, + getattr(samples, self._samples_from_curves[qty]), + ).as_array(detach=True) if "phase" in qty: qty_arr = qty_arr / (2 * np.pi) curves.append(qty_arr) @@ -370,7 +389,7 @@ def _draw_register_det_maps( ) # Draw masked register if register: - pos = np.array(register._coords) + pos = register._coords_arr.as_array(detach=True) title = ( "Register" if sampled_seq._slm_mask.targets == set() @@ -430,7 +449,7 @@ def _draw_register_det_maps( else cast(DMMSamples, sampled_seq.channel_samples[ch]).qubits ) reg_det_map = det_map.get_qubit_weight_map(qubits) - pos = np.array(list(qubits.values())) + pos = np.array([c.as_array(detach=True) for c in qubits.values()]) if need_init: if det_map.dimensionality == 3: labels = "xyz" @@ -522,15 +541,15 @@ def _draw_channel_content( shown_duration: Total duration to be shown in the X axis. """ - def phase_str(phi: float) -> str: + def phase_str(phi: Any) -> str: """Formats a phase value for printing.""" - value = (((phi + np.pi) % (2 * np.pi)) - np.pi) / np.pi + value = (((float(phi) + np.pi) % (2 * np.pi)) - np.pi) / np.pi if value == -1: return r"$\pi$" elif value == 0: return "0" # pragma: no cover - just for safety else: - return rf"{value:.2g}$\pi$" + return rf"{float(value):.2g}$\pi$" data = gather_data(sampled_seq, shown_duration) n_channels = len(sampled_seq.channels) @@ -724,7 +743,7 @@ def phase_str(phi: float) -> str: area_fmt = ( r"A: $\pi$" if round(area_val, 2) == 1 - else rf"A: {area_val:.2g}$\pi$" + else rf"A: {float(area_val):.2g}$\pi$" ) if not print_phase: txt = area_fmt diff --git a/pulser-core/pulser/sequence/_seq_str.py b/pulser-core/pulser/sequence/_seq_str.py index 33ddee11..21f7695e 100644 --- a/pulser-core/pulser/sequence/_seq_str.py +++ b/pulser-core/pulser/sequence/_seq_str.py @@ -15,7 +15,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from pulser.channels import DMM from pulser.pulse import Pulse @@ -67,18 +67,18 @@ def seq_to_str(sequence: Sequence) -> str: f"{ts.type.detuning!s} rad/µs" if not seq.is_detuned_delay(ts.type) else "{:.3g} rad/µs".format( - cast(float, ts.type.detuning[0]) + float(ts.type.detuning[0]) ) ), tgt_txt, ) elif seq.is_detuned_delay(ts.type): det = ts.type.detuning[0] - full += det_delay_line.format(ts.ti, ts.tf, det) + full += det_delay_line.format(ts.ti, ts.tf, float(det)) else: full += pulse_line.format(ts.ti, ts.tf, ts.type, tgt_txt) elif ts.type == "target": - phase = sequence._basis_ref[basis][tgts[0]].phase[ts.tf] + phase = float(sequence._basis_ref[basis][tgts[0]].phase[ts.tf]) if first_slot: full += ( f"t: 0 | Initial targets: {tgt_txt} | " diff --git a/pulser-core/pulser/sequence/sequence.py b/pulser-core/pulser/sequence/sequence.py index 3869c123..5d3166d8 100644 --- a/pulser-core/pulser/sequence/sequence.py +++ b/pulser-core/pulser/sequence/sequence.py @@ -40,6 +40,7 @@ import pulser import pulser.devices as devices +import pulser.math as pm import pulser.sequence._decorators as seq_decorators from pulser.channels.base_channel import Channel, States, get_states_from_bases from pulser.channels.dmm import DMM, _dmm_id_from_name, _get_dmm_name @@ -214,7 +215,7 @@ def _in_ising(self, value: bool) -> None: self._set_slm_mask_dmm(self._slm_mask_dmm, self._slm_mask_targets) @property - def qubit_info(self) -> dict[QubitId, np.ndarray]: + def qubit_info(self) -> dict[QubitId, pm.AbstractArray]: """Dictionary with the qubit's IDs and positions.""" if self.is_register_mappable(): raise RuntimeError( @@ -490,7 +491,7 @@ def current_phase_ref( f"No declared channel targets the given 'basis' ('{basis}')." ) - return self._basis_ref[basis][qubit].phase.last_phase + return float(self._basis_ref[basis][qubit].phase.last_phase) def set_magnetic_field( self, bx: float = 0.0, by: float = 0.0, bz: float = 30.0 @@ -1096,8 +1097,8 @@ def declare_variable( def enable_eom_mode( self, channel: str, - amp_on: Union[float, Parametrized], - detuning_on: Union[float, Parametrized], + amp_on: Union[float, pm.TensorLike, Parametrized], + detuning_on: Union[float, pm.TensorLike, Parametrized], optimal_detuning_off: Union[float, Parametrized] = 0.0, correct_phase_drift: bool = False, ) -> None: @@ -1148,25 +1149,33 @@ def enable_eom_mode( channel_obj, amp_on, detuning_on, optimal_detuning_off ) if not self.is_parametrized(): - detuning_off = cast(float, detuning_off) + assert not isinstance(amp_on, Parametrized) + amp_on_ = pm.AbstractArray(amp_on) + assert not isinstance(detuning_on, Parametrized) + detuning_on_ = pm.AbstractArray(detuning_on) + assert not isinstance(detuning_off, Parametrized) + detuning_off_ = pm.AbstractArray(detuning_off) + phase_drift_params = _PhaseDriftParams( - drift_rate=-detuning_off, + drift_rate=-detuning_off_, # enable_eom() calls wait for fall, so the block only # starts after fall time ti=self.get_duration(channel, include_fall_time=True), ) self._schedule.enable_eom( channel, - cast(float, amp_on), - cast(float, detuning_on), - detuning_off, + amp_on_, + detuning_on_, + detuning_off_, switching_beams, ) if correct_phase_drift: buffer_slot = self._last(channel) drift = phase_drift_params.calc_phase_drift(buffer_slot.tf) self._phase_shift( - -drift, *buffer_slot.targets, basis=channel_obj.basis + -float(drift), + *buffer_slot.targets, + basis=channel_obj.basis, ) # Manually store the call to "enable_eom_mode" so that the updated @@ -1182,7 +1191,11 @@ def enable_eom_mode( channel=channel, amp_on=amp_on, detuning_on=detuning_on, - optimal_detuning_off=detuning_off, + optimal_detuning_off=( + detuning_off + if isinstance(detuning_off, Parametrized) + else float(detuning_off) + ), correct_phase_drift=correct_phase_drift, ), ) @@ -1229,7 +1242,7 @@ def disable_eom_mode( last_eom_block_tf = cast(int, ch_schedule.eom_blocks[-1].tf) drift_params = self._get_last_eom_pulse_phase_drift(channel) self._phase_shift( - -drift_params.calc_phase_drift(last_eom_block_tf), + -float(drift_params.calc_phase_drift(last_eom_block_tf)), *ch_schedule[-1].targets, basis=ch_schedule.channel_obj.basis, ) @@ -1239,8 +1252,8 @@ def disable_eom_mode( def modify_eom_setpoint( self, channel: str, - amp_on: Union[float, Parametrized], - detuning_on: Union[float, Parametrized], + amp_on: Union[float, pm.TensorLike, Parametrized], + detuning_on: Union[float, pm.TensorLike, Parametrized], optimal_detuning_off: Union[float, Parametrized] = 0.0, correct_phase_drift: bool = False, ) -> None: @@ -1273,20 +1286,26 @@ def modify_eom_setpoint( ) if not self.is_parametrized(): - detuning_off = cast(float, detuning_off) + assert not isinstance(amp_on, Parametrized) + amp_on_ = pm.AbstractArray(amp_on) + assert not isinstance(detuning_on, Parametrized) + detuning_on_ = pm.AbstractArray(detuning_on) + assert not isinstance(detuning_off, Parametrized) + detuning_off_ = pm.AbstractArray(detuning_off) + self._schedule.disable_eom(channel, _skip_buffer=True) old_phase_drift_params = self._get_last_eom_pulse_phase_drift( channel ) new_phase_drift_params = _PhaseDriftParams( - drift_rate=-detuning_off, + drift_rate=-detuning_off_, ti=self.get_duration(channel, include_fall_time=False), ) self._schedule.enable_eom( channel, - cast(float, amp_on), - cast(float, detuning_on), - detuning_off, + amp_on_, + detuning_on_, + detuning_off_, switching_beams, _skip_wait_for_fall=True, ) @@ -1296,7 +1315,9 @@ def modify_eom_setpoint( buffer_slot.ti ) + new_phase_drift_params.calc_phase_drift(buffer_slot.tf) self._phase_shift( - -drift, *buffer_slot.targets, basis=channel_obj.basis + -float(drift), + *buffer_slot.targets, + basis=channel_obj.basis, ) # Manually store the call to "modify_eom_setpoint" so that the updated @@ -1312,7 +1333,11 @@ def modify_eom_setpoint( channel=channel, amp_on=amp_on, detuning_on=detuning_on, - optimal_detuning_off=detuning_off, + optimal_detuning_off=( + detuning_off + if isinstance(detuning_off, Parametrized) + else float(detuning_off) + ), correct_phase_drift=correct_phase_drift, ), ) @@ -1325,7 +1350,7 @@ def add_eom_pulse( self, channel: str, duration: Union[int, Parametrized], - phase: Union[float, Parametrized], + phase: Union[float, pm.TensorLike, Parametrized], post_phase_shift: Union[float, Parametrized] = 0.0, protocol: PROTOCOLS = "min-delay", correct_phase_drift: bool = False, @@ -1375,7 +1400,13 @@ def add_eom_pulse( channel_obj = self.declared_channels[channel] channel_obj.validate_duration(duration) for arg in (phase, post_phase_shift): - if not isinstance(arg, (Parametrized, float, int)): + if isinstance(arg, Parametrized): + continue + try: + if isinstance(arg, str): + raise TypeError + float(pm.AbstractArray(arg, dtype=float)) + except TypeError: raise TypeError("Phase values must be a numeric value.") return @@ -1585,7 +1616,7 @@ def measure(self, basis: str = "ground-rydberg") -> None: @seq_decorators.store def phase_shift( self, - phi: Union[float, Parametrized], + phi: float | Parametrized, *targets: QubitId, basis: str = "digital", ) -> None: @@ -1607,8 +1638,8 @@ def phase_shift( @seq_decorators.store def phase_shift_index( self, - phi: Union[float, Parametrized], - *targets: Union[int, Parametrized], + phi: float | Parametrized, + *targets: int | Parametrized, basis: str = "digital", ) -> None: r"""Shifts the phase of a qubit's reference by 'phi', on a given basis. @@ -1682,7 +1713,7 @@ def build( self, *, qubits: Optional[Mapping[QubitId, int]] = None, - **vars: Union[ArrayLike, float, int], + **vars: Union[ArrayLike, pm.TensorLike, float, int], ) -> Sequence: """Builds a sequence from the programmed instructions. @@ -1731,9 +1762,12 @@ def build( # Eliminates the source of recursiveness errors seq._reset_parametrized() - # Deepcopy the base sequence (what remains) - seq = copy.deepcopy(seq) - # NOTE: Changes to seq are now safe to do + # Recreate the base sequence (what remains) + temp_seq = type(seq)(register=seq._register, device=seq._device) + assert not seq._to_build_calls + for call in seq._calls[1:]: + getattr(temp_seq, call.name)(*call.args, **call.kwargs) + seq = temp_seq if not (self.is_parametrized() or self.is_register_mappable()): warnings.warn( @@ -2172,11 +2206,10 @@ def _add( # The phase correction done to the EOM pulse's phase must # also be done to the phase shift, as the phase reference is # effectively changed by -drift - total_phase_shift = ( - total_phase_shift - - phase_drift_params.calc_phase_drift(new_pulse_slot.ti) + total_phase_shift -= float( + phase_drift_params.calc_phase_drift(new_pulse_slot.ti) ) - if total_phase_shift: + if total_phase_shift != 0.0: self._phase_shift(total_phase_shift, *last.targets, basis=basis) if ( self._in_ising @@ -2202,6 +2235,8 @@ def _target( ) -> None: self._validate_channel(channel, block_eom_mode=True) channel_obj = self._schedule[channel].channel_obj + if isinstance(qubits, pm.AbstractArray): + qubits = qubits.tolist() try: qubits_set = ( set(cast(Collection, qubits)) @@ -2231,7 +2266,7 @@ def _target( if not self.is_parametrized(): basis = channel_obj.basis phase_refs = { - self._basis_ref[basis][q].phase.last_phase + float(self._basis_ref[basis][q].phase.last_phase) for q in qubit_ids_set } if len(phase_refs) != 1: @@ -2259,10 +2294,12 @@ def _check_qubits_give_ids( ) return set() else: - qubits = cast(Tuple[int, ...], qubits) try: return { - self._register.qubit_ids[index] for index in qubits + self._register.qubit_ids[ + int(index) # type: ignore[arg-type] + ] + for index in qubits } except IndexError: raise IndexError("Indices must exist for the register.") @@ -2292,8 +2329,8 @@ def _delay( def _phase_shift( self, - phi: Union[float, Parametrized], - *targets: Union[QubitId, Parametrized], + phi: float | Parametrized, + *targets: QubitId | Parametrized, basis: str, _index: bool = False, ) -> None: @@ -2304,10 +2341,7 @@ def _phase_shift( target_ids = self._check_qubits_give_ids(*targets, _index=_index) if not self.is_parametrized(): - phi = cast(float, phi) - if phi % (2 * np.pi) == 0: - return - + phi = float(cast(float, phi)) for qubit in target_ids: self._basis_ref[basis][qubit].increment_phase(phi) @@ -2381,7 +2415,10 @@ def _validate_channel( ) def _validate_and_adjust_pulse( - self, pulse: Pulse, channel: str, phase_ref: Optional[float] = None + self, + pulse: Pulse, + channel: str, + phase_ref: float | None = None, ) -> Pulse: # Get the channel object and its detuning map if the channel is a DMM channel_obj: Channel @@ -2457,19 +2494,23 @@ def _validate_add_protocol(self, protocol: str) -> None: def _process_eom_parameters( self, channel_obj: Channel, - amp_on: Union[float, Parametrized], - detuning_on: Union[float, Parametrized], + amp_on: Union[float, pm.TensorLike, Parametrized], + detuning_on: Union[float, pm.TensorLike, Parametrized], optimal_detuning_off: Union[float, Parametrized], - ) -> tuple[float | Parametrized, tuple[RydbergBeam, ...]]: + ) -> tuple[ + float | pm.AbstractArray | Parametrized, tuple[RydbergBeam, ...] + ]: on_pulse = Pulse.ConstantPulse( channel_obj.min_duration, amp_on, detuning_on, 0.0 ) - stored_opt_detuning_off = optimal_detuning_off + stored_opt_detuning_off: float | pm.AbstractArray | Parametrized = ( + optimal_detuning_off + ) switching_beams: tuple[RydbergBeam, ...] = () if not isinstance(on_pulse, Parametrized): channel_obj.validate_pulse(on_pulse) - amp_on = cast(float, amp_on) - detuning_on = cast(float, detuning_on) + assert not isinstance(amp_on, Parametrized) + assert not isinstance(detuning_on, Parametrized) eom_config = cast(RydbergEOM, channel_obj.eom_config) if not isinstance(optimal_detuning_off, Parametrized): ( @@ -2478,7 +2519,7 @@ def _process_eom_parameters( ) = eom_config.calculate_detuning_off( amp_on, detuning_on, - optimal_detuning_off, + float(optimal_detuning_off), return_switching_beams=True, ) off_pulse = Pulse.ConstantPulse( diff --git a/pulser-core/pulser/waveforms.py b/pulser-core/pulser/waveforms.py index 4ef560f7..e5d32423 100644 --- a/pulser-core/pulser/waveforms.py +++ b/pulser-core/pulser/waveforms.py @@ -23,7 +23,7 @@ from abc import ABC, abstractmethod from functools import cached_property from types import FunctionType -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar, Union, cast import matplotlib.pyplot as plt import numpy as np @@ -31,6 +31,7 @@ from matplotlib.axes import Axes from numpy.typing import ArrayLike +import pulser.math as pm from pulser.json.abstract_repr.serializer import abstract_repr from pulser.json.exceptions import AbstractReprError from pulser.json.utils import obj_to_dict @@ -51,6 +52,18 @@ "KaiserWaveform", ] +T = TypeVar("T", int, float) + + +def _cast_check(type_: type[T], value: Any, name: str) -> T: + try: + return type_(value) + except (ValueError, TypeError) as e: + raise TypeError( + f"'{name}' needs to be castable to {type_.__name__!s} " + f"but type {type(value)} was provided." + ) from e + class Waveform(ABC): """The abstract class for a pulse's waveform.""" @@ -69,14 +82,9 @@ def __init__(self, duration: Union[int, Parametrized]): Args: duration: The waveforms duration (in ns). """ - duration = cast(int, duration) - try: - _duration = int(duration) - except (TypeError, ValueError): - raise TypeError( - "duration needs to be castable to an int but " - f"type {type(duration)} was provided." - ) + assert not isinstance(duration, Parametrized) + _duration = _cast_check(int, duration, "duration") + if _duration <= 0: raise ValueError( "A waveform must have a positive duration, " @@ -100,11 +108,11 @@ def duration(self) -> int: @cached_property @abstractmethod - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: pass @property - def samples(self) -> np.ndarray: + def samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: @@ -125,7 +133,7 @@ def last_value(self) -> float: @property def integral(self) -> float: """Integral of the waveform (in [waveform units].µs).""" - return float(np.sum(self.samples)) * 1e-3 # ns * rad/µs = 1e-3 + return float(pm.sum(self._samples)) * 1e-3 # ns * rad/µs = 1e-3 def draw( self, @@ -169,7 +177,7 @@ def change_duration(self, new_duration: int) -> Waveform: def modulated_samples( self, channel: Channel, eom: bool = False - ) -> np.ndarray: + ) -> pm.AbstractArray: """The waveform samples as output of a given channel. This duration is adjusted according to the minimal buffer times. @@ -181,11 +189,22 @@ def modulated_samples( Returns: The array of samples after modulation. """ + detach = True # We detach unless... + if self.samples.is_tensor and self.samples.as_tensor().requires_grad: + # ... the samples require grad. In this case, we clear the cache + # so that the modulation is recalculated with the current samples + self._modulated_samples.cache_clear() + detach = False start, end = self.modulation_buffers(channel) mod_samples = self._modulated_samples(channel, eom=eom) tr = channel.rise_time trim = slice(tr - start, len(mod_samples) - tr + end) - return mod_samples[trim] + final_samples = mod_samples[trim] + if detach: + # This ensures that we don't carry the `requires_grad` of a + # cached results + return pm.AbstractArray(final_samples.as_array(detach=True)) + return final_samples @functools.lru_cache() def modulation_buffers( @@ -212,7 +231,7 @@ def modulation_buffers( @functools.lru_cache() def _modulated_samples( self, channel: Channel, eom: bool = False - ) -> np.ndarray: + ) -> pm.AbstractArray: """The waveform samples as output of a given channel. This is not adjusted to the minimal buffer times. Use @@ -245,13 +264,13 @@ def __repr__(self) -> str: def __getitem__( self, index_or_slice: Union[int, slice] - ) -> Union[float, np.ndarray]: + ) -> pm.AbstractArray: if isinstance(index_or_slice, slice): s: slice = self._check_slice(index_or_slice) return self._samples[s] else: index: int = self._check_index(index_or_slice) - return cast(float, self._samples[index]) + return self._samples[index] def _check_index(self, i: int) -> int: if i < -self.duration or i >= self.duration: @@ -295,17 +314,18 @@ def _check_slice(self, s: slice) -> slice: return slice(start, stop) @abstractmethod - def __mul__(self, other: float) -> Waveform: + def __mul__(self, other: float | ArrayLike) -> Waveform: pass def __neg__(self) -> Waveform: return self.__mul__(-1.0) - def __truediv__(self, other: float) -> Waveform: - if other == 0: + def __truediv__(self, other: float | ArrayLike) -> Waveform: + other_ = pm.AbstractArray(other) + if np.any(other_.as_array(detach=True) == 0): raise ZeroDivisionError("Can't divide a waveform by zero.") else: - return self.__mul__(1 / other) + return self.__mul__(1 / other_) def __eq__(self, other: object) -> bool: if not isinstance(other, Waveform): @@ -313,10 +333,17 @@ def __eq__(self, other: object) -> bool: elif self.duration != other.duration: return False else: - return bool(np.all(np.isclose(self.samples, other.samples))) + return bool( + np.all( + np.isclose( + self.samples.as_array(detach=True), + other.samples.as_array(detach=True), + ) + ) + ) def __hash__(self) -> int: - return hash(tuple(self.samples)) + return hash(tuple(self.samples.tolist())) def _plot( self, @@ -332,7 +359,7 @@ def _plot( self.samples if channel is None else self.modulated_samples(channel) - ) + ).as_array(detach=True) ts = np.arange(len(samples)) + start_t if not channel and start_t: # Adds zero on both ends to show rise and fall @@ -385,15 +412,13 @@ def duration(self) -> int: return duration @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: A numpy array with a value for each time step. """ - return cast( - np.ndarray, np.concatenate([wf.samples for wf in self._waveforms]) - ) + return pm.concatenate([wf.samples for wf in self._waveforms]) @property def waveforms(self) -> list[Waveform]: @@ -422,8 +447,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"CompositeWaveform({self.duration} ns, {self._waveforms!r})" - def __mul__(self, other: float) -> CompositeWaveform: - return CompositeWaveform(*(wf * other for wf in self._waveforms)) + def __mul__(self, other: float | ArrayLike) -> CompositeWaveform: + other_ = pm.AbstractArray(other, dtype=float) + return CompositeWaveform(*(wf * other_ for wf in self._waveforms)) class CustomWaveform(Waveform): @@ -434,19 +460,19 @@ class CustomWaveform(Waveform): The number of samples dictates the duration, in ns. """ - def __init__(self, samples: ArrayLike): + def __init__(self, samples: ArrayLike | pm.TensorLike): """Initializes a custom waveform.""" - samples_arr = np.array(samples, dtype=float) - self._samples_arr: np.ndarray = samples_arr + samples_arr = pm.AbstractArray(samples, dtype=float) + self._samples_arr: pm.AbstractArray = samples_arr super().__init__(len(samples_arr)) @property def duration(self) -> int: """The duration of the pulse (in ns).""" - return self._duration + return int(self._duration) @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: @@ -466,8 +492,10 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"CustomWaveform({self.duration} ns, {self.samples!r})" - def __mul__(self, other: float) -> CustomWaveform: - return CustomWaveform(self._samples * float(other)) + def __mul__(self, other: float | ArrayLike) -> CustomWaveform: + return CustomWaveform( + self._samples * pm.AbstractArray(other, dtype=float) + ) class ConstantWaveform(Waveform): @@ -481,12 +509,13 @@ class ConstantWaveform(Waveform): def __init__( self, duration: Union[int, Parametrized], - value: Union[float, Parametrized], + value: Union[float, pm.TensorLike, Parametrized], ): """Initializes a constant waveform.""" super().__init__(duration) - value = cast(float, value) - self._value = float(value) + assert not isinstance(value, Parametrized) + _cast_check(float, value, "value") + self._value = pm.AbstractArray(value, dtype=float) @property def duration(self) -> int: @@ -494,13 +523,13 @@ def duration(self) -> int: return self._duration @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: A numpy array with a value for each time step. """ - return np.full(self.duration, self._value) + return self._value * np.ones(self.duration) def change_duration(self, new_duration: int) -> ConstantWaveform: """Returns a new waveform with modified duration. @@ -520,13 +549,17 @@ def _to_abstract_repr(self) -> dict[str, Any]: return abstract_repr("ConstantWaveform", self._duration, self._value) def __str__(self) -> str: - return f"{self._value:.3g}" + return f"{float(self._value):.3g}" def __repr__(self) -> str: - return f"ConstantWaveform({self._duration} ns, {self._value:.3g})" + return ( + f"ConstantWaveform({self._duration} ns, {float(self._value):.3g})" + ) - def __mul__(self, other: float) -> ConstantWaveform: - return ConstantWaveform(self._duration, self._value * float(other)) + def __mul__(self, other: float | ArrayLike) -> ConstantWaveform: + return ConstantWaveform( + self._duration, self._value * pm.AbstractArray(other, dtype=float) + ) class RampWaveform(Waveform): @@ -541,15 +574,17 @@ class RampWaveform(Waveform): def __init__( self, duration: Union[int, Parametrized], - start: Union[float, Parametrized], - stop: Union[float, Parametrized], + start: Union[float, pm.TensorLike, Parametrized], + stop: Union[float, pm.TensorLike, Parametrized], ): """Initializes a ramp waveform.""" super().__init__(duration) - start = cast(float, start) - self._start: float = float(start) - stop = cast(float, stop) - self._stop: float = float(stop) + assert not isinstance(start, Parametrized) + assert not isinstance(stop, Parametrized) + _cast_check(float, start, "start") + _cast_check(float, stop, "stop") + self._start = pm.AbstractArray(start, dtype=float) + self._stop = pm.AbstractArray(stop, dtype=float) @property def duration(self) -> int: @@ -557,18 +592,24 @@ def duration(self) -> int: return self._duration @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: A numpy array with a value for each time step. """ - return np.linspace(self._start, self._stop, num=self._duration) + return ( + self._slope * np.arange(self._duration, dtype=float) + self._start + ) + + @property + def _slope(self) -> pm.AbstractArray: + return (self._stop - self._start) / (self._duration - 1) @property def slope(self) -> float: r"""Slope of the ramp, in [waveform units] / ns.""" - return (self._stop - self._start) / (self._duration - 1) + return float(self._slope) def change_duration(self, new_duration: int) -> RampWaveform: """Returns a new waveform with modified duration. @@ -590,16 +631,16 @@ def _to_abstract_repr(self) -> dict[str, Any]: ) def __str__(self) -> str: - return f"Ramp({self._start:.3g}->{self._stop:.3g})" + return f"Ramp({float(self._start):.3g}->{float(self._stop):.3g})" def __repr__(self) -> str: return ( f"RampWaveform({self._duration} ns, " - + f"{self._start:.3g}->{self._stop:.3g})" + f"{float(self._start):.3g}->{float(self._stop):.3g})" ) - def __mul__(self, other: float) -> RampWaveform: - k = float(other) + def __mul__(self, other: float | ArrayLike) -> RampWaveform: + k = pm.AbstractArray(other, dtype=float) return RampWaveform(self._duration, self._start * k, self._stop * k) @@ -621,31 +662,25 @@ class BlackmanWaveform(Waveform): def __init__( self, duration: Union[int, Parametrized], - area: Union[float, Parametrized], + area: Union[float, pm.TensorLike, Parametrized], ): """Initializes a Blackman waveform.""" super().__init__(duration) - try: - self._area: float = float(cast(float, area)) - except (TypeError, ValueError): - raise TypeError( - "area needs to be castable to a float but " - f"type {type(area)} was provided." - ) + assert not isinstance(area, Parametrized) + _cast_check(float, area, "area") + self._area = pm.AbstractArray(area, dtype=float) - self._norm_samples: np.ndarray = np.clip( - np.blackman(self._duration), 0, np.inf - ) - self._scaling: float = ( - self._area / float(np.sum(self._norm_samples)) / 1e-3 + self._norm_samples = pm.AbstractArray( + np.clip(np.blackman(self._duration), 0, np.inf) ) + self._scaling = self._area / pm.sum(self._norm_samples) * 1e3 @classmethod @parametrize def from_max_val( cls, max_val: Union[float, Parametrized], - area: Union[float, Parametrized], + area: Union[float, pm.TensorLike, Parametrized], ) -> BlackmanWaveform: """Creates a Blackman waveform with a threshold on the maximum value. @@ -666,24 +701,25 @@ def from_max_val( area: The area under the waveform. """ max_val = cast(float, max_val) - area = cast(float, area) - area_sign = np.sign(area) + assert not isinstance(area, Parametrized) + area_float = _cast_check(float, area, "area") + area_sign = np.sign(area_float) if np.sign(max_val) != area_sign: raise ValueError( - "The maximum value and the area must have " "matching signs." + "The maximum value and the area must have matching signs." ) # Deal only with positive areas - area *= float(area_sign) + area = pm.AbstractArray(area, dtype=float) * float(area_sign) max_val *= float(area_sign) # A normalized Blackman waveform has an area of 0.42 * duration - duration = np.ceil(area / (0.42 * max_val) * 1e3) # in ns + duration = np.ceil(float(area) / (0.42 * max_val) * 1e3) # in ns wf = cls(duration, area) previous_wf = None # Adjust for rounding errors to make sure max_val is not surpassed - while wf._scaling > max_val: + while float(wf._scaling) > max_val: duration += 1 previous_wf = wf wf = cls(duration, area) @@ -694,7 +730,9 @@ def from_max_val( if ( previous_wf is not None and duration % 2 == 1 - and np.max(wf.samples) < np.max(previous_wf.samples) <= max_val + and np.max(wf.samples.as_array(detach=True)) + < np.max(previous_wf.samples.as_array(detach=True)) + <= max_val ): wf = previous_wf @@ -707,13 +745,13 @@ def duration(self) -> int: return self._duration @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: A numpy array with a value for each time step. """ - return cast(np.ndarray, self._norm_samples * self._scaling) + return self._norm_samples * self._scaling def change_duration(self, new_duration: int) -> BlackmanWaveform: """Returns a new waveform with modified duration. @@ -734,13 +772,18 @@ def _to_abstract_repr(self) -> dict[str, Any]: return abstract_repr("BlackmanWaveform", self._duration, self._area) def __str__(self) -> str: - return f"Blackman(Area: {self._area:.3g})" + return f"Blackman(Area: {float(self._area):.3g})" def __repr__(self) -> str: - return f"BlackmanWaveform({self._duration} ns, Area: {self._area:.3g})" + return ( + f"BlackmanWaveform({self._duration} ns, " + f"Area: {float(self._area):.3g})" + ) - def __mul__(self, other: float) -> BlackmanWaveform: - return BlackmanWaveform(self._duration, self._area * float(other)) + def __mul__(self, other: float | ArrayLike) -> BlackmanWaveform: + return BlackmanWaveform( + self._duration, self._area * pm.AbstractArray(other, dtype=float) + ) class InterpolatedWaveform(Waveform): @@ -826,14 +869,14 @@ def duration(self) -> int: return self._duration @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform.""" samples = self._interp_func(np.arange(self._duration)) value_range = np.max(np.abs(samples)) decimals = int( min(np.finfo(samples.dtype).precision - np.log10(value_range), 9) ) # Reduces decimal values below 9 for large ranges - return cast(np.ndarray, np.round(samples, decimals=decimals)) + return pm.AbstractArray(np.round(samples, decimals=decimals)) @property def interp_function( @@ -907,9 +950,11 @@ def __repr__(self) -> str: interp_str = f", Interpolator={self._kwargs['interpolator']})" return self.__str__()[:-1] + interp_str - def __mul__(self, other: float) -> InterpolatedWaveform: + def __mul__(self, other: float | ArrayLike) -> InterpolatedWaveform: return InterpolatedWaveform( - self._duration, self._values * other, **self._kwargs + self._duration, + self._values * np.array(other, dtype=float), + **self._kwargs, ) @@ -938,27 +983,20 @@ class KaiserWaveform(Waveform): def __init__( self, duration: Union[int, Parametrized], - area: Union[float, Parametrized], + area: Union[float, pm.TensorLike, Parametrized], beta: Optional[Union[float, Parametrized]] = 14.0, ): """Initializes a Kaiser waveform.""" super().__init__(duration) - try: - self._area: float = float(cast(float, area)) - except (TypeError, ValueError): - raise TypeError( - "area needs to be castable to a float but " - f"type {type(area)} was provided." - ) + assert not isinstance(area, Parametrized) + _cast_check(float, area, "area") + self._area = pm.AbstractArray(area, dtype=float) - try: - self._beta: float = float(cast(float, beta)) - except (TypeError, ValueError): - raise TypeError( - "beta needs to be castable to a float but " - f"type {type(beta)} was provided." - ) + beta = cast(float, beta) + # This makes sure 'beta' is not a tensor that requires grad + pm.AbstractArray(beta).as_array() + self._beta = _cast_check(float, beta, "beta") if self._beta < 0.0: raise ValueError( @@ -966,20 +1004,18 @@ def __init__( " must be greater than 0." ) - self._norm_samples: np.ndarray = np.clip( - np.kaiser(self._duration, self._beta), 0, np.inf + self._norm_samples = pm.AbstractArray( + np.clip(np.kaiser(self._duration, self._beta), 0, np.inf) ) - self._scaling: float = ( - self._area / float(np.sum(self._norm_samples)) / 1e-3 - ) + self._scaling = self._area / pm.sum(self._norm_samples) * 1e3 @classmethod @parametrize def from_max_val( cls, max_val: Union[float, Parametrized], - area: Union[float, Parametrized], + area: Union[float, pm.TensorLike, Parametrized], beta: Optional[Union[float, Parametrized]] = 14.0, ) -> KaiserWaveform: """Creates a Kaiser waveform with a threshold on the maximum value. @@ -1003,26 +1039,27 @@ def from_max_val( The default value is 14. """ max_val = cast(float, max_val) - area = cast(float, area) + assert not isinstance(area, Parametrized) + area_float = _cast_check(float, area, "area") beta = cast(float, beta) - if np.sign(max_val) != np.sign(area): + if np.sign(max_val) != np.sign(area_float): raise ValueError( "The maximum value and the area must have matching signs." ) # All computations will be done on a positive area - - is_negative: bool = area < 0 + area = pm.AbstractArray(area, dtype=float) + is_negative: bool = area_float < 0 if is_negative: - area = -area + area_float = -area_float max_val = -max_val # Compute the ratio area / duration for a long duration # and use this value for a first guess of the best duration ratio: float = max_val * np.sum(np.kaiser(100, beta)) / 100 - duration_guess: int = int(area * 1000.0 / ratio) + duration_guess: int = int(area_float * 1000.0 / ratio) duration_best: int = 0 @@ -1033,7 +1070,7 @@ def from_max_val( max_val_best: float = 0 for duration in range(1, 16): kaiser_temp = np.kaiser(duration, beta) - scaling_temp = 1000 * area / np.sum(kaiser_temp) + scaling_temp = 1000 * area_float / np.sum(kaiser_temp) max_val_temp = np.max(kaiser_temp) * scaling_temp if max_val_best < max_val_temp <= max_val: max_val_best = max_val_temp @@ -1043,7 +1080,7 @@ def from_max_val( # Start with a waveform based on the duration guess kaiser_guess = np.kaiser(duration_guess, beta) - scaling_guess = 1000 * area / np.sum(kaiser_guess) + scaling_guess = 1000 * area_float / np.sum(kaiser_guess) max_val_temp = np.max(kaiser_guess) * scaling_guess # Increase or decrease duration depending on @@ -1055,16 +1092,11 @@ def from_max_val( while np.sign(max_val_temp - max_val) == step: duration += step kaiser_temp = np.kaiser(duration, beta) - scaling = 1000 * area / np.sum(kaiser_temp) + scaling = 1000 * area_float / np.sum(kaiser_temp) max_val_temp = np.max(kaiser_temp) * scaling duration_best = duration if step == 1 else duration + 1 - # Restore the original area if it was negative - - if is_negative: - area = -area - return cls(duration_best, area, beta) @property @@ -1073,13 +1105,13 @@ def duration(self) -> int: return self._duration @cached_property - def _samples(self) -> np.ndarray: + def _samples(self) -> pm.AbstractArray: """The value at each time step that describes the waveform. Returns: A numpy array with a value for each time step. """ - return cast(np.ndarray, self._norm_samples * self._scaling) + return self._norm_samples * self._scaling def change_duration(self, new_duration: int) -> KaiserWaveform: """Returns a new waveform with modified duration. @@ -1104,18 +1136,20 @@ def _to_abstract_repr(self) -> dict[str, Any]: def __str__(self) -> str: return ( f"Kaiser({self._duration} ns, " - f"Area: {self._area:.3g}, Beta: {self._beta:.3g})" + f"Area: {float(self._area):.3g}, Beta: {self._beta:.3g})" ) def __repr__(self) -> str: return ( f"KaiserWaveform(duration: {self._duration}, " - f"area: {self._area:.3g}, beta: {self._beta:.3g})" + f"area: {float(self._area):.3g}, beta: {self._beta:.3g})" ) - def __mul__(self, other: float) -> KaiserWaveform: + def __mul__(self, other: float | ArrayLike) -> KaiserWaveform: return KaiserWaveform( - self._duration, self._area * float(other), self._beta + self._duration, + self._area * pm.AbstractArray(other, dtype=float), + self._beta, ) diff --git a/pulser-core/setup.py b/pulser-core/setup.py index 6db9e8e0..cb658234 100644 --- a/pulser-core/setup.py +++ b/pulser-core/setup.py @@ -45,6 +45,7 @@ name=distribution_name, version=__version__, install_requires=requirements, + extras_require={"torch": ["torch ~= 2.0"]}, packages=find_packages(), package_data={package_name: ["py.typed"]}, include_package_data=True, diff --git a/pulser-pasqal/pulser_pasqal/pasqal_cloud.py b/pulser-pasqal/pulser_pasqal/pasqal_cloud.py index e0faa009..25e6ab92 100644 --- a/pulser-pasqal/pulser_pasqal/pasqal_cloud.py +++ b/pulser-pasqal/pulser_pasqal/pasqal_cloud.py @@ -14,7 +14,6 @@ """Allows to connect to PASQAL's cloud platform to run sequences.""" from __future__ import annotations -import copy import json from dataclasses import fields from typing import Any, Type, cast @@ -110,8 +109,12 @@ def submit( "The measurement basis can't be implicitly determined " "for a sequence not addressing a single basis." ) - # The copy prevents changing the input sequence - sequence = copy.deepcopy(sequence) + # This is equivalent to performing a deepcopy + # All tensors are converted to arrays but that's ok, it would + # have happened anyway later on + sequence = Sequence.from_abstract_repr( + sequence.to_abstract_repr(skip_validation=True) + ) sequence.measure(bases[0]) emulator = kwargs.get("emulator", None) diff --git a/pulser-simulation/pulser_simulation/hamiltonian.py b/pulser-simulation/pulser_simulation/hamiltonian.py index 605c0ab7..a770cee9 100644 --- a/pulser-simulation/pulser_simulation/hamiltonian.py +++ b/pulser-simulation/pulser_simulation/hamiltonian.py @@ -23,6 +23,7 @@ import numpy as np import qutip +import pulser.math as pm from pulser.channels.base_channel import STATES_RANK, States from pulser.devices._device_datacls import BaseDevice from pulser.noise_model import NoiseModel @@ -47,14 +48,14 @@ class Hamiltonian: def __init__( self, samples_obj: SequenceSamples, - qdict: dict[QubitId, np.ndarray], + qdict: dict[QubitId, pm.AbstractArray], device: BaseDevice, sampling_rate: float, config: NoiseModel, ) -> None: """Instantiates a Hamiltonian object.""" self.samples_obj = samples_obj - self._qdict = qdict + self._qdict = {k: v.as_array(detach=True) for k, v in qdict.items()} self._device = device self._sampling_rate = sampling_rate diff --git a/pulser-simulation/pulser_simulation/simulation.py b/pulser-simulation/pulser_simulation/simulation.py index 77a3d11c..4cffff32 100644 --- a/pulser-simulation/pulser_simulation/simulation.py +++ b/pulser-simulation/pulser_simulation/simulation.py @@ -504,7 +504,10 @@ def run( def get_min_variation(ch_sample: ChannelSamples) -> int: end_point = ch_sample.duration - 1 min_variations: list[int] = [] - for sample in (ch_sample.amp, ch_sample.det): + for sample in ( + ch_sample.amp.as_array(detach=True), + ch_sample.det.as_array(detach=True), + ): min_variations.append( int( np.min( diff --git a/setup.py b/setup.py index 07c53d7b..2e094929 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ name="pulser", version=__version__, install_requires=requirements, + extras_require={"torch": [f"pulser-core[torch] == {__version__}"]}, description="A pulse-level composer for neutral-atom quantum devices.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/tests/test_abstract_repr.py b/tests/test_abstract_repr.py index 2b3b5ebb..a72466a4 100644 --- a/tests/test_abstract_repr.py +++ b/tests/test_abstract_repr.py @@ -1128,11 +1128,12 @@ def test_dmm_slm_mask(self, triangular_lattice, is_empty): assert abstract["operations"][1]["op"] == "config_detuning_map" assert abstract["operations"][1]["dmm_id"] == "dmm_0" + reg_coords = reg._coords_arr.as_array() assert abstract["operations"][1]["detuning_map"]["traps"] == [ { "weight": weight, - "x": reg._coords[i][0], - "y": reg._coords[i][1], + "x": reg_coords[i][0], + "y": reg_coords[i][1], } for i, weight in enumerate(list(det_map.values())) ] @@ -1244,7 +1245,12 @@ def _check_roundtrip(serialized_seq: dict[str, Any]): reconstructed_wf = wf_cls( *(op[wf][qty] for qty in wf_args) ) - op[wf] = reconstructed_wf._to_abstract_repr() + op[wf] = json.loads( + json.dumps( + reconstructed_wf._to_abstract_repr(), + cls=AbstractReprEncoder, + ) + ) elif ( "eom" in op["op"] and not op.get("correct_phase_drift") @@ -1344,7 +1350,9 @@ def test_deserialize_register(self, layout_coords): # Check layout if layout_coords is not None: assert seq.register.layout == reg_layout - q_coords = list(seq.qubit_info.values()) + q_coords = [ + q_coords.tolist() for q_coords in seq.qubit_info.values() + ] assert seq.register._layout_info.trap_ids == tuple( reg_layout.get_traps_from_coordinates(*q_coords) ) @@ -1824,7 +1832,7 @@ def test_deserialize_parametrized_op(self, op): operations=[op], variables={ "var1": {"type": "int", "value": [0]}, - "var2": {"type": "int", "value": [42]}, + "var2": {"type": "int", "value": [44]}, }, ) _check_roundtrip(s) @@ -2088,8 +2096,8 @@ def test_deserialize_eom_ops(self, correct_phase_drift, var_detuning_on): else: enable_eom_call = seq._calls[-1] eom_conf = seq.declared_channels["global"].eom_config - optimal_det_off = eom_conf.calculate_detuning_off( - 3.0, detuning_on, -1.0 + optimal_det_off = float( + eom_conf.calculate_detuning_off(3.0, detuning_on, -1.0) ) # Roundtrip will only match if the optimal detuning off matches diff --git a/tests/test_channels.py b/tests/test_channels.py index 479a30fc..bbf1a321 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -271,22 +271,32 @@ def test_modulation_errors(): (_eom_rydberg, _eom_config.rise_time, True, 0), ], ) -def test_modulation(channel, tr, eom, side_buffer_len): - wf = ConstantWaveform(100, 1) +@pytest.mark.parametrize("requires_grad", [False, True]) +def test_modulation(channel, tr, eom, side_buffer_len, requires_grad): + wf_vals = [1, np.pi] + if requires_grad: + wf_vals = pytest.importorskip("torch").tensor( + wf_vals, requires_grad=True + ) + wf = ConstantWaveform(100, wf_vals[0]) out_ = channel.modulate(wf.samples, eom=eom) assert len(out_) == wf.duration + 2 * tr assert channel.calc_modulation_buffer(wf.samples, out_, eom=eom) == ( tr, tr, ) + if requires_grad: + assert out_.as_tensor().requires_grad - wf2 = BlackmanWaveform(800, np.pi) + wf2 = BlackmanWaveform(800, wf_vals[1]) out_ = channel.modulate(wf2.samples, eom=eom) assert len(out_) == wf2.duration + 2 * tr # modulate() does not truncate assert channel.calc_modulation_buffer(wf2.samples, out_, eom=eom) == ( side_buffer_len, side_buffer_len, ) + if requires_grad: + assert out_.as_tensor().requires_grad @pytest.mark.parametrize( diff --git a/tests/test_devices.py b/tests/test_devices.py index 7ab67e35..5252fe64 100644 --- a/tests/test_devices.py +++ b/tests/test_devices.py @@ -270,27 +270,39 @@ def test_rydberg_blockade(): ) -def test_validate_register(): +@pytest.mark.parametrize("with_diff", [False, True]) +def test_validate_register(with_diff): + bad_coords1 = [(100.0, 0.0), (-100.0, 0.0)] + bad_coords2 = [(-10, 4, 0), (0, 0, 0)] + good_spacing = 5.0 + if with_diff: + torch = pytest.importorskip("torch") + bad_coords1 = torch.tensor( + bad_coords1, dtype=float, requires_grad=True + ) + bad_coords2 = torch.tensor( + bad_coords2, dtype=float, requires_grad=True + ) + good_spacing = torch.tensor(good_spacing, requires_grad=True) + with pytest.raises(ValueError, match="The number of atoms"): DigitalAnalogDevice.validate_register(Register.square(50)) - coords = [(100, 0), (-100, 0)] with pytest.raises(TypeError): - DigitalAnalogDevice.validate_register(coords) + DigitalAnalogDevice.validate_register(bad_coords1) with pytest.raises(ValueError, match="at most 50 μm away from the center"): DigitalAnalogDevice.validate_register( - Register.from_coordinates(coords) + Register.from_coordinates(bad_coords1) ) with pytest.raises(ValueError, match="at most 2D vectors"): - coords = [(-10, 4, 0), (0, 0, 0)] DigitalAnalogDevice.validate_register( - Register3D(dict(enumerate(coords))) + Register3D(dict(enumerate(bad_coords2))) ) with pytest.raises(ValueError, match="The minimal distance between atoms"): DigitalAnalogDevice.validate_register( - Register.triangular_lattice(3, 4, spacing=3.9) + Register.triangular_lattice(3, 4, spacing=good_spacing // 2) ) with pytest.raises( @@ -301,7 +313,9 @@ def test_validate_register(): tri_layout.hexagonal_register(10) ) - DigitalAnalogDevice.validate_register(Register.rectangle(5, 10, spacing=5)) + DigitalAnalogDevice.validate_register( + Register.rectangle(5, 10, spacing=good_spacing) + ) def test_validate_layout(): @@ -325,7 +339,7 @@ def test_validate_layout(): valid_layout = RegisterLayout( Register.square( int(np.sqrt(DigitalAnalogDevice.max_atom_num * 2)) - )._coords + )._coords_arr ) DigitalAnalogDevice.validate_layout(valid_layout) diff --git a/tests/test_eom.py b/tests/test_eom.py index 58f61833..ea63a4b2 100644 --- a/tests/test_eom.py +++ b/tests/test_eom.py @@ -98,6 +98,7 @@ def test_bad_controlled_beam(params): assert RydbergEOM(**params).controlled_beams == tuple(RydbergBeam) +@pytest.mark.parametrize("requires_grad", [False, True]) @pytest.mark.parametrize("limiting_beam", list(RydbergBeam)) @pytest.mark.parametrize("blue_shift_coeff", [0.5, 1.0, 2.0]) @pytest.mark.parametrize("red_shift_coeff", [0.5, 1.0, 1.8]) @@ -110,7 +111,11 @@ def test_detuning_off( multiple_beam_control, limit_amp_fraction, params, + requires_grad, ): + if requires_grad: + torch = pytest.importorskip("torch") + params["multiple_beam_control"] = multiple_beam_control params["blue_shift_coeff"] = blue_shift_coeff params["red_shift_coeff"] = red_shift_coeff @@ -142,19 +147,24 @@ def calc_offset(amp): limit_amp_ if limiting_beam == RydbergBeam.BLUE else non_limit_amp ) # The offset to have resonance when the pulse is on is -lightshift - return -( + return -float( blue_shift_coeff * blue_amp**2 - red_shift_coeff * red_amp**2 ) / (4 * params["intermediate_detuning"]) # Case where the EOM pulses are resonant detuning_on = 0.0 + if requires_grad: + amp = torch.tensor(amp, requires_grad=True) + detuning_on = torch.tensor(detuning_on, requires_grad=True) + zero_det = calc_offset(amp) # detuning when both beams are off = offset - assert np.isclose(eom._lightshift(amp, *RydbergBeam), -zero_det) + assert np.isclose(float(eom._lightshift(amp, *RydbergBeam)), -zero_det) assert eom._lightshift(amp) == 0.0 det_off_options = eom.detuning_off_options(amp, detuning_on) switching_beams_opts = eom._switching_beams_combos assert len(det_off_options) == len(switching_beams_opts) assert len(det_off_options) == 2 + multiple_beam_control + det_off_options = det_off_options.as_array(detach=True) order = np.argsort(det_off_options) det_off_options = det_off_options[order] switching_beams_opts = [switching_beams_opts[ind] for ind in order] @@ -180,9 +190,11 @@ def calc_offset(amp): ] ) assert calculated_det_off == min(det_off_options, key=abs) + if requires_grad: + assert calculated_det_off.as_tensor().requires_grad # Case where the EOM pulses are off-resonant - detuning_on = 1.0 + detuning_on = detuning_on + 1.0 for beam, ind in [(RydbergBeam.RED, next_), (RydbergBeam.BLUE, 0)]: # When only one beam is controlled, there is a single # detuning_off option @@ -192,7 +204,11 @@ def calc_offset(amp): assert len(off_options) == 1 # The new detuning_off is shifted by the new detuning_on, # since that changes the offset compared the resonant case - assert np.isclose(off_options[0], det_off_options[ind] + detuning_on) + assert np.isclose( + float(off_options[0]), det_off_options[ind] + float(detuning_on) + ) assert off_options[0] == eom_.calculate_detuning_off( amp, detuning_on, optimal_detuning_off=0.0 ) + if requires_grad: + assert off_options.as_tensor().requires_grad diff --git a/tests/test_math.py b/tests/test_math.py new file mode 100644 index 00000000..75aa0d50 --- /dev/null +++ b/tests/test_math.py @@ -0,0 +1,336 @@ +# Copyright 2024 Pulser Development Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import contextlib +import json +import sys + +import numpy as np +import pytest + +import pulser.math as pm +from pulser.json.abstract_repr.serializer import AbstractReprEncoder +from pulser.json.coders import PulserDecoder, PulserEncoder + + +@pytest.mark.parametrize( + "cast_to, requires_grad", + [(None, False), ("array", False), ("tensor", False), ("tensor", True)], +) +def test_pad(cast_to, requires_grad): + """Explicitly tested because it's the extensively rewritten.""" + arr = [1.0, 2.0, 3.0] + if cast_to == "array": + arr = np.array(arr) + elif cast_to == "tensor": + torch = pytest.importorskip("torch") + arr = torch.tensor(arr, requires_grad=requires_grad) + + def check_match(arr1: pm.AbstractArray, arr2): + if requires_grad: + assert arr1.as_tensor().requires_grad + np.testing.assert_array_equal( + arr1.as_array(detach=requires_grad), arr2 + ) + + # "constant" mode + + check_match( + pm.pad(arr, 2, mode="constant"), [0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0] + ) + check_match( + pm.pad(arr, (2, 1), mode="constant"), [0.0, 0.0, 1.0, 2.0, 3.0, 0.0] + ) + check_match( + pm.pad(arr, 1, mode="constant", constant_values=-1.0), + [-1.0, 1.0, 2.0, 3.0, -1.0], + ) + check_match( + pm.pad(arr, (1, 2), mode="constant", constant_values=-1.0), + [-1.0, 1.0, 2.0, 3.0, -1.0, -1.0], + ) + check_match( + pm.pad(arr, (1, 2), mode="constant", constant_values=(-1.0, 4.0)), + [-1.0, 1.0, 2.0, 3.0, 4.0, 4.0], + ) + + # "edge" mode + + check_match( + pm.pad(arr, 2, mode="edge"), [1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0] + ) + check_match( + pm.pad(arr, (2, 1), mode="edge"), [1.0, 1.0, 1.0, 2.0, 3.0, 3.0] + ) + check_match(pm.pad(arr, (0, 2), mode="edge"), [1.0, 2.0, 3.0, 3.0, 3.0]) + + +class TestAbstractArray: + + @pytest.mark.parametrize("force_array", [False, True]) + def test_no_torch(self, monkeypatch, force_array): + monkeypatch.setitem(sys.modules, "torch", None) + pm.AbstractArray.has_torch.cache_clear() + + val = 3.2 + arr = pm.AbstractArray(val, force_array=force_array, dtype=float) + assert not arr.is_tensor + with pytest.raises(RuntimeError, match="`torch` is not installed"): + arr.as_tensor() + + assert arr.size == 1 + assert arr.shape == ((1,) if force_array else ()) + assert arr.ndim == int(force_array) + assert arr.real == 3.2 + assert arr.dtype is np.dtype(float) + assert repr(arr) == repr(np.array(arr)) + assert arr.detach() == arr + + @pytest.mark.parametrize("force_array", [False, True]) + @pytest.mark.parametrize("requires_grad", [False, True]) + def test_with_torch(self, force_array, requires_grad): + pm.AbstractArray.has_torch.cache_clear() + torch = pytest.importorskip("torch") + + t = torch.tensor(1.0, requires_grad=requires_grad) + arr = pm.AbstractArray(t, force_array=force_array) + assert arr.is_tensor + assert arr.as_tensor() == t + assert arr.as_array(detach=requires_grad) == t.detach().numpy() + assert arr.detach() == pm.AbstractArray(t.detach()) + assert repr(arr) == repr(t[None] if force_array else t) + + @pytest.mark.parametrize("requires_grad", [False, True]) + def test_casting(self, requires_grad): + val = 4.1 + if requires_grad: + torch = pytest.importorskip("torch") + val = torch.tensor(val, requires_grad=True) + + arr = pm.AbstractArray(val) + assert int(arr) == int(val) + assert float(arr) == float(val) + assert bool(arr) == bool(val) + + @pytest.mark.parametrize("scalar", [False, True]) + @pytest.mark.parametrize("use_tensor", [False, True]) + def test_unary_ops(self, use_tensor, scalar): + val = np.linspace(-1, 1) + if scalar: + val = val[13] + if use_tensor: + torch = pytest.importorskip("torch") + val = torch.tensor(val) + lib = torch + else: + lib = np + + arr = pm.AbstractArray(val) + np.testing.assert_array_equal(-arr, -val) + np.testing.assert_array_equal(abs(arr), abs(val)) + np.testing.assert_array_equal(round(arr), lib.round(val)) + np.testing.assert_array_equal( + round(arr, 2), lib.round(val, decimals=2) + ) + + @pytest.mark.parametrize("scalar", [False, True]) + @pytest.mark.parametrize("use_tensor", [False, True]) + def test_comparison_ops(self, use_tensor, scalar): + min_, max_ = -1, 1 + val = np.linspace(min_, max_, endpoint=True) + if scalar: + val = val[13] + if use_tensor: + torch = pytest.importorskip("torch") + val = torch.tensor(val, requires_grad=True) + + arr = pm.AbstractArray(val) + assert np.all(arr < max_ + 1e-12) + assert np.all(arr <= max_) + assert np.all(arr > min_ - 1e-12) + assert np.all(arr >= min_) + assert np.all(arr == val) + assert np.all(arr != val * 5) + + @pytest.mark.parametrize("scalar", [False, True]) + @pytest.mark.parametrize("use_tensor", [False, True]) + def test_binary_ops(self, use_tensor, scalar): + values = np.linspace(-1, 1, endpoint=True) + if scalar: + val = values[13] + assert val != 0 + else: + val = values + if use_tensor: + torch = pytest.importorskip("torch") + val = torch.tensor(val) + + arr = pm.AbstractArray(val) + # add + np.testing.assert_array_equal(arr + 5.0, val + 5.0) + np.testing.assert_array_equal(arr + values, val + values) + np.testing.assert_array_equal(2.0 + arr, val + 2.0) + + # sub + np.testing.assert_array_equal(arr - 5.0, val - 5.0) + np.testing.assert_array_equal(arr - values, val - values) + np.testing.assert_array_equal(2.0 - arr, 2.0 - val) + + # mul + np.testing.assert_array_equal(arr * 5.0, val * 5.0) + np.testing.assert_array_equal(arr * values, val * values) + np.testing.assert_array_equal(2.0 * arr, val * 2.0) + + # truediv + np.testing.assert_array_equal(arr / 5.0, val / 5.0) + # Avoid zero division + np.testing.assert_array_equal( + arr / (values + 2.0), val / (values + 2.0) + ) + np.testing.assert_array_equal(2.0 / arr, 2.0 / val) + + # floordiv + np.testing.assert_array_equal(arr // 5.0, val // 5.0) + np.testing.assert_array_equal( + arr // (values + 2.0), val // (values + 2.0) + ) + np.testing.assert_array_equal(2.0 // arr, 2.0 // val) + + # pow + np.testing.assert_array_equal(arr**5.0, val**5.0) + + np.testing.assert_array_almost_equal( + abs(arr) ** values, abs(val) ** values + ) # rounding errors here + np.testing.assert_array_equal(2.0**arr, 2.0**val) + + # mod + np.testing.assert_array_equal(arr % 5.0, val % 5.0) + np.testing.assert_array_equal(arr % values, val % values) + np.testing.assert_array_equal(2.0 % arr, 2.0 % val) + + # matmul + if not scalar: + id_ = np.eye(len(arr)).tolist() + np.testing.assert_array_almost_equal(arr @ id_, val) + np.testing.assert_array_almost_equal(id_ @ arr, val) + + @pytest.mark.parametrize( + "indices", + [ + 4, + slice(None, -1), + slice(2, 8), + slice(9, None), + [1, -5, 8], + np.array([1, 2, 4]), + np.random.random(10) > 0.5, + ], + ) + @pytest.mark.parametrize( + "use_tensor, requires_grad", + [(False, False), (True, False), (True, True)], + ) + def test_items(self, use_tensor, requires_grad, indices): + val = np.linspace(-1, 1, endpoint=True, num=10) + if use_tensor: + torch = pytest.importorskip("torch") + val = torch.tensor(val, requires_grad=requires_grad) + + arr = pm.AbstractArray(val) + + # getitem + assert np.all(arr[indices] == pm.AbstractArray(val[indices])) + assert arr[indices].is_tensor == use_tensor + + # iter + for i, item in enumerate(arr): + assert item == val[i] + assert isinstance(item, pm.AbstractArray) + assert item.is_tensor == use_tensor + if use_tensor: + assert item.as_tensor().requires_grad == requires_grad + + # setitem + if not requires_grad: + arr[indices] = np.ones(len(val))[indices] + val[indices] = 1.0 + assert np.all(arr == val) + assert arr.is_tensor == use_tensor + + arr[indices] = np.pi + val[indices] = np.pi + assert np.all(arr == val) + assert arr.is_tensor == use_tensor + else: + with pytest.raises( + RuntimeError, + match="Failed to modify a tensor that requires grad in place.", + ): + arr[indices] = np.ones(len(val))[indices] + + if use_tensor: + # Check that a np.array is converted to tensor if assign a tensor + new_val = arr.as_array(detach=True) + arr_np = pm.AbstractArray(new_val) + assert not arr_np.is_tensor + arr_np[indices] = torch.zeros_like( + val, requires_grad=requires_grad + )[indices] + new_val[indices] = 0.0 + assert np.all(arr_np == new_val) + assert arr_np.is_tensor + # The resulting tensor requires grad if the assing one did + assert arr_np.as_tensor().requires_grad == requires_grad + + @pytest.mark.parametrize("scalar", [False, True]) + @pytest.mark.parametrize( + "use_tensor, requires_grad", + [(False, False), (True, False), (True, True)], + ) + def test_serialization(self, scalar, use_tensor, requires_grad): + values = np.linspace(-1, 1, endpoint=True) + if scalar: + val = values[13] + assert val != 0 + else: + val = values + + if use_tensor: + torch = pytest.importorskip("torch") + val = torch.tensor(val, requires_grad=requires_grad) + + arr = pm.AbstractArray(val) + + context = ( + pytest.raises( + NotImplementedError, + match="can't be serialized without losing the " + "computational graph", + ) + if requires_grad + else contextlib.nullcontext() + ) + + with context: + assert json.dumps(arr, cls=AbstractReprEncoder) == str( + float(val) if scalar else val.tolist() + ) + + with context: + legacy_ser = json.dumps(arr, cls=PulserEncoder) + deserialized = json.loads(legacy_ser, cls=PulserDecoder) + assert isinstance(deserialized, pm.AbstractArray) + np.testing.assert_array_equal(deserialized, val) diff --git a/tests/test_parametrized.py b/tests/test_parametrized.py index b94a7de0..7d0c4ccc 100644 --- a/tests/test_parametrized.py +++ b/tests/test_parametrized.py @@ -97,6 +97,19 @@ def test_var(a, b): b[[-3, 1]] +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_var_diff(a, b, requires_grad): + torch = pytest.importorskip("torch") + a._assign(torch.tensor(1.23, requires_grad=requires_grad)) + b._assign(torch.tensor([-1.0, 1.0], requires_grad=requires_grad)) + + for var in [a, b]: + assert ( + a.value is not None + and a.value.as_tensor().requires_grad == requires_grad + ) + + def test_varitem(a, b, d): a0 = a[0] b1 = b[1] @@ -116,8 +129,8 @@ def test_varitem(a, b, d): assert d0.build() == 0.5 with pytest.raises(FrozenInstanceError): b1.key = 0 - np.testing.assert_equal(b01.build(), b01_2.build()) - np.testing.assert_equal(b01_2.build(), b01_3.build()) + np.testing.assert_equal(b01.build().as_array(), b01_2.build().as_array()) + np.testing.assert_equal(b01_2.build().as_array(), b01_3.build().as_array()) with pytest.raises( TypeError, match=re.escape("len() of unsized variable item 'b[1]'") ): @@ -150,13 +163,32 @@ def test_paramobj(bwf, t, a, b): assert origin.build() == 0.0 -def test_opsupport(a, b): +@pytest.mark.parametrize("with_diff_tensor", [False, True]) +def test_opsupport(a, b, with_diff_tensor): + def check_var_grad(var): + if with_diff_tensor: + assert var.build().as_tensor().requires_grad + a._assign(-2.0) + if with_diff_tensor: + torch = pytest.importorskip("torch") + a._assign( + torch.tensor( + a.build().as_array().astype(float), requires_grad=True + ) + ) + # We need to make b's dtype=float so that it preserves the grad + bval = b.build().as_array().astype(float) + b = Variable("b", float, size=2) + b._assign(torch.tensor(bval, requires_grad=True)) + check_var_grad(a) + check_var_grad(b) u = 5 + a u = b - u # u = [-4, -2] u = u / 2 u = 8 * u # u = [-16, -8] u = -u // 3 # u = [5, 2] + check_var_grad(u) assert np.all(u.build() == [5.0, 2.0]) v = a**a @@ -167,6 +199,7 @@ def test_opsupport(a, b): assert v.build() == 1.0 v = -v assert v.build() == -1.0 + check_var_grad(v) x = a + 11 assert x.build() == 9 @@ -182,35 +215,70 @@ def test_opsupport(a, b): assert x.build() == 0.125 x = np.log2(x) assert x.build() == -3.0 + check_var_grad(x) # Trigonometric functions pi = -a * np.pi / 2 x = np.sin(pi) - np.testing.assert_almost_equal(x.build(), 0.0) + check_var_grad(x) + np.testing.assert_almost_equal( + x.build().as_array(detach=with_diff_tensor), 0.0 + ) x = np.cos(pi) - np.testing.assert_almost_equal(x.build(), -1.0) + check_var_grad(x) + np.testing.assert_almost_equal( + x.build().as_array(detach=with_diff_tensor), -1.0 + ) x = np.tan(pi / 4) - np.testing.assert_almost_equal(x.build(), 1.0) + check_var_grad(x) + np.testing.assert_almost_equal( + x.build().as_array(detach=with_diff_tensor), 1.0 + ) # Other transcendentals y = np.exp(b) - np.testing.assert_almost_equal(y.build(), [1 / np.e, np.e]) + check_var_grad(y) + np.testing.assert_almost_equal( + y.build().as_array(detach=with_diff_tensor), [1 / np.e, np.e] + ) y = np.log(y) - np.testing.assert_almost_equal(y.build(), b.build()) + check_var_grad(y) + np.testing.assert_almost_equal( + y.build().as_array(detach=with_diff_tensor), + b.build().as_array(detach=with_diff_tensor), + ) y_ = y + 0.4 # y_ = [-0.6, 1.4] y = np.round(y_, 1) - np.testing.assert_array_equal(y.build(), np.round(y_.build(), 1)) - np.testing.assert_array_equal(round(y_).build(), np.round(y_).build()) - np.testing.assert_array_equal(round(y_, 1).build(), y.build()) + np.testing.assert_array_equal( + y.build().as_array(detach=with_diff_tensor), + np.round(y_.build().as_array(detach=with_diff_tensor), 1), + ) + np.testing.assert_array_equal( + round(y_).build().as_array(detach=with_diff_tensor), + np.round(y_).build().as_array(detach=with_diff_tensor), + ) + np.testing.assert_array_equal( + round(y_, 1).build().as_array(detach=with_diff_tensor), + y.build().as_array(detach=with_diff_tensor), + ) y = round(y) - np.testing.assert_array_equal(y.build(), [-1.0, 1.0]) + np.testing.assert_array_equal( + y.build().as_array(detach=with_diff_tensor), [-1.0, 1.0] + ) y = np.floor(y + 0.1) - np.testing.assert_array_equal(y.build(), [-1.0, 1.0]) + np.testing.assert_array_equal( + y.build().as_array(detach=with_diff_tensor), [-1.0, 1.0] + ) y = np.ceil(y + 0.1) - np.testing.assert_array_equal(y.build(), [0.0, 2.0]) + np.testing.assert_array_equal( + y.build().as_array(detach=with_diff_tensor), [0.0, 2.0] + ) y = np.sqrt((y - 1) ** 2) - np.testing.assert_array_equal(y.build(), [1.0, 1.0]) + np.testing.assert_array_equal( + y.build().as_array(detach=with_diff_tensor), [1.0, 1.0] + ) + check_var_grad(y) # Test serialization support for operations def encode_decode(obj): @@ -223,19 +291,29 @@ def encode_decode(obj): assert set(u2.variables) == {"a", "b"} u2.variables["a"]._assign(a.value) u2.variables["b"]._assign(b.value) - np.testing.assert_array_equal(u2.build(), u.build()) + np.testing.assert_array_equal( + u2.build().as_array(detach=with_diff_tensor), + u.build().as_array(detach=with_diff_tensor), + ) + check_var_grad(u2) v2 = encode_decode(v) assert list(v2.variables) == ["a"] v2.variables["a"]._assign(a.value) assert v2.build() == v.build() + check_var_grad(v2) x2 = encode_decode(x) assert list(x2.variables) == ["a"] x2.variables["a"]._assign(a.value) assert x2.build() == x.build() + check_var_grad(x2) y2 = encode_decode(y) assert list(y2.variables) == ["b"] y2.variables["b"]._assign(b.value) - np.testing.assert_array_equal(y2.build(), y.build()) + np.testing.assert_array_equal( + y2.build().as_array(detach=with_diff_tensor), + y.build().as_array(detach=with_diff_tensor), + ) + check_var_grad(y2) diff --git a/tests/test_pasqal.py b/tests/test_pasqal.py index 76106194..5e134cea 100644 --- a/tests/test_pasqal.py +++ b/tests/test_pasqal.py @@ -224,7 +224,7 @@ def test_submit( ) mod_test_device = dataclasses.replace(test_device, max_atom_num=1000) seq3 = seq.switch_device(mod_test_device).switch_register( - pulser.Register.square(11, spacing=5) + pulser.Register.square(11, spacing=5, prefix="q") ) with pytest.raises( ValueError, @@ -233,7 +233,9 @@ def test_submit( fixt.pasqal_cloud.submit( seq3, job_params=[dict(runs=10)], mimic_qpu=mimic_qpu ) - seq4 = seq3.switch_register(pulser.Register.square(4, spacing=5)) + seq4 = seq3.switch_register( + pulser.Register.square(4, spacing=5, prefix="q") + ) # The sequence goes through QPUBackend.validate_sequence() with pytest.raises( ValueError, match="defined from a `RegisterLayout`" diff --git a/tests/test_pulse.py b/tests/test_pulse.py index 8c575a2b..fe51866a 100644 --- a/tests/test_pulse.py +++ b/tests/test_pulse.py @@ -54,6 +54,9 @@ def test_creation(): Pulse.ConstantAmplitude(-1, cwf, 0) Pulse.ConstantPulse(100, -1, 0, 0) + with pytest.raises(TypeError, match="'phase' must be a single float"): + Pulse(bwf, rwf, [0.0, 1.0, 2.0]) + assert pls.phase == 0 assert pls2 == pls3 assert pls != pls4 @@ -167,15 +170,18 @@ def test_arbitrary_phase(phase_wf, det_wf, phase_0): pls_ = Pulse.ArbitraryPhase(bwf, phase_wf) assert pls_ == Pulse(bwf, det_wf, phase_0) - calculated_phase = -np.cumsum(pls_.detuning.samples * 1e-3) + phase_0 + calculated_phase = -np.cumsum( + pls_.detuning.samples.as_array() * 1e-3 + ) + float(phase_0) + phase_samples = phase_wf.samples.as_array() assert np.allclose( calculated_phase % (2 * np.pi), - phase_wf.samples % (2 * np.pi), + phase_samples % (2 * np.pi), atol=PHASE_PRECISION, # The shift makes sure we don't fail around the wrapping point ) or np.allclose( (calculated_phase + 1) % (2 * np.pi), - (phase_wf.samples + 1) % (2 * np.pi), + (phase_samples + 1) % (2 * np.pi), atol=PHASE_PRECISION, ) @@ -225,3 +231,51 @@ def test_eq(): post_phase_shift=-1e-6, ) assert pls_ != repr(pls_) + + +def _assert_pulse_requires_grad(pulse: Pulse, invert: bool = False) -> None: + assert pulse.amplitude.samples.as_tensor().requires_grad == (not invert) + assert pulse.detuning.samples.as_tensor().requires_grad == (not invert) + assert pulse.phase.as_tensor().requires_grad == (not invert) + + +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_pulse_diff(requires_grad, eom_channel, patch_plt_show): + torch = pytest.importorskip("torch") + + duration = 1000 + diff_val = torch.tensor(1.0, requires_grad=requires_grad) + constant_wf = ConstantWaveform(duration, diff_val) + phase = torch.tensor(3.14, requires_grad=requires_grad) + phase_wf = RampWaveform( + duration, + phase - diff_val * 1e-3, + phase - diff_val * duration * 1e-3, + ) + assert torch.isclose(torch.tensor(phase_wf.slope), -diff_val * 1e-3) + + pulses: list[Pulse] = [ + Pulse(constant_wf, constant_wf, phase), + Pulse.ConstantDetuning(constant_wf, diff_val, phase), + Pulse.ConstantAmplitude(diff_val, constant_wf, phase), + Pulse.ConstantPulse(constant_wf.duration, diff_val, diff_val, phase), + Pulse.ArbitraryPhase(constant_wf, phase_wf), + ] + for i, pulse in enumerate(pulses): + _assert_pulse_requires_grad(pulse, invert=not requires_grad) + # Check other methods still work + assert pulse.duration == duration + assert pulse.get_full_duration( + eom_channel + ) == duration + pulse.fall_time(eom_channel) + + # Check all pulses are equal (by design) + for pulse2 in pulses[1:]: + assert str(pulses[0]) == str(pulse2) + assert repr(pulses[0]) == repr(pulse2) + assert pulses[0] == pulse2 + + # Extra checks for ArbitraryPhase (since it's more complex) + bwf = BlackmanWaveform(duration, diff_val) + phase_pulse = Pulse.ArbitraryPhase(constant_wf, bwf) + _assert_pulse_requires_grad(phase_pulse, invert=not requires_grad) diff --git a/tests/test_register.py b/tests/test_register.py index 03b571ec..294bff8f 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from unittest.mock import patch import numpy as np @@ -84,6 +86,17 @@ def test_creation(): Register(qubits, spacing=10, layout="square", trap_ids=(0, 1, 3)) +def test_repr(): + assert ( + repr(Register(dict(q0=(1.0, 0.0), q1=(-1, 5)))) + == "Register({'q0': array([1., 0.]), 'q1': array([-1., 5.])})" + ) + assert ( + repr(Register3D(dict(q0=(1, 2, 3)))) + == "Register3D({'q0': array([1., 2., 3.])})" + ) + + def test_rectangular_lattice(): # Check rows with pytest.raises(ValueError, match="The number of rows"): @@ -292,7 +305,9 @@ def test_rotation(): reg = Register.square(2, spacing=np.sqrt(2)) rot_reg = reg.rotated(45) new_coords_ = np.array([(0, -1), (1, 0), (-1, 0), (0, 1)], dtype=float) - np.testing.assert_allclose(rot_reg._coords, new_coords_, atol=1e-15) + np.testing.assert_allclose( + rot_reg._coords_arr.as_array(), new_coords_, atol=1e-15 + ) assert rot_reg != reg @@ -466,8 +481,8 @@ def test_coords_hash(): reg1 = Register.square(2, prefix="foo") reg2 = Register.rectangle(2, 2, prefix="bar") assert reg1 != reg2 # Ids are different - coords1 = list(reg1.qubits.values()) - coords2 = list(reg2.qubits.values()) + coords1 = list(c.as_array() for c in reg1.qubits.values()) + coords2 = list(c.as_array() for c in reg2.qubits.values()) np.testing.assert_equal(coords1, coords2) # But coords are the same assert reg1.coords_hex_hash() == reg2.coords_hex_hash() @@ -484,3 +499,91 @@ def test_coords_hash(): coords1[0][1] += 1e-6 reg5 = Register.from_coordinates(coords1) assert reg1.coords_hex_hash() != reg5.coords_hex_hash() + + +def _assert_reg_requires_grad( + reg: Register | Register3D, invert: bool = False +) -> None: + for coords in reg.qubits.values(): + if invert: + assert not coords.as_tensor().requires_grad + else: + assert coords.is_tensor and coords.as_tensor().requires_grad + + +@pytest.mark.parametrize( + "register_type, coords", + [ + (Register, [[1.0, -4.0], [0.0, 0.0]]), + (Register3D, [[1.0, -4.0, 5.0], [0.0, 0.0, 0.0]]), + ], +) +def test_custom_register_torch(register_type, coords, patch_plt_show): + torch = pytest.importorskip("torch") + + diff_qubit = torch.tensor(coords[0], requires_grad=True) + + reg1 = register_type({"q0": diff_qubit, "q1": coords[1]}) + reg2 = register_type.from_coordinates( + [diff_qubit, coords[1]], center=False, prefix="q" + ) + assert reg1 == reg2 + + # Also check that centering keeps the grad + reg3 = register_type.from_coordinates([diff_qubit, coords[1]], center=True) + assert torch.all(reg3.qubits[0].as_tensor() == diff_qubit / 2) + + for r in [reg1, reg2, reg3]: + _assert_reg_requires_grad(r) + if r.dimensionality == 2: + # Check after rotation + _assert_reg_requires_grad(r.rotated(30)) + else: + # Check after conversion to 2D + _assert_reg_requires_grad(r.to_2D(0.1)) + + # Check that drawing still works too + r.draw() + + +@pytest.mark.parametrize( + "reg_classmethod, param_name, extra_params", + [ + (Register.square, "spacing", {"side": 2}), + (Register.rectangle, "spacing", {"rows": 1, "columns": 3}), + ( + Register.rectangular_lattice, + "row_spacing", + {"rows": 1, "columns": 3}, + ), + ( + Register.rectangular_lattice, + "col_spacing", + {"rows": 1, "columns": 3}, + ), + ( + Register.triangular_lattice, + "spacing", + {"rows": 3, "atoms_per_row": 5}, + ), + (Register.hexagon, "spacing", {"layers": 5}), + ( + Register.max_connectivity, + "spacing", + {"n_qubits": 20, "device": DigitalAnalogDevice}, + ), + (Register3D.cubic, "spacing", {"side": 3}), + (Register3D.cuboid, "spacing", {"rows": 4, "columns": 2, "layers": 5}), + ], +) +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_register_recipes_torch( + reg_classmethod, param_name, extra_params, requires_grad +): + torch = pytest.importorskip("torch") + kwargs = { + param_name: torch.tensor(6.0, requires_grad=requires_grad), + **extra_params, + } + reg = reg_classmethod(**kwargs) + _assert_reg_requires_grad(reg, invert=not requires_grad) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 876cd56e..5979c464 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -73,7 +73,7 @@ def test_init(reg, device): Sequence(reg, Device) seq = Sequence(reg, device) - assert seq.qubit_info == reg.qubits + assert Register(seq.qubit_info) == reg assert seq.declared_channels == {} assert ( seq.available_channels.keys() @@ -1381,7 +1381,6 @@ def test_str(reg, device, mod_device, det_map): ) measure_msg = "\n\nMeasured in basis: digital" - print(seq) assert seq.__str__() == msg_ch0 + msg_ch1 + msg_det_map + measure_msg seq2 = Sequence(Register({"q0": (0, 0), 1: (5, 5)}), device) @@ -2338,7 +2337,7 @@ def test_eom_mode( ) assert np.isclose( seq.current_phase_ref("q0", basis="ground-rydberg"), - phase_ref % (2 * np.pi), + float(phase_ref) % (2 * np.pi), ) # Add delay to test the phase drift correction in disable_eom_mode @@ -2349,7 +2348,7 @@ def test_eom_mode( phase_ref += new_eom_block.detuning_off * last_delay_time * 1e-3 assert np.isclose( seq.current_phase_ref("q0", basis="ground-rydberg"), - phase_ref % (2 * np.pi), + float(phase_ref) % (2 * np.pi), ) # Test drawing in eom mode @@ -2495,3 +2494,62 @@ def test_add_to_dmm_fails(reg, device, det_map): seq.declare_channel("ryd", "rydberg_global") with pytest.raises(ValueError, match="not the name of a DMM channel"): seq.add_dmm_detuning(pulse.detuning, "ryd") + + +@pytest.mark.parametrize( + "with_eom, with_modulation", [(True, True), (True, False), (False, False)] +) +@pytest.mark.parametrize("parametrized", [True, False]) +def test_sequence_diff(device, parametrized, with_modulation, with_eom): + torch = pytest.importorskip("torch") + reg = Register( + {"q0": torch.tensor([0.0, 0.0], requires_grad=True), "q1": (-5.0, 5.0)} + ) + seq = Sequence(reg, AnalogDevice if with_eom else device) + seq.declare_channel("ryd_global", "rydberg_global") + + if parametrized: + amp = seq.declare_variable("amp", dtype=float) + dets = seq.declare_variable("dets", dtype=float, size=2) + else: + amp = torch.tensor(1.0, requires_grad=True) + dets = torch.tensor([-2.0, -1.0], requires_grad=True) + + # The phase is never a variable so we're sure the gradient + # is kept after build + phase = torch.tensor(2.0, requires_grad=True) + + if with_eom: + seq.enable_eom_mode("ryd_global", amp, dets[0], dets[1]) + seq.add_eom_pulse("ryd_global", 100, phase, correct_phase_drift=False) + seq.delay(100, "ryd_global") + seq.modify_eom_setpoint("ryd_global", amp * 2, dets[1], -dets[0]) + seq.add_eom_pulse("ryd_global", 100, -phase, correct_phase_drift=True) + seq.disable_eom_mode("ryd_global") + + else: + pulse = Pulse.ConstantDetuning( + BlackmanWaveform(1000, amp), dets[0], phase + ) + seq.add(pulse, "ryd_global") + det_map = reg.define_detuning_map({"q0": 1.0}) + seq.config_detuning_map(det_map, "dmm_0") + seq.add_dmm_detuning(RampWaveform(2000, *dets), "dmm_0") + + if parametrized: + seq = seq.build( + amp=torch.tensor(1.0, requires_grad=True), + dets=torch.tensor([-2.0, -1.0], requires_grad=True), + ) + + seq_samples = sample(seq, modulation=with_modulation) + ryd_ch_samples = seq_samples.channel_samples["ryd_global"] + assert ryd_ch_samples.amp.as_tensor().requires_grad + assert ryd_ch_samples.det.as_tensor().requires_grad + assert ryd_ch_samples.phase.as_tensor().requires_grad + if "dmm_0" in seq_samples.channel_samples: + dmm_ch_samples = seq_samples.channel_samples["dmm_0"] + # Only detuning is modulated + assert not dmm_ch_samples.amp.as_tensor().requires_grad + assert dmm_ch_samples.det.as_tensor().requires_grad + assert not dmm_ch_samples.phase.as_tensor().requires_grad diff --git a/tests/test_sequence_sampler.py b/tests/test_sequence_sampler.py index d78f4a29..8363825e 100644 --- a/tests/test_sequence_sampler.py +++ b/tests/test_sequence_sampler.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import re from copy import deepcopy from dataclasses import replace from typing import Literal @@ -21,6 +22,7 @@ import pytest import pulser +import pulser.math as pm import pulser_simulation from pulser.channels.dmm import DMM from pulser.devices import Device, MockDevice @@ -168,12 +170,12 @@ def test_modulation(mod_seq: pulser.Sequence) -> None: blackman = np.clip(np.blackman(N), 0, np.inf) input = (np.pi / 2) / (np.sum(blackman) / N) * blackman - want_amp = chan.modulate(input) + want_amp = chan.modulate(input).as_array() mod_samples = sample(mod_seq, modulation=True) got_amp = mod_samples.to_nested_dict()["Global"]["ground-rydberg"]["amp"] - np.testing.assert_array_equal(got_amp, want_amp) + np.testing.assert_allclose(got_amp, want_amp) - want_det = chan.modulate(np.ones(N), keep_ends=True) + want_det = chan.modulate(np.ones(N), keep_ends=True).as_array() got_det = mod_samples.to_nested_dict()["Global"]["ground-rydberg"]["det"] np.testing.assert_array_equal(got_det, want_det) @@ -189,8 +191,8 @@ def test_modulation(mod_seq: pulser.Sequence) -> None: for qty in ("amp", "det", "phase", "centered_phase"): np.testing.assert_array_equal( - getattr(input_ch_samples.modulate(chan), qty), - getattr(output_ch_samples, qty), + getattr(input_ch_samples.modulate(chan), qty).as_array(), + getattr(output_ch_samples, qty).as_array(), ) # input samples don't have a custom centered phase, output samples do @@ -294,11 +296,12 @@ def test_eom_modulation(mod_device, disable_eom): want = eom_output + aom_output # Check that modulation through sample() = sample() + modulation - got = getattr(mod_samples.channel_samples["ch0"], qty) - alt_got = getattr(input_samples.modulate(chan, full_duration), qty) + got = getattr(mod_samples.channel_samples["ch0"], qty).as_array() + alt_got = getattr( + input_samples.modulate(chan, full_duration), qty + ).as_array() np.testing.assert_array_equal(got, alt_got) - - np.testing.assert_allclose(want, got, atol=1e-10) + np.testing.assert_allclose(want.as_array(), got, atol=1e-10) def test_seq_with_DMM_and_map_reg(): @@ -422,12 +425,12 @@ def test_extend_duration(seq_rydberg, with_custom_centered_phase): extended_short = short.extend_duration(long.duration) assert extended_short.duration == long.duration for qty in ("amp", "det", "phase", "centered_phase"): - new_qty_samples = getattr(extended_short, qty) - old_qty_samples = getattr(short, qty) + new_qty_samples = getattr(extended_short, qty).as_array() + old_qty_samples = getattr(short, qty).as_array() np.testing.assert_array_equal( new_qty_samples[: short.duration], old_qty_samples ) - np.testing.assert_equal( + np.testing.assert_array_equal( new_qty_samples[short.duration :], old_qty_samples[-1] if "phase" in qty else 0.0, ) @@ -471,16 +474,20 @@ def test_phase_sampling(mod_device): expected_phase[transition3_4:] = 4.0 got_phase = (ch_samples_ := sample(seq).channel_samples["ch0"]).phase - np.testing.assert_array_equal(expected_phase, got_phase) + np.testing.assert_array_equal(expected_phase, got_phase.as_array()) # Test centered phase expected_phase[expected_phase > np.pi] -= 2 * np.pi np.testing.assert_array_equal(expected_phase, ch_samples_.centered_phase) +@pytest.mark.parametrize("with_diff", [False, True]) @pytest.mark.parametrize("off_center", [False, True]) -def test_phase_modulation(off_center): +def test_phase_modulation(off_center, with_diff): start_phase = np.pi / 2 + np.pi * off_center + if with_diff: + torch = pytest.importorskip("torch") + start_phase = torch.tensor(start_phase, requires_grad=True) phase1 = pulser.RampWaveform(400, start_phase, 0) phase2 = pulser.BlackmanWaveform(500, np.pi) phase3 = pulser.InterpolatedWaveform(500, [0, 11, 1, 5]) @@ -494,9 +501,17 @@ def test_phase_modulation(off_center): seq.add(pulse, "rydberg_global") seq_samples = sample(seq).channel_samples["rydberg_global"] + if with_diff: + assert full_phase.samples.as_tensor().requires_grad + assert not seq_samples.amp.as_tensor().requires_grad + assert seq_samples.det.as_tensor().requires_grad + assert seq_samples.phase.as_tensor().requires_grad + assert seq_samples.phase_modulation.as_tensor().requires_grad + np.testing.assert_allclose( - seq_samples.phase_modulation + 2 * np.pi * off_center, - full_phase.samples, + seq_samples.phase_modulation.as_array(detach=with_diff) + + 2 * np.pi * off_center, + full_phase.samples.as_array(detach=with_diff), atol=PHASE_PRECISION, ) @@ -526,6 +541,44 @@ def test_draw_samples( ) +@pytest.mark.parametrize("all_local", [False, True]) +@pytest.mark.parametrize("samples_type", ["array", "abstract", "tensor"]) +def test_to_nested_dict_samples_type(mod_seq, samples_type, all_local): + samples = sample(mod_seq) + with pytest.raises( + ValueError, + match=re.escape( + "'samples_type' must be one of ('abstract', 'array', 'tensor')," + " not 'jax'." + ), + ): + samples.to_nested_dict(samples_type="jax") + + if samples_type == "tensor": + expected_type = pytest.importorskip("torch").Tensor + elif samples_type == "array": + expected_type = np.ndarray + else: + assert samples_type == "abstract" + expected_type = pm.AbstractArray + + nested_dict = samples.to_nested_dict( + samples_type=samples_type, all_local=all_local + ) + + if all_local: + assert not nested_dict["Global"] + samples_per_qubit = nested_dict["Local"]["ground-rydberg"] + for qsamples in samples_per_qubit.values(): + for arr_ in qsamples.values(): + assert isinstance(arr_, expected_type) + else: + assert not nested_dict["Local"] + samples_arrs = nested_dict["Global"]["ground-rydberg"] + for arr_ in samples_arrs.values(): + assert isinstance(arr_, expected_type) + + # Fixtures diff --git a/tests/test_simresults.py b/tests/test_simresults.py index 3941923f..e366fa5a 100644 --- a/tests/test_simresults.py +++ b/tests/test_simresults.py @@ -236,7 +236,7 @@ def test_get_state_float_time(results): results.get_state(mean, t_tol=diff / 2) state = results.get_state(mean, t_tol=3 * diff / 2) assert state == results.get_state(results._sim_times[-2]) - assert np.isclose( + np.testing.assert_allclose( state.full(), np.array( [ @@ -246,7 +246,8 @@ def test_get_state_float_time(results): [-0.27977172 - 0.11031832j], ] ), - ).all() + atol=1e-5, + ) def test_expect(results, pi_pulse, reg): diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 6c7d9cc0..dfd4aec0 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -159,7 +159,7 @@ def test_initialization_and_construction_of_hamiltonian(seq, mod_device): for ch in sampled_seq.channels ] ) - assert sim._hamiltonian._qdict == seq.qubit_info + assert Register(sim._hamiltonian._qdict) == Register(seq.qubit_info) assert sim._hamiltonian._size == len(seq.qubit_info) assert sim._tot_duration == 9000 # seq has 9 pulses of 1µs assert sim._hamiltonian._qid_index == { @@ -218,35 +218,35 @@ def test_extraction_of_sequences(seq): for slot in seq._schedule[channel]: if isinstance(slot.type, Pulse): samples = sim._hamiltonian.samples[addr][basis] - assert ( + assert np.all( samples["amp"][slot.ti : slot.tf] == slot.type.amplitude.samples - ).all() - assert ( + ) + assert np.all( samples["det"][slot.ti : slot.tf] == slot.type.detuning.samples - ).all() - assert ( + ) + assert np.all( samples["phase"][slot.ti : slot.tf] == slot.type.phase - ).all() + ) elif addr == "Local": for slot in seq._schedule[channel]: if isinstance(slot.type, Pulse): for qubit in slot.targets: # TO DO: multiaddressing?? samples = sim._hamiltonian.samples[addr][basis][qubit] - assert ( + assert np.all( samples["amp"][slot.ti : slot.tf] == slot.type.amplitude.samples - ).all() - assert ( + ) + assert np.all( samples["det"][slot.ti : slot.tf] == slot.type.detuning.samples - ).all() - assert ( + ) + assert np.all( samples["phase"][slot.ti : slot.tf] == slot.type.phase - ).all() + ) @pytest.mark.parametrize("leakage", [False, True]) @@ -482,7 +482,7 @@ def test_get_hamiltonian(): simple_seq, config=SimConfig(noise="doppler", temperature=20000) ) simple_ham_noise = simple_sim_noise.get_hamiltonian(144) - assert np.isclose( + np.testing.assert_allclose( simple_ham_noise.full(), np.array( [ @@ -507,7 +507,7 @@ def test_get_hamiltonian(): [0.0 + 0.0j, 0.09606404 + 0.0j, 0.09606404 + 0.0j, 0.0 + 0.0j], ] ), - ).all() + ) def test_single_atom_simulation(): @@ -1593,7 +1593,7 @@ def test_simulation_with_modulation(mod_device, reg, patch_plt_show): seq.add(pulse1, "ch1") seq.add(pulse1, "ch0") ch1_obj = seq.declared_channels["ch1"] - pulse1_mod_samples = ch1_obj.modulate(pulse1.amplitude.samples) + pulse1_mod_samples = ch1_obj.modulate(pulse1.amplitude.samples).as_array() mod_dt = pulse1.duration + pulse1.fall_time(ch1_obj) assert pulse1_mod_samples.size == mod_dt @@ -1621,11 +1621,11 @@ def test_simulation_with_modulation(mod_device, reg, patch_plt_show): sim._hamiltonian._doppler_detune[qid], ) np.testing.assert_allclose( - raman_samples[qid]["phase"][time_slice], pulse1.phase + raman_samples[qid]["phase"][time_slice], float(pulse1.phase) ) def pos_factor(qid): - r = np.linalg.norm(reg.qubits[qid]) + r = np.linalg.norm(reg.qubits[qid].as_array()) w0 = sim_config.laser_waist return np.exp(-((r / w0) ** 2)) @@ -1645,7 +1645,7 @@ def pos_factor(qid): sim._hamiltonian._doppler_detune[qid], ) np.testing.assert_allclose( - rydberg_samples[qid]["phase"][time_slice], pulse1.phase + rydberg_samples[qid]["phase"][time_slice], float(pulse1.phase) ) with pytest.warns( DeprecationWarning, match="The `Simulation` class is deprecated" diff --git a/tests/test_waveforms.py b/tests/test_waveforms.py index bdfc7bf4..8357d8d4 100644 --- a/tests/test_waveforms.py +++ b/tests/test_waveforms.py @@ -46,7 +46,7 @@ def test_duration(): - with pytest.raises(TypeError, match="needs to be castable to an int"): + with pytest.raises(TypeError, match="needs to be castable to int"): ConstantWaveform("s", -1) RampWaveform([0, 1, 3], 1, 0) @@ -84,11 +84,11 @@ def test_change_duration(): def test_samples(): - assert np.all(constant.samples == -3) + assert np.all(constant.samples.as_array() == -3) bm_samples = np.clip(np.blackman(40), 0, np.inf) bm_samples *= np.pi / np.sum(bm_samples) / 1e-3 comp_samples = np.concatenate([bm_samples, np.full(100, -3), arb_samples]) - assert np.all(np.isclose(composite.samples, comp_samples)) + assert np.all(np.isclose(composite.samples.as_array(), comp_samples)) def test_integral(): @@ -232,10 +232,14 @@ def test_interpolated(): dt, [0, 1], interpolator="interp1d", kind="linear" ) assert isinstance(interp_wf.interp_function, interp1d) - np.testing.assert_allclose(interp_wf.samples, np.linspace(0, 1.0, num=dt)) + np.testing.assert_allclose( + interp_wf.samples.as_array(), np.linspace(0, 1.0, num=dt) + ) interp_wf *= 2 - np.testing.assert_allclose(interp_wf.samples, np.linspace(0, 2.0, num=dt)) + np.testing.assert_allclose( + interp_wf.samples.as_array(), np.linspace(0, 2.0, num=dt) + ) wf_str = "InterpolatedWaveform(Points: (0, 0), (999, 2)" assert str(interp_wf) == wf_str + ")" @@ -246,14 +250,16 @@ def test_interpolated(): dt, vals, interpolator="interp1d", kind="quadratic" ) np.testing.assert_allclose( - interp_wf2.samples, np.linspace(0, 1, num=dt) ** 2, atol=1e-3 + interp_wf2.samples.as_array(), + np.linspace(0, 1, num=dt) ** 2, + atol=1e-3, ) # Test rounding when range of values is large wf = InterpolatedWaveform( 1000, times=[0.0, 0.5, 1.0], values=[0, 2.6e7, 0] ) - assert np.all(wf.samples >= 0) + assert np.all((wf.samples >= 0).as_array()) def test_kaiser(): @@ -262,6 +268,7 @@ def test_kaiser(): beta: float = 14.0 wf: KaiserWaveform = KaiserWaveform(duration, area, beta) + wf_samples = wf.samples.as_array() # Check type error on area with pytest.raises(TypeError): @@ -284,17 +291,19 @@ def test_kaiser(): kaiser_beta_14: np.ndarray = np.kaiser(duration, 14.0) kaiser_beta_14 *= area / float(np.sum(kaiser_beta_14)) / 1e-3 np.testing.assert_allclose( - wf_default_beta.samples, kaiser_beta_14, atol=1e-3 + wf_default_beta.samples.as_array(), kaiser_beta_14, atol=1e-3 ) # Check area - assert np.isclose(np.sum(wf.samples), area * 1000.0) + assert np.isclose(np.sum(wf_samples), area * 1000.0) # Check duration change new_duration = duration * 2 wf_change_duration = wf.change_duration(new_duration) assert wf_change_duration.samples.size == new_duration - assert np.isclose(np.sum(wf.samples), np.sum(wf_change_duration.samples)) + assert np.isclose( + np.sum(wf_samples), np.sum(wf_change_duration.samples.as_array()) + ) # Check __str__ assert str(wf) == ( @@ -309,7 +318,7 @@ def test_kaiser(): # Check multiplication wf_multiplication = wf * 2 - assert (wf_multiplication.samples == wf.samples * 2).all() + assert np.all(wf_multiplication.samples == wf_samples * 2) # Check area and max_val must have matching signs with pytest.raises(ValueError, match="must have matching signs"): @@ -319,11 +328,11 @@ def test_kaiser(): for max_val in range(1, 501, 50): for beta in range(1, 20): wf = KaiserWaveform.from_max_val(max_val, area, beta) - assert np.isclose(np.sum(wf.samples), area * 1000.0) - assert np.max(wf.samples) <= max_val + assert np.isclose(np.sum(wf.samples.as_array()), area * 1000.0) + assert np.max(wf.samples.as_array()) <= max_val wf = KaiserWaveform.from_max_val(-max_val, -area, beta) - assert np.isclose(np.sum(wf.samples), -area * 1000.0) - assert np.min(wf.samples) >= -max_val + assert np.isclose(np.sum(wf.samples.as_array()), -area * 1000.0) + assert np.min(wf.samples.as_array()) >= -max_val def test_ops(): @@ -386,44 +395,48 @@ def test_get_item(): # Check with slices - assert (wf[0:duration] == samples).all() - assert (wf[0:-1] == samples[0:-1]).all() - assert (wf[0:] == samples).all() - assert (wf[-1:] == samples[-1:]).all() - assert (wf[:duration] == samples).all() - assert (wf[:] == samples).all() - assert ( + assert np.all(wf[0:duration] == samples) + assert np.all(wf[0:-1] == samples[0:-1]) + assert np.all(wf[0:] == samples) + assert np.all(wf[-1:] == samples[-1:]) + assert np.all(wf[:duration] == samples) + assert np.all(wf[:] == samples) + assert np.all( wf[duration14:duration34] == samples[duration14:duration34] - ).all() - assert ( + ) + assert np.all( wf[-duration34:-duration14] == samples[-duration34:-duration14] - ).all() + ) # Check with out of bounds slices - assert (wf[: duration * 2] == samples).all() - assert (wf[-duration * 2 :] == samples).all() - assert (wf[-duration * 2 : duration * 2] == samples).all() - assert ( + assert np.all(wf[: duration * 2] == samples) + assert np.all(wf[-duration * 2 :] == samples) + assert np.all(wf[-duration * 2 : duration * 2] == samples) + assert np.all( wf[duration // 2 : duration * 2] == samples[duration // 2 : duration * 2] - ).all() - assert ( + ) + assert np.all( wf[-duration * 2 : duration // 2] == samples[-duration * 2 : duration // 2] - ).all() + ) assert wf[2:1].size == 0 assert wf[duration * 2 :].size == 0 assert wf[duration * 2 : duration * 3].size == 0 assert wf[-duration * 3 : -duration * 2].size == 0 -def test_modulation(): - rydberg_global = Rydberg.Global( +@pytest.fixture +def rydberg_global(): + return Rydberg.Global( 2 * np.pi * 20, 2 * np.pi * 2.5, mod_bandwidth=4, # MHz ) - mod_samples = constant.modulated_samples(rydberg_global) + + +def test_modulation(rydberg_global): + mod_samples = constant.modulated_samples(rydberg_global).as_array() assert np.all(mod_samples == rydberg_global.modulate(constant.samples)) assert constant.modulation_buffers(rydberg_global) == ( rydberg_global.rise_time, @@ -432,3 +445,74 @@ def test_modulation(): assert len(mod_samples) == constant.duration + 2 * rydberg_global.rise_time assert np.isclose(np.sum(mod_samples) * 1e-3, constant.integral) assert max(np.abs(mod_samples)) < np.abs(constant[0]) + + +@pytest.mark.parametrize( + "wf_type, diff_param_name, diff_param_value, extra_params", + [ + (CustomWaveform, "samples", np.arange(-10.0, 10.0), {}), + (ConstantWaveform, "value", -3.14, {"duration": 20}), + (RampWaveform, "start", -10.0, {"duration": 10, "stop": 10}), + (RampWaveform, "stop", -10.0, {"duration": 10, "start": 10}), + (BlackmanWaveform, "area", 2.0, {"duration": 200}), + (BlackmanWaveform.from_max_val, "area", -2.0, {"max_val": -1}), + (KaiserWaveform, "area", -2.0, {"duration": 200}), + (KaiserWaveform.from_max_val, "area", 2.0, {"max_val": 1}), + ], +) +@pytest.mark.parametrize("requires_grad", [True, False]) +@pytest.mark.parametrize("composite", [True, False]) +def test_waveform_diff( + wf_type, + diff_param_name, + diff_param_value, + extra_params, + requires_grad, + composite, + rydberg_global, + patch_plt_show, +): + torch = pytest.importorskip("torch") + kwargs = { + diff_param_name: torch.tensor( + diff_param_value, requires_grad=requires_grad + ), + **extra_params, + } + wf = wf_type(**kwargs) + if composite: + wf = CompositeWaveform(wf, ConstantWaveform(100, 1.0)) + + samples_tensor = wf.samples.as_tensor() + assert samples_tensor.requires_grad == requires_grad + assert ( + wf.modulated_samples(rydberg_global).as_tensor().requires_grad + == requires_grad + ) + wfx2_tensor = (-wf * 2).samples.as_tensor() + assert torch.equal(wfx2_tensor, samples_tensor * -2.0) + assert wfx2_tensor.requires_grad == requires_grad + + wfdiv2 = wf / torch.tensor(2.0, requires_grad=True) + assert torch.equal(wfdiv2.samples.as_tensor(), samples_tensor / 2.0) + # Should always be true because it was divided by diff tensor + assert wfdiv2.samples.as_tensor().requires_grad + + assert wf[-1].as_tensor().requires_grad == requires_grad + + try: + assert ( + wf.change_duration(1000).samples.as_tensor().requires_grad + == requires_grad + ) + except NotImplementedError: + pass + + # Check that all non-related methods still work + wf.draw(output_channel=rydberg_global) + repr(wf) + str(wf) + hash(wf) + wf._to_dict() + wf._to_abstract_repr() + assert isinstance(wf.integral, float)