Skip to content

Commit

Permalink
Support differentiability through Torch tensors (#703)
Browse files Browse the repository at this point in the history
* Defining pulser.math and AbstractArray

* POC: Differentiable constant pulse amp

Typing is still failing

* Fix typing in waveforms

* Fix all typing errors in POC

* Pass all existing UTs

* Pass all UTs without array support

* Fix typing

* All tests pass with torch installed

* Add support for pulser-diff backend (#686)

* works with basic features of pulser-diff

* Fixed phase attribute setting; removed debugging code; reverted unnecessary changes

* Modified register creation code to work with AbstractArray; register coordinates are differentiable with pulser-diff

* Fixed type hints

* Minor fixes and refactoring

* Modified ParamObj code to work with quantum model training in pulser-diff

* Minor refactoring; add possibility to ensure 0D AbstractArray is reshaped into 1D

* Force array only for scalars

* Fix UTs after pulser-diff changes

* Avoid using AbstractArrayLike outside of pulser.math

* Preserve gradient in EOM mode

* Add torch as an optional requirement

* Support waveform multiplication with abstract array

* Explicitly marking the differentiable parameters

* Remove __array_wrap__

* Pass relevant UTs without array support

* Support new features

* Using pm.Differentiable whenever possible

* Simplifying Waveform.__getitem__() type hint

* UTs for new features outside of pulser.math

* Write torch UTs for registers

* Write UTs for waveforms

* UTs for pulse

* UTs for EOM

* UTs on internal functionality

* UTs for Sequence with autograd

* Implicitly cover math functions

* Removing AbstractArray.__hash__() and differentiable phase shifts

* Finish unit tests

* Update CI to run tests with and without torch

* Fix CI errors

* Fix failing no-torch UT

* Minor corrections

* Include pulser[torch] installation in the README

* Fix warning in UT after merge

* Incorporating the latest changes

* Fix typing

* Addressing review comments

* Including `detach()` in Differentiable protocol

* Differentiable -> TensorLike

* Tentatively allow waveform division by array

* Full coverage

---------

Co-authored-by: Vytautas Abramavicius <[email protected]>
  • Loading branch information
HGSilveri and vytautas-a authored Sep 17, 2024
1 parent 6c12156 commit e21d3a8
Show file tree
Hide file tree
Showing 49 changed files with 2,224 additions and 563 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ per-file-ignores =
tests/*: D100, D101, D102, D103
__init__.py: F401
pulser-core/pulser/backends.py: F401
pulser-core/pulser/math/__init__.py: D103
setup.py: D100
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.12"]
with-torch: ["with-torch", "no-torch"]
steps:
- name: Check out Pulser
uses: actions/checkout@v4
Expand All @@ -67,8 +68,13 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
extra-packages: pytest
with-torch: ${{ matrix.with-torch }}
- name: Run the unit tests & generate coverage report
if: ${{ matrix.with-torch == 'with-torch' }}
run: pytest --cov --cov-fail-under=100
- name: Run the unit tests without torch installed
if: ${{ matrix.with-torch != 'with-torch' }}
run: pytest --cov
- name: Test validation with legacy jsonschema
run: |
pip install jsonschema==4.17.3
Expand Down
13 changes: 12 additions & 1 deletion .github/workflows/pulser-setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ inputs:
description: Extra packages to install (give to grep)
required: false
default: ""
with-torch:
description: Whether to include pytorch
required: false
default: "with-torch"
runs:
using: "composite"
steps:
Expand All @@ -17,11 +21,18 @@ runs:
with:
python-version: ${{ inputs.python-version }}
cache: "pip"
- name: Install Pulser
- name: Install Pulser (with torch)
if: ${{ inputs.with-torch == 'with-torch' }}
shell: bash
run: |
python -m pip install --upgrade pip
make dev-install
- name: Install Pulser (without torch)
if: ${{ inputs.with-torch != 'with-torch' }}
shell: bash
run: |
python -m pip install --upgrade pip
make dev-install-no-torch
- name: Install extra packages from the dev requirements
if: "${{ inputs.extra-packages != '' }}"
shell: bash
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
# Python 3.8 and 3.9 does not run on macos-latest (14)
# Uses macos-13 for 3.8 and 3.9 and macos-latest for >=3.10
os: [ubuntu-latest, macos-13, macos-latest, windows-latest]
with-torch: ["with-torch", "no-torch"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
exclude:
- os: macos-latest
Expand All @@ -38,5 +39,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
extra-packages: pytest
with-torch: ${{ matrix.with-torch }}
- name: Run the unit tests & generate coverage report
run: pytest --cov --cov-fail-under=100
run: pytest --cov
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
.PHONY: dev-install
dev-install: dev-install-core dev-install-simulation dev-install-pasqal

.PHONY: dev-install-no-torch
dev-install-no-torch: dev-install-core-no-torch dev-install-simulation dev-install-pasqal

.PHONY: dev-install-core
dev-install-core:
pip install -e ./pulser-core[torch]

.PHONY: dev-install-core-no-torch
dev-install-core-no-torch:
pip install -e ./pulser-core

.PHONY: dev-install-simulation
Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ If you wish to install only the core ``pulser`` features, you can instead run:
pip install pulser-core
```

### Including PyTorch

To include PyTorch in your installation, append the ``[torch]`` suffix to the commands outlined above, i.e.

```bash
pip install pulser[torch]
```

for the standard ``pulser`` distribution with PyTorch, **or**

```bash
pip install pulser-core[torch]
```

for just the core features plus PyTorch support.

### Development install

If you wish to **install the development version of Pulser from source** instead, do the following from within this repository after cloning it:

```bash
Expand Down
42 changes: 24 additions & 18 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import numpy as np
from numpy.typing import ArrayLike
from scipy.fft import fft, fftfreq, ifft

import pulser.math as pm
from pulser.channels.eom import MODBW_TO_TR, BaseEOM
from pulser.json.utils import get_dataclass_defaults, obj_to_dict
from pulser.pulse import Pulse
Expand Down Expand Up @@ -420,22 +420,24 @@ def validate_pulse(self, pulse: Pulse) -> None:
f"'pulse' must be of type Pulse, not of type {type(pulse)}."
)

if self.max_amp is not None and np.any(
pulse.amplitude.samples > self.max_amp
):
amp_samples_np = pulse.amplitude.samples.as_array(detach=True)
if self.max_amp is not None and np.any(amp_samples_np > self.max_amp):
raise ValueError(
"The pulse's amplitude goes over the maximum "
"value allowed for the chosen channel."
)
if self.max_abs_detuning is not None and np.any(
np.round(np.abs(pulse.detuning.samples), decimals=6)
np.round(
np.abs(pulse.detuning.samples.as_array(detach=True)),
decimals=6,
)
> self.max_abs_detuning
):
raise ValueError(
"The pulse's detuning values go out of the range "
"allowed for the chosen channel."
)
avg_amp = np.average(pulse.amplitude.samples)
avg_amp = np.average(amp_samples_np)
if 0 < avg_amp < self.min_avg_amp:
raise ValueError(
"The pulse's average amplitude is below the chosen "
Expand All @@ -453,10 +455,10 @@ def _modulation_padding(self) -> int:

def modulate(
self,
input_samples: np.ndarray,
input_samples: ArrayLike,
keep_ends: bool = False,
eom: bool = False,
) -> np.ndarray:
) -> pm.AbstractArray:
"""Modulates the input according to the channel's modulation bandwidth.
Args:
Expand All @@ -482,17 +484,17 @@ def modulate(
" 'Channel.modulate()' returns the 'input_samples' unchanged.",
stacklevel=2,
)
return input_samples
return pm.AbstractArray(input_samples)
else:
mod_bandwidth = self.mod_bandwidth
mod_padding = self._modulation_padding

if keep_ends:
samples = np.pad(
samples = pm.pad(
input_samples, mod_padding + self.rise_time, mode="edge"
)
else:
samples = np.pad(input_samples, mod_padding)
samples = pm.pad(input_samples, mod_padding)
mod_samples = self.apply_modulation(samples, mod_bandwidth)
if keep_ends:
# Cut off the extra ends
Expand All @@ -501,8 +503,8 @@ def modulate(

@staticmethod
def apply_modulation(
input_samples: np.ndarray, mod_bandwidth: float
) -> np.ndarray:
input_samples: ArrayLike, mod_bandwidth: float
) -> pm.AbstractArray:
"""Applies the modulation transfer fuction to the input samples.
Note:
Expand All @@ -516,10 +518,11 @@ def apply_modulation(
"""
# The cutoff frequency (fc) and the modulation transfer function
# are defined in https://tinyurl.com/bdeumc8k
input_samples = pm.AbstractArray(input_samples)
fc = mod_bandwidth * 1e-3 / np.sqrt(np.log(2))
freqs = fftfreq(input_samples.size)
modulation = np.exp(-(freqs**2) / fc**2)
return cast(np.ndarray, ifft(fft(input_samples) * modulation).real)
freqs = pm.fftfreq(input_samples.size)
modulation = pm.exp(-(freqs**2) / fc**2)
return pm.ifft(pm.fft(input_samples) * modulation).real

def calc_modulation_buffer(
self,
Expand Down Expand Up @@ -553,8 +556,11 @@ def calc_modulation_buffer(
f"The channel {self} doesn't have a modulation bandwidth."
)
tr = self.rise_time
samples = np.pad(input_samples, tr)
diffs = np.abs(samples - mod_samples) <= max_allowed_diff
samples = pm.pad(input_samples, tr)
diffs = (
abs(samples - mod_samples).as_array(detach=True)
<= max_allowed_diff
)
try:
# Finds the last index in the start buffer that's below the max
# allowed diff. Considers that the waveform could start at the next
Expand Down
5 changes: 4 additions & 1 deletion pulser-core/pulser/channels/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

import pulser.math as pm
from pulser.channels.base_channel import Channel
from pulser.json.utils import get_dataclass_defaults
from pulser.pulse import Pulse
Expand Down Expand Up @@ -112,7 +113,9 @@ def validate_pulse(
(defaults to a detuning map with weight 1.0).
"""
super().validate_pulse(pulse)
round_detuning = np.round(pulse.detuning.samples, decimals=6)
round_detuning = pm.round(pulse.detuning.samples, 6).as_array(
detach=True
)
# Check that detuning is negative
if np.any(round_detuning > 0):
raise ValueError("The detuning in a DMM must not be positive.")
Expand Down
57 changes: 32 additions & 25 deletions pulser-core/pulser/channels/eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np

import pulser.math as pm
from pulser.json.utils import get_dataclass_defaults, obj_to_dict

# Conversion factor from modulation bandwith to rise time
Expand Down Expand Up @@ -210,30 +211,30 @@ def _switching_beams_combos(self) -> list[tuple[RydbergBeam, ...]]:
@overload
def calculate_detuning_off(
self,
amp_on: float,
detuning_on: float,
amp_on: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
optimal_detuning_off: float,
return_switching_beams: Literal[False],
) -> float:
) -> pm.AbstractArray:
pass

@overload
def calculate_detuning_off(
self,
amp_on: float,
detuning_on: float,
amp_on: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
optimal_detuning_off: float,
return_switching_beams: Literal[True],
) -> tuple[float, tuple[RydbergBeam, ...]]:
) -> tuple[pm.AbstractArray, tuple[RydbergBeam, ...]]:
pass

def calculate_detuning_off(
self,
amp_on: float,
detuning_on: float,
amp_on: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
optimal_detuning_off: float,
return_switching_beams: bool = False,
) -> float | tuple[float, tuple[RydbergBeam, ...]]:
) -> pm.AbstractArray | tuple[pm.AbstractArray, tuple[RydbergBeam, ...]]:
"""Calculates the detuning when the amplitude is off in EOM mode.
Args:
Expand All @@ -246,17 +247,19 @@ def calculate_detuning_off(
on and off.
"""
off_options = self.detuning_off_options(amp_on, detuning_on)
closest_option = np.abs(off_options - optimal_detuning_off).argmin()
best_det_off = cast(float, off_options[closest_option])
closest_option = np.abs(
off_options.as_array(detach=True) - optimal_detuning_off
).argmin()
best_det_off = off_options[closest_option]
if not return_switching_beams:
return best_det_off
return best_det_off, self._switching_beams_combos[closest_option]

def detuning_off_options(
self,
rabi_frequency: float,
detuning_on: float,
) -> np.ndarray:
rabi_frequency: float | pm.TensorLike,
detuning_on: float | pm.TensorLike,
) -> pm.AbstractArray:
"""Calculates the possible detuning values when the amplitude is off.
Args:
Expand All @@ -267,11 +270,14 @@ def detuning_off_options(
Returns:
The possible detuning values when in between pulses.
"""
rabi_frequency = pm.AbstractArray(rabi_frequency)
# detuning = offset + lightshift

# offset takes into account the lightshift when both beams are on
# which is not zero when the Rabi freq of both beams is not equal
offset = detuning_on - self._lightshift(rabi_frequency, *RydbergBeam)
offset = pm.AbstractArray(detuning_on) - self._lightshift(
rabi_frequency, *RydbergBeam
)
all_beams: set[RydbergBeam] = set(RydbergBeam)
lightshifts = []
for beams_off in self._switching_beams_combos:
Expand All @@ -280,25 +286,26 @@ def detuning_off_options(
lightshifts.append(self._lightshift(rabi_frequency, *beams_on))

# We sum the offset to all lightshifts to get the effective detuning
return np.array(lightshifts) + offset
return pm.flatten(pm.vstack(lightshifts)) + offset

def _lightshift(
self, rabi_frequency: float, *beams_on: RydbergBeam
) -> float:
self, rabi_frequency: pm.AbstractArray, *beams_on: RydbergBeam
) -> pm.AbstractArray:
# lightshift = (rabi_blue**2 - rabi_red**2) / 4 * int_detuning
rabi_freqs = self._rabi_freq_per_beam(rabi_frequency)
bias = {
RydbergBeam.RED: -self.red_shift_coeff,
RydbergBeam.BLUE: self.blue_shift_coeff,
}
# beam off -> beam_rabi_freq = 0
return sum(bias[beam] * rabi_freqs[beam] ** 2 for beam in beams_on) / (
4 * self.intermediate_detuning
return pm.AbstractArray(
sum(bias[beam] * rabi_freqs[beam] ** 2 for beam in beams_on)
/ (4 * self.intermediate_detuning)
)

def _rabi_freq_per_beam(
self, rabi_frequency: float
) -> dict[RydbergBeam, float]:
self, rabi_frequency: pm.AbstractArray
) -> dict[RydbergBeam, pm.AbstractArray]:
shift_factor = np.sqrt(
self.red_shift_coeff / self.blue_shift_coeff
if self.limiting_beam == RydbergBeam.RED
Expand All @@ -315,14 +322,14 @@ def _rabi_freq_per_beam(
if rabi_frequency <= limit_rabi_freq:
base_amp_squared = 2 * rabi_frequency * self.intermediate_detuning
return {
self.limiting_beam: np.sqrt(base_amp_squared / shift_factor),
~self.limiting_beam: np.sqrt(base_amp_squared * shift_factor),
self.limiting_beam: pm.sqrt(base_amp_squared / shift_factor),
~self.limiting_beam: pm.sqrt(base_amp_squared * shift_factor),
}

# The limiting beam is at its maximum amplitude while the other
# has the necessary amplitude to reach the desired effective rabi freq
return {
self.limiting_beam: self.max_limiting_amp,
self.limiting_beam: pm.AbstractArray(self.max_limiting_amp),
~self.limiting_beam: 2
* self.intermediate_detuning
* rabi_frequency
Expand Down
Loading

0 comments on commit e21d3a8

Please sign in to comment.