Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shortcut to find out if an AbstractArray is differentiable #784

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pulser-core/pulser/math/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def is_tensor(self) -> bool:
"""Whether the stored array is a tensor."""
return self.has_torch() and isinstance(self._array, torch.Tensor)

@property
def requires_grad(self) -> bool:
"""Whether the stored array is a tensor that needs a gradient."""
return self.is_tensor and cast(torch.Tensor, self._array).requires_grad

def astype(self, dtype: DTypeLike) -> AbstractArray:
"""Casts the data type of the array contents."""
if self.is_tensor:
Expand Down Expand Up @@ -271,10 +276,7 @@ def __setitem__(self, indices: Any, values: AbstractArrayLike) -> None:
self._process_indices(indices)
] = values # type: ignore[assignment]
except RuntimeError as e:
if (
self.is_tensor
and cast(torch.Tensor, self._array).requires_grad
):
if self.requires_grad:
raise RuntimeError(
"Failed to modify a tensor that requires grad in place."
) from e
Expand Down
1 change: 0 additions & 1 deletion pulser-core/pulser/register/base_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
)
self._ids: tuple[QubitId, ...] = tuple(qubits.keys())
if any(not isinstance(id, str) for id in self._ids):
warnings.simplefilter("always")
warnings.warn(
"Usage of `int`s or any non-`str`types as `QubitId`s will be "
"deprecated. Define your `QubitId`s as `str`s, prefer setting "
Expand Down
6 changes: 2 additions & 4 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ def test_modulation(channel, tr, eom, side_buffer_len, requires_grad):
tr,
tr,
)
if requires_grad:
assert out_.as_tensor().requires_grad
assert out_.requires_grad == requires_grad

wf2 = BlackmanWaveform(800, wf_vals[1])
out_ = channel.modulate(wf2.samples, eom=eom)
Expand All @@ -302,8 +301,7 @@ def test_modulation(channel, tr, eom, side_buffer_len, requires_grad):
side_buffer_len,
side_buffer_len,
)
if requires_grad:
assert out_.as_tensor().requires_grad
assert out_.requires_grad == requires_grad


@pytest.mark.parametrize(
Expand Down
6 changes: 2 additions & 4 deletions tests/test_eom.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def calc_offset(amp):
]
)
assert calculated_det_off == min(det_off_options, key=abs)
if requires_grad:
assert calculated_det_off.as_tensor().requires_grad
assert calculated_det_off.requires_grad == requires_grad

# Case where the EOM pulses are off-resonant
detuning_on = detuning_on + 1.0
Expand All @@ -210,5 +209,4 @@ def calc_offset(amp):
assert off_options[0] == eom_.calculate_detuning_off(
amp, detuning_on, optimal_detuning_off=0.0
)
if requires_grad:
assert off_options.as_tensor().requires_grad
assert off_options.requires_grad == requires_grad
10 changes: 4 additions & 6 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def test_pad(cast_to, requires_grad):
arr = torch.tensor(arr, requires_grad=requires_grad)

def check_match(arr1: pm.AbstractArray, arr2):
if requires_grad:
assert arr1.as_tensor().requires_grad
assert arr1.requires_grad == requires_grad
np.testing.assert_array_equal(
arr1.as_array(detach=requires_grad), arr2
)
Expand Down Expand Up @@ -260,8 +259,7 @@ def test_items(self, use_tensor, requires_grad, indices):
assert item == val[i]
assert isinstance(item, pm.AbstractArray)
assert item.is_tensor == use_tensor
if use_tensor:
assert item.as_tensor().requires_grad == requires_grad
assert item.requires_grad == requires_grad

# setitem
if not requires_grad:
Expand Down Expand Up @@ -292,8 +290,8 @@ def test_items(self, use_tensor, requires_grad, indices):
new_val[indices] = 0.0
assert np.all(arr_np == new_val)
assert arr_np.is_tensor
# The resulting tensor requires grad if the assing one did
assert arr_np.as_tensor().requires_grad == requires_grad
# The resulting tensor requires grad if the assigned one did
assert arr_np.requires_grad == requires_grad

