Skip to content

Commit

Permalink
Use pauli_word_prefactor in stoch_pulse_grad (#4156)
Browse files Browse the repository at this point in the history
* bugfix, test, changelog

* comment -> dev comment

* warnings business
  • Loading branch information
dwierichs authored Jun 5, 2023
1 parent cf70b7a commit 6f3b080
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@

<h3>Bug fixes 🐛</h3>

* Fixes a bug where `stoch_pulse_grad` would ignore prefactors of rescaled Pauli words in the
generating terms of a pulse Hamiltonian.
[(4156)](https://github.com/PennyLaneAI/pennylane/pull/4156)

* Fixes a bug where the wire ordering of the `wires` argument to `qml.density_matrix`
was not taken into account.
[(#4072)](https://github.com/PennyLaneAI/pennylane/pull/4072)
Expand Down
13 changes: 11 additions & 2 deletions pennylane/gradients/pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains functions for computing the stochastic parameter-shift gradient
of pulse sequences in a qubit-based quantum tape.
"""
import warnings
import numpy as np

import pennylane as qml
Expand Down Expand Up @@ -90,12 +91,17 @@ def _split_evol_ops(op, ob, tau):
after_t = jax.numpy.array([tau, t1])

if qml.pauli.is_pauli_word(ob):
prefactor = next(iter(qml.pauli.pauli_sentence(ob).values()))
prefactor = qml.pauli.pauli_word_prefactor(ob)
word = qml.pauli.pauli_word_to_string(ob)
insert_ops = [qml.PauliRot(shift, word, ob.wires) for shift in [np.pi / 2, -np.pi / 2]]
coeffs = [prefactor, -prefactor]
else:
eigvals = qml.eigvals(ob)
with warnings.catch_warnings():
if len(ob.wires) <= 4:
warnings.filterwarnings(
"ignore", ".*the eigenvalues will be computed numerically.*"
)
eigvals = qml.eigvals(ob)
coeffs, shifts = zip(*generate_shift_rule(eigvals_to_frequencies(tuple(eigvals))))
insert_ops = [qml.exp(qml.dot([-1j * shift], [ob])) for shift in shifts]

Expand Down Expand Up @@ -568,8 +574,11 @@ def _generate_tapes_and_cjacs(tape, idx, key, num_split_times, use_broadcasting)
"""Generate the tapes and compute the classical Jacobians for one given
generating Hamiltonian term of one pulse.
"""
# Obtain the operation into which the indicated parameter feeds, its position in the tape,
# and the index of the parameter within the operation
op, op_idx, term_idx = tape.get_operation(idx)
if not isinstance(op, ParametrizedEvolution):
# Only pulses are supported
raise ValueError(
"stoch_pulse_grad does not support differentiating parameters of "
"other operations than pulses."
Expand Down
28 changes: 25 additions & 3 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Tests for the gradients.pulse_gradient module.
"""

import warnings
import copy
import pytest
import numpy as np
Expand Down Expand Up @@ -179,6 +180,27 @@ def test_with_general_ob(self, ham, params, time, ob):
# Check that the inserted exponential is correct
assert qml.equal(qml.exp(qml.dot([-1j * exp_shift], [ob])), _ops[1])

def test_warnings(self):
"""Test that a warning is raised for computing eigenvalues of a Hamiltonian
for more than four wires but not for fewer wires."""
import jax

jax.config.update("jax_enable_x64", True)
ham = qml.pulse.constant * qml.PauliY(0)
op = qml.evolve(ham)([0.3], 2.0)
ob = qml.Hamiltonian(
[0.4, 0.2], [qml.operation.Tensor(*[qml.PauliY(i) for i in range(5)]), qml.PauliX(0)]
)
with pytest.warns(UserWarning, match="the eigenvalues will be computed numerically"):
_split_evol_ops(op, ob, tau=0.4)

ob = qml.Hamiltonian(
[0.4, 0.2], [qml.operation.Tensor(*[qml.PauliY(i) for i in range(4)]), qml.PauliX(0)]
)
with warnings.catch_warnings():
warnings.simplefilter("error")
_split_evol_ops(op, ob, tau=0.4)


@pytest.mark.jax
class TestSplitEvolTapes:
Expand Down Expand Up @@ -396,7 +418,7 @@ def test_some_zero_grads(self):

@pytest.mark.parametrize("num_split_times", [1, 3])
@pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)])
def test_constant_pauliword(self, num_split_times, t):
def test_constant_ry(self, num_split_times, t):
"""Test that the derivative of a pulse generated by a constant Hamiltonian,
which is a Pauli word, is computed correctly."""
import jax
Expand All @@ -422,7 +444,7 @@ def test_constant_pauliword(self, num_split_times, t):

@pytest.mark.parametrize("num_split_times", [1, 3])
@pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)])
def test_constant_paulisentence(self, num_split_times, t):
def test_constant_ry_rescaled(self, num_split_times, t):
"""Test that the derivative of a pulse generated by a constant Hamiltonian,
which is a Pauli sentence, is computed correctly."""
import jax
Expand Down Expand Up @@ -451,7 +473,7 @@ def test_constant_paulisentence(self, num_split_times, t):
assert qml.math.isclose(res, -2 * jnp.sin(2 * p) * delta_t * prefactor)

@pytest.mark.parametrize("t", [0.02, (0.5, 0.6)])
def test_sin_envelope_rx_expval(self, t):
def test_sin_envelope_rz_expval(self, t):
"""Test that the derivative of a pulse with a sine wave envelope
is computed correctly when returning an expectation value."""
import jax.numpy as jnp
Expand Down

0 comments on commit 6f3b080

Please sign in to comment.