From 6f3b080f8f9b18ff9198a435635654a39318ddbc Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Mon, 5 Jun 2023 14:21:21 +0200 Subject: [PATCH] Use `pauli_word_prefactor` in `stoch_pulse_grad` (#4156) * bugfix, test, changelog * comment -> dev comment * warnings business --- doc/releases/changelog-dev.md | 4 +++ pennylane/gradients/pulse_gradient.py | 13 ++++++++-- tests/gradients/core/test_pulse_gradient.py | 28 ++++++++++++++++++--- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 9d0cd23edc6..557ea542822 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -217,6 +217,10 @@

Bug fixes 🐛

+* Fixes a bug where `stoch_pulse_grad` would ignore prefactors of rescaled Pauli words in the + generating terms of a pulse Hamiltonian. + [(4156)](https://github.com/PennyLaneAI/pennylane/pull/4156) + * Fixes a bug where the wire ordering of the `wires` argument to `qml.density_matrix` was not taken into account. [(#4072)](https://github.com/PennyLaneAI/pennylane/pull/4072) diff --git a/pennylane/gradients/pulse_gradient.py b/pennylane/gradients/pulse_gradient.py index 5df80e41852..f087a9e5c2c 100644 --- a/pennylane/gradients/pulse_gradient.py +++ b/pennylane/gradients/pulse_gradient.py @@ -15,6 +15,7 @@ This module contains functions for computing the stochastic parameter-shift gradient of pulse sequences in a qubit-based quantum tape. """ +import warnings import numpy as np import pennylane as qml @@ -90,12 +91,17 @@ def _split_evol_ops(op, ob, tau): after_t = jax.numpy.array([tau, t1]) if qml.pauli.is_pauli_word(ob): - prefactor = next(iter(qml.pauli.pauli_sentence(ob).values())) + prefactor = qml.pauli.pauli_word_prefactor(ob) word = qml.pauli.pauli_word_to_string(ob) insert_ops = [qml.PauliRot(shift, word, ob.wires) for shift in [np.pi / 2, -np.pi / 2]] coeffs = [prefactor, -prefactor] else: - eigvals = qml.eigvals(ob) + with warnings.catch_warnings(): + if len(ob.wires) <= 4: + warnings.filterwarnings( + "ignore", ".*the eigenvalues will be computed numerically.*" + ) + eigvals = qml.eigvals(ob) coeffs, shifts = zip(*generate_shift_rule(eigvals_to_frequencies(tuple(eigvals)))) insert_ops = [qml.exp(qml.dot([-1j * shift], [ob])) for shift in shifts] @@ -568,8 +574,11 @@ def _generate_tapes_and_cjacs(tape, idx, key, num_split_times, use_broadcasting) """Generate the tapes and compute the classical Jacobians for one given generating Hamiltonian term of one pulse. """ + # Obtain the operation into which the indicated parameter feeds, its position in the tape, + # and the index of the parameter within the operation op, op_idx, term_idx = tape.get_operation(idx) if not isinstance(op, ParametrizedEvolution): + # Only pulses are supported raise ValueError( "stoch_pulse_grad does not support differentiating parameters of " "other operations than pulses." diff --git a/tests/gradients/core/test_pulse_gradient.py b/tests/gradients/core/test_pulse_gradient.py index 013d62d6c37..e9a89bac0cc 100644 --- a/tests/gradients/core/test_pulse_gradient.py +++ b/tests/gradients/core/test_pulse_gradient.py @@ -15,6 +15,7 @@ Tests for the gradients.pulse_gradient module. """ +import warnings import copy import pytest import numpy as np @@ -179,6 +180,27 @@ def test_with_general_ob(self, ham, params, time, ob): # Check that the inserted exponential is correct assert qml.equal(qml.exp(qml.dot([-1j * exp_shift], [ob])), _ops[1]) + def test_warnings(self): + """Test that a warning is raised for computing eigenvalues of a Hamiltonian + for more than four wires but not for fewer wires.""" + import jax + + jax.config.update("jax_enable_x64", True) + ham = qml.pulse.constant * qml.PauliY(0) + op = qml.evolve(ham)([0.3], 2.0) + ob = qml.Hamiltonian( + [0.4, 0.2], [qml.operation.Tensor(*[qml.PauliY(i) for i in range(5)]), qml.PauliX(0)] + ) + with pytest.warns(UserWarning, match="the eigenvalues will be computed numerically"): + _split_evol_ops(op, ob, tau=0.4) + + ob = qml.Hamiltonian( + [0.4, 0.2], [qml.operation.Tensor(*[qml.PauliY(i) for i in range(4)]), qml.PauliX(0)] + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + _split_evol_ops(op, ob, tau=0.4) + @pytest.mark.jax class TestSplitEvolTapes: @@ -396,7 +418,7 @@ def test_some_zero_grads(self): @pytest.mark.parametrize("num_split_times", [1, 3]) @pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)]) - def test_constant_pauliword(self, num_split_times, t): + def test_constant_ry(self, num_split_times, t): """Test that the derivative of a pulse generated by a constant Hamiltonian, which is a Pauli word, is computed correctly.""" import jax @@ -422,7 +444,7 @@ def test_constant_pauliword(self, num_split_times, t): @pytest.mark.parametrize("num_split_times", [1, 3]) @pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)]) - def test_constant_paulisentence(self, num_split_times, t): + def test_constant_ry_rescaled(self, num_split_times, t): """Test that the derivative of a pulse generated by a constant Hamiltonian, which is a Pauli sentence, is computed correctly.""" import jax @@ -451,7 +473,7 @@ def test_constant_paulisentence(self, num_split_times, t): assert qml.math.isclose(res, -2 * jnp.sin(2 * p) * delta_t * prefactor) @pytest.mark.parametrize("t", [0.02, (0.5, 0.6)]) - def test_sin_envelope_rx_expval(self, t): + def test_sin_envelope_rz_expval(self, t): """Test that the derivative of a pulse with a sine wave envelope is computed correctly when returning an expectation value.""" import jax.numpy as jnp