@pytest.mark.parametrize("scalar", [False, True])
@pytest.mark.parametrize(
Expand Down
7 changes: 2 additions & 5 deletions tests/test_parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ def test_var_diff(a, b, requires_grad):
b._assign(torch.tensor([-1.0, 1.0], requires_grad=requires_grad))

for var in [a, b]:
assert (
a.value is not None
and a.value.as_tensor().requires_grad == requires_grad
)
assert a.value is not None and a.value.requires_grad == requires_grad


def test_varitem(a, b, d):
Expand Down Expand Up @@ -167,7 +164,7 @@ def test_paramobj(bwf, t, a, b):
def test_opsupport(a, b, with_diff_tensor):
def check_var_grad(var):
if with_diff_tensor:
assert var.build().as_tensor().requires_grad
assert var.build().requires_grad

a._assign(-2.0)
if with_diff_tensor:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ def test_eq():


def _assert_pulse_requires_grad(pulse: Pulse, invert: bool = False) -> None:
assert pulse.amplitude.samples.as_tensor().requires_grad == (not invert)
assert pulse.detuning.samples.as_tensor().requires_grad == (not invert)
assert pulse.phase.as_tensor().requires_grad == (not invert)
assert pulse.amplitude.samples.requires_grad == (not invert)
assert pulse.detuning.samples.requires_grad == (not invert)
assert pulse.phase.requires_grad == (not invert)


@pytest.mark.parametrize("requires_grad", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,9 @@ def _assert_reg_requires_grad(
) -> None:
for coords in reg.qubits.values():
if invert:
assert not coords.as_tensor().requires_grad
assert not coords.requires_grad
else:
assert coords.is_tensor and coords.as_tensor().requires_grad
assert coords.is_tensor and coords.requires_grad


@pytest.mark.parametrize(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,12 +2886,12 @@ def test_sequence_diff(device, parametrized, with_modulation, with_eom):

seq_samples = sample(seq, modulation=with_modulation)
ryd_ch_samples = seq_samples.channel_samples["ryd_global"]
assert ryd_ch_samples.amp.as_tensor().requires_grad
assert ryd_ch_samples.det.as_tensor().requires_grad
assert ryd_ch_samples.phase.as_tensor().requires_grad
assert ryd_ch_samples.amp.requires_grad
assert ryd_ch_samples.det.requires_grad
assert ryd_ch_samples.phase.requires_grad
if "dmm_0" in seq_samples.channel_samples:
dmm_ch_samples = seq_samples.channel_samples["dmm_0"]
# Only detuning is modulated
assert not dmm_ch_samples.amp.as_tensor().requires_grad
assert dmm_ch_samples.det.as_tensor().requires_grad
assert not dmm_ch_samples.phase.as_tensor().requires_grad
assert not dmm_ch_samples.amp.requires_grad
assert dmm_ch_samples.det.requires_grad
assert not dmm_ch_samples.phase.requires_grad
10 changes: 5 additions & 5 deletions tests/test_sequence_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,11 @@ def test_phase_modulation(off_center, with_diff):
seq_samples = sample(seq).channel_samples["rydberg_global"]

if with_diff:
assert full_phase.samples.as_tensor().requires_grad
assert not seq_samples.amp.as_tensor().requires_grad
assert seq_samples.det.as_tensor().requires_grad
assert seq_samples.phase.as_tensor().requires_grad
assert seq_samples.phase_modulation.as_tensor().requires_grad
assert full_phase.samples.requires_grad
assert not seq_samples.amp.requires_grad
assert seq_samples.det.requires_grad
assert seq_samples.phase.requires_grad
assert seq_samples.phase_modulation.requires_grad

np.testing.assert_allclose(
seq_samples.phase_modulation.as_array(detach=with_diff)
Expand Down
14 changes: 4 additions & 10 deletions tests/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,26 +490,20 @@ def test_waveform_diff(

samples_tensor = wf.samples.as_tensor()
assert samples_tensor.requires_grad == requires_grad
assert (
wf.modulated_samples(rydberg_global).as_tensor().requires_grad
== requires_grad
)
assert wf.modulated_samples(rydberg_global).requires_grad == requires_grad
wfx2_tensor = (-wf * 2).samples.as_tensor()
assert torch.equal(wfx2_tensor, samples_tensor * -2.0)
assert wfx2_tensor.requires_grad == requires_grad

wfdiv2 = wf / torch.tensor(2.0, requires_grad=True)
assert torch.equal(wfdiv2.samples.as_tensor(), samples_tensor / 2.0)
# Should always be true because it was divided by diff tensor
assert wfdiv2.samples.as_tensor().requires_grad
assert wfdiv2.samples.requires_grad

assert wf[-1].as_tensor().requires_grad == requires_grad
assert wf[-1].requires_grad == requires_grad

try:
assert (
wf.change_duration(1000).samples.as_tensor().requires_grad
== requires_grad
)
assert wf.change_duration(1000).samples.requires_grad == requires_grad
except NotImplementedError:
pass

Expand Down
Loading