diff --git a/pulser-core/pulser/math/abstract_array.py b/pulser-core/pulser/math/abstract_array.py index c74805a6..eb38bcdd 100644 --- a/pulser-core/pulser/math/abstract_array.py +++ b/pulser-core/pulser/math/abstract_array.py @@ -71,6 +71,11 @@ def is_tensor(self) -> bool: """Whether the stored array is a tensor.""" return self.has_torch() and isinstance(self._array, torch.Tensor) + @property + def requires_grad(self) -> bool: + """Whether the stored array is a tensor that needs a gradient.""" + return self.is_tensor and cast(torch.Tensor, self._array).requires_grad + def astype(self, dtype: DTypeLike) -> AbstractArray: """Casts the data type of the array contents.""" if self.is_tensor: @@ -271,10 +276,7 @@ def __setitem__(self, indices: Any, values: AbstractArrayLike) -> None: self._process_indices(indices) ] = values # type: ignore[assignment] except RuntimeError as e: - if ( - self.is_tensor - and cast(torch.Tensor, self._array).requires_grad - ): + if self.requires_grad: raise RuntimeError( "Failed to modify a tensor that requires grad in place." ) from e diff --git a/pulser-core/pulser/register/base_register.py b/pulser-core/pulser/register/base_register.py index 9b30c33b..ef73e43a 100644 --- a/pulser-core/pulser/register/base_register.py +++ b/pulser-core/pulser/register/base_register.py @@ -79,7 +79,6 @@ def __init__( ) self._ids: tuple[QubitId, ...] = tuple(qubits.keys()) if any(not isinstance(id, str) for id in self._ids): - warnings.simplefilter("always") warnings.warn( "Usage of `int`s or any non-`str`types as `QubitId`s will be " "deprecated. Define your `QubitId`s as `str`s, prefer setting " diff --git a/tests/test_channels.py b/tests/test_channels.py index a3e93e24..582deb23 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -292,8 +292,7 @@ def test_modulation(channel, tr, eom, side_buffer_len, requires_grad): tr, tr, ) - if requires_grad: - assert out_.as_tensor().requires_grad + assert out_.requires_grad == requires_grad wf2 = BlackmanWaveform(800, wf_vals[1]) out_ = channel.modulate(wf2.samples, eom=eom) @@ -302,8 +301,7 @@ def test_modulation(channel, tr, eom, side_buffer_len, requires_grad): side_buffer_len, side_buffer_len, ) - if requires_grad: - assert out_.as_tensor().requires_grad + assert out_.requires_grad == requires_grad @pytest.mark.parametrize( diff --git a/tests/test_eom.py b/tests/test_eom.py index ea63a4b2..e10d508b 100644 --- a/tests/test_eom.py +++ b/tests/test_eom.py @@ -190,8 +190,7 @@ def calc_offset(amp): ] ) assert calculated_det_off == min(det_off_options, key=abs) - if requires_grad: - assert calculated_det_off.as_tensor().requires_grad + assert calculated_det_off.requires_grad == requires_grad # Case where the EOM pulses are off-resonant detuning_on = detuning_on + 1.0 @@ -210,5 +209,4 @@ def calc_offset(amp): assert off_options[0] == eom_.calculate_detuning_off( amp, detuning_on, optimal_detuning_off=0.0 ) - if requires_grad: - assert off_options.as_tensor().requires_grad + assert off_options.requires_grad == requires_grad diff --git a/tests/test_math.py b/tests/test_math.py index 75aa0d50..51b8abb3 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -39,8 +39,7 @@ def test_pad(cast_to, requires_grad): arr = torch.tensor(arr, requires_grad=requires_grad) def check_match(arr1: pm.AbstractArray, arr2): - if requires_grad: - assert arr1.as_tensor().requires_grad + assert arr1.requires_grad == requires_grad np.testing.assert_array_equal( arr1.as_array(detach=requires_grad), arr2 ) @@ -260,8 +259,7 @@ def test_items(self, use_tensor, requires_grad, indices): assert item == val[i] assert isinstance(item, pm.AbstractArray) assert item.is_tensor == use_tensor - if use_tensor: - assert item.as_tensor().requires_grad == requires_grad + assert item.requires_grad == requires_grad # setitem if not requires_grad: @@ -292,8 +290,8 @@ def test_items(self, use_tensor, requires_grad, indices): new_val[indices] = 0.0 assert np.all(arr_np == new_val) assert arr_np.is_tensor - # The resulting tensor requires grad if the assing one did - assert arr_np.as_tensor().requires_grad == requires_grad + # The resulting tensor requires grad if the assigned one did + assert arr_np.requires_grad == requires_grad @pytest.mark.parametrize("scalar", [False, True]) @pytest.mark.parametrize( diff --git a/tests/test_parametrized.py b/tests/test_parametrized.py index 7d0c4ccc..87e55584 100644 --- a/tests/test_parametrized.py +++ b/tests/test_parametrized.py @@ -104,10 +104,7 @@ def test_var_diff(a, b, requires_grad): b._assign(torch.tensor([-1.0, 1.0], requires_grad=requires_grad)) for var in [a, b]: - assert ( - a.value is not None - and a.value.as_tensor().requires_grad == requires_grad - ) + assert a.value is not None and a.value.requires_grad == requires_grad def test_varitem(a, b, d): @@ -167,7 +164,7 @@ def test_paramobj(bwf, t, a, b): def test_opsupport(a, b, with_diff_tensor): def check_var_grad(var): if with_diff_tensor: - assert var.build().as_tensor().requires_grad + assert var.build().requires_grad a._assign(-2.0) if with_diff_tensor: diff --git a/tests/test_pulse.py b/tests/test_pulse.py index fe51866a..e5a26566 100644 --- a/tests/test_pulse.py +++ b/tests/test_pulse.py @@ -234,9 +234,9 @@ def test_eq(): def _assert_pulse_requires_grad(pulse: Pulse, invert: bool = False) -> None: - assert pulse.amplitude.samples.as_tensor().requires_grad == (not invert) - assert pulse.detuning.samples.as_tensor().requires_grad == (not invert) - assert pulse.phase.as_tensor().requires_grad == (not invert) + assert pulse.amplitude.samples.requires_grad == (not invert) + assert pulse.detuning.samples.requires_grad == (not invert) + assert pulse.phase.requires_grad == (not invert) @pytest.mark.parametrize("requires_grad", [True, False]) diff --git a/tests/test_register.py b/tests/test_register.py index 5cbacf0a..c7c387a7 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -508,9 +508,9 @@ def _assert_reg_requires_grad( ) -> None: for coords in reg.qubits.values(): if invert: - assert not coords.as_tensor().requires_grad + assert not coords.requires_grad else: - assert coords.is_tensor and coords.as_tensor().requires_grad + assert coords.is_tensor and coords.requires_grad @pytest.mark.parametrize( diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 87b0f745..a5daffd6 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2886,12 +2886,12 @@ def test_sequence_diff(device, parametrized, with_modulation, with_eom): seq_samples = sample(seq, modulation=with_modulation) ryd_ch_samples = seq_samples.channel_samples["ryd_global"] - assert ryd_ch_samples.amp.as_tensor().requires_grad - assert ryd_ch_samples.det.as_tensor().requires_grad - assert ryd_ch_samples.phase.as_tensor().requires_grad + assert ryd_ch_samples.amp.requires_grad + assert ryd_ch_samples.det.requires_grad + assert ryd_ch_samples.phase.requires_grad if "dmm_0" in seq_samples.channel_samples: dmm_ch_samples = seq_samples.channel_samples["dmm_0"] # Only detuning is modulated - assert not dmm_ch_samples.amp.as_tensor().requires_grad - assert dmm_ch_samples.det.as_tensor().requires_grad - assert not dmm_ch_samples.phase.as_tensor().requires_grad + assert not dmm_ch_samples.amp.requires_grad + assert dmm_ch_samples.det.requires_grad + assert not dmm_ch_samples.phase.requires_grad diff --git a/tests/test_sequence_sampler.py b/tests/test_sequence_sampler.py index fb539ca6..cc7b67f0 100644 --- a/tests/test_sequence_sampler.py +++ b/tests/test_sequence_sampler.py @@ -523,11 +523,11 @@ def test_phase_modulation(off_center, with_diff): seq_samples = sample(seq).channel_samples["rydberg_global"] if with_diff: - assert full_phase.samples.as_tensor().requires_grad - assert not seq_samples.amp.as_tensor().requires_grad - assert seq_samples.det.as_tensor().requires_grad - assert seq_samples.phase.as_tensor().requires_grad - assert seq_samples.phase_modulation.as_tensor().requires_grad + assert full_phase.samples.requires_grad + assert not seq_samples.amp.requires_grad + assert seq_samples.det.requires_grad + assert seq_samples.phase.requires_grad + assert seq_samples.phase_modulation.requires_grad np.testing.assert_allclose( seq_samples.phase_modulation.as_array(detach=with_diff) diff --git a/tests/test_waveforms.py b/tests/test_waveforms.py index 59648cfb..5d46de56 100644 --- a/tests/test_waveforms.py +++ b/tests/test_waveforms.py @@ -490,10 +490,7 @@ def test_waveform_diff( samples_tensor = wf.samples.as_tensor() assert samples_tensor.requires_grad == requires_grad - assert ( - wf.modulated_samples(rydberg_global).as_tensor().requires_grad - == requires_grad - ) + assert wf.modulated_samples(rydberg_global).requires_grad == requires_grad wfx2_tensor = (-wf * 2).samples.as_tensor() assert torch.equal(wfx2_tensor, samples_tensor * -2.0) assert wfx2_tensor.requires_grad == requires_grad @@ -501,15 +498,12 @@ def test_waveform_diff( wfdiv2 = wf / torch.tensor(2.0, requires_grad=True) assert torch.equal(wfdiv2.samples.as_tensor(), samples_tensor / 2.0) # Should always be true because it was divided by diff tensor - assert wfdiv2.samples.as_tensor().requires_grad + assert wfdiv2.samples.requires_grad - assert wf[-1].as_tensor().requires_grad == requires_grad + assert wf[-1].requires_grad == requires_grad try: - assert ( - wf.change_duration(1000).samples.as_tensor().requires_grad - == requires_grad - ) + assert wf.change_duration(1000).samples.requires_grad == requires_grad except NotImplementedError: pass