Skip to content

Commit

Permalink
Raise a warning when applying a pulse gradient transform to a QNode d…
Browse files Browse the repository at this point in the history
…irectly (#4241)

* introduce warning and tests

* changelog, recommendation

* docstrings

* fix test

* switch to raising an error instead

* Apply suggestions from code review

Co-authored-by: Tom Bromley <[email protected]>

* fix tests

---------

Co-authored-by: Korbinian Kottmann <[email protected]>
Co-authored-by: Tom Bromley <[email protected]>
  • Loading branch information
3 people authored Jun 16, 2023
1 parent 241045e commit 3bf2964
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 11 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@

<h3>Improvements 🛠</h3>

* The pulse differentiation methods, `pulse_generator` and `stoch_pulse_grad` now raise an error when they
are applied to a `QNode` directly. Instead, use differentiation via a JAX entry point (`jax.grad`, `jax.jacobian`, ...).
[(4241)](https://github.com/PennyLaneAI/pennylane/pull/4241)

* `pulse.ParametrizedEvolution` now raises an error if the number of input parameters does not match the number
of parametrized coefficients in the `ParametrizedHamiltonian` that generates it. An exception is made for
`HardwareHamiltonian`s which are not checked.
Expand Down
22 changes: 20 additions & 2 deletions pennylane/gradients/pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pennylane.measurements import Shots

from .parameter_shift import _make_zero_rep
from .pulse_gradient import _assert_has_jax
from .pulse_gradient import _assert_has_jax, raise_pulse_diff_on_qnode
from .gradient_transform import (
_all_zero_grad,
assert_active_return,
Expand Down Expand Up @@ -455,7 +455,14 @@ def _pulse_generator(tape, argnum=None, shots=None, atol=1e-7):
This function requires the JAX interface and does not work with other autodiff interfaces
commonly encountered with PennyLane.
In addition, this transform is only JIT-compatible with pulses that only have scalar parameters.
In addition, this transform is only JIT-compatible with pulses that only have scalar
parameters.
.. warning::
This transform may not be applied directly to QNodes. Use JAX entrypoints
(``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape
level. Also see the examples below.
**Example**
Expand Down Expand Up @@ -702,3 +709,14 @@ def expand_invalid_trainable_pulse_generator(x, *args, **kwargs):
pulse_generator = gradient_transform(
_pulse_generator, expand_fn=expand_invalid_trainable_pulse_generator
)


@pulse_generator.custom_qnode_wrapper
def pulse_generator_qnode_wrapper(self, qnode, targs, tkwargs):
"""A custom QNode wrapper for the gradient transform :func:`~.pulse_generator`.
It raises an error, so that applying ``pulse_generator`` to a ``QNode`` directly
is not supported.
"""
# pylint:disable=unused-argument
transform_name = "pulse generator parameter-shift"
raise_pulse_diff_on_qnode(transform_name)
31 changes: 28 additions & 3 deletions pennylane/gradients/pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def _assert_has_jax(transform_name):
)


def raise_pulse_diff_on_qnode(transform_name):
"""Raises an error as the gradient transform with the provided name does
not support direct application to QNodes.
"""
msg = (
f"Applying the {transform_name} gradient transform to a QNode directly is currently "
"not supported. Please use differentiation via a JAX entry point "
"(jax.grad, jax.jacobian, ...) instead.",
UserWarning,
)
raise NotImplementedError(msg)


def _split_evol_ops(op, ob, tau):
r"""Randomly split a ``ParametrizedEvolution`` with respect to time into two operations and
insert a Pauli rotation using a given Pauli word and rotation angles :math:`\pm\pi/2`.
Expand Down Expand Up @@ -313,10 +326,11 @@ def _stoch_pulse_grad(
rules when used with simple pulses (see details and examples below), potentially leading
to imprecise results and/or unnecessarily large computational efforts.
.. note::
.. warning::
Currently this function only supports pulses for which each *parametrized* term is a
simple Pauli word. More general Hamiltonian terms are not supported yet.
This transform may not be applied directly to QNodes. Use JAX entrypoints
(``jax.grad``, ``jax.jacobian``, ...) instead or apply the transform on the tape level.
Also see the examples below.
**Examples**
Expand Down Expand Up @@ -682,3 +696,14 @@ def expand_invalid_trainable_stoch_pulse_grad(x, *args, **kwargs):
stoch_pulse_grad = gradient_transform(
_stoch_pulse_grad, expand_fn=expand_invalid_trainable_stoch_pulse_grad
)


@stoch_pulse_grad.custom_qnode_wrapper
def stoch_pulse_grad_qnode_wrapper(self, qnode, targs, tkwargs):
"""A custom QNode wrapper for the gradient transform :func:`~.stoch_pulse_grad`.
It raises an error, so that applying ``pulse_generator`` to a ``QNode`` directly
is not supported.
"""
# pylint:disable=unused-argument
transform_name = "stochastic pulse parameter-shift"
raise_pulse_diff_on_qnode(transform_name)
24 changes: 20 additions & 4 deletions tests/gradients/core/test_pulse_generator_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,24 @@ def circuit(par):
class TestPulseGeneratorQNode:
"""Test that pulse_generator integrates correctly with QNodes."""

def test_raises_for_application_to_qnodes(self):
"""Test that an error is raised when applying ``stoch_pulse_grad``
to a QNode directly."""

dev = qml.device("default.qubit.jax", wires=1)
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(dev, interface="jax")
def circuit(params):
qml.evolve(ham_single_q_const)([params], 0.2)
return qml.expval(qml.PauliZ(0))

_match = "pulse generator parameter-shift gradient transform to a QNode directly"
with pytest.raises(NotImplementedError, match=_match):
pulse_generator(circuit)

# TODO: include the following tests when #4225 is resolved.
@pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_expval_single_par(self):
"""Test that a simple qnode that returns an expectation value
can be differentiated with pulse_generator."""
Expand All @@ -1146,8 +1164,7 @@ def circuit(params):
assert jnp.allclose(grad, exp_grad)
assert tracker.totals["executions"] == 2 # two shifted tapes

# Applying QNode-level gradient transforms with non-scalar parameters is not supported yet
@pytest.mark.xfail
@pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_expval_probs_single_par(self):
"""Test that a simple qnode that returns an expectation value
can be differentiated with pulse_generator."""
Expand Down Expand Up @@ -1179,8 +1196,7 @@ def circuit(params):
for j, e in zip(jac, exp_jac):
assert qml.math.allclose(j, e)

# Applying QNode-level gradient transforms with non-scalar parameters is not supported yet
@pytest.mark.xfail
@pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_probs_expval_multi_par(self):
"""Test that a simple qnode that returns probabilities
can be differentiated with pulse_generator."""
Expand Down
52 changes: 50 additions & 2 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,56 @@ def test_shots_attribute(self, shots):


@pytest.mark.jax
class TestStochPulseGradQNodeIntegration:
"""Test that stoch_pulse_grad integrates correctly with QNodes."""
class TestStochPulseGradQNode:
"""Test that pulse_generator integrates correctly with QNodes."""

def test_raises_for_application_to_qnodes(self):
"""Test that an error is raised when applying ``stoch_pulse_grad``
to a QNode directly."""
dev = qml.device("default.qubit.jax", wires=1)
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(dev, interface="jax")
def circuit(params):
qml.evolve(ham_single_q_const)([params], 0.2)
return qml.expval(qml.PauliZ(0))

_match = "stochastic pulse parameter-shift gradient transform to a QNode directly"
with pytest.raises(NotImplementedError, match=_match):
stoch_pulse_grad(circuit, num_split_times=2)

# TODO: include the following tests when #4225 is resolved.
@pytest.mark.skip("Applying this gradient transform to QNodes directly is not supported.")
def test_qnode_expval_single_par(self):
"""Test that a simple qnode that returns an expectation value
can be differentiated with pulse_generator."""
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
dev = qml.device("default.qubit.jax", wires=1)
T = 0.2
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(dev, interface="jax")
def circuit(params):
qml.evolve(ham_single_q_const)([params], T)
return qml.expval(qml.PauliZ(0))

params = jnp.array(0.4)
with qml.Tracker(dev) as tracker:
_match = "stochastic pulse parameter-shift .* scalar pulse parameters."
grad = stoch_pulse_grad(circuit, num_split_times=2)(params)

p = params * T
exp_grad = -2 * jnp.sin(2 * p) * T
assert jnp.allclose(grad, exp_grad)
assert tracker.totals["executions"] == 4 # two shifted tapes, two splitting times


@pytest.mark.jax
class TestStochPulseGradIntegration:
"""Test that stoch_pulse_grad integrates correctly with QNodes and ML interfaces."""

@pytest.mark.parametrize("shots, tol", [(None, 1e-4), (100, 0.1), ([100, 99], 0.1)])
@pytest.mark.parametrize("num_split_times", [1, 2])
Expand Down

0 comments on commit 3bf2964

Please sign in to comment.