Skip to content

Commit

Permalink
Fix expval of Sum with broadcasting (#4275)
Browse files Browse the repository at this point in the history
* fix bug and add test

* changelog addition
  • Loading branch information
dwierichs authored and mudit2812 committed Jun 21, 2023
1 parent 8b32119 commit 99c3409
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-0.31.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@
* Allow for `Sum` observables with trainable parameters.
[(#4251)](https://github.com/PennyLaneAI/pennylane/pull/4251)
[(#4275)](https://github.com/PennyLaneAI/pennylane/pull/4275)
<h3>Contributors ✍️</h3>
Expand Down
9 changes: 6 additions & 3 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,13 @@ def expval(self, observable, shot_range=None, bin_size=None):
Hamiltonian is not NumPy or Autograd
"""
is_state_batched = self._ndim(self.state) == 2
# intercept Sums
if isinstance(observable, Sum) and not self.shots:
return measure(
ExpectationMP(observable.map_wires(self.wire_map)), self._pre_rotated_state
ExpectationMP(observable.map_wires(self.wire_map)),
self._pre_rotated_state,
is_state_batched,
)

# intercept other Hamiltonians
Expand All @@ -592,7 +595,7 @@ def expval(self, observable, shot_range=None, bin_size=None):

if backprop_mode:
# TODO[dwierichs]: This branch is not adapted to broadcasting yet
if self._ndim(self.state) == 2:
if is_state_batched:
raise NotImplementedError(
"Expectation values of Hamiltonians for interface!=None are "
"not supported together with parameter broadcasting yet"
Expand Down Expand Up @@ -632,7 +635,7 @@ def expval(self, observable, shot_range=None, bin_size=None):
Hmat = observable.sparse_matrix(wire_order=self.wires)

state = qml.math.toarray(self.state)
if self._ndim(state) == 2:
if is_state_batched:
res = qml.math.array(
[
csr_matrix.dot(
Expand Down
45 changes: 27 additions & 18 deletions tests/devices/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=protected-access,cell-var-from-loop
import math

from functools import partial
import pytest

import pennylane as qml
Expand Down Expand Up @@ -2361,22 +2362,28 @@ def test_Hamiltonian_filtered_from_rotations(self, mocker):
assert qml.equal(call_args.measurements[0], qml.expval(qml.PauliX(0)))


@pytest.mark.parametrize("is_state_batched", [False, True])
class TestSumSupport:
"""Tests for custom Sum support in DefaultQubit."""

expected_grad = [-np.sin(1.3), np.cos(1.3)]
@staticmethod
def expected_grad(is_state_batched):
if is_state_batched:
return [[-np.sin(1.3), -np.sin(0.4)], [np.cos(1.3), np.cos(0.4)]]
return [-np.sin(1.3), np.cos(1.3)]

@staticmethod
def circuit(y, z):
qml.RX(1.3, 0)
def circuit(y, z, is_state_batched):
rx_param = [1.3, 0.4] if is_state_batched else 1.3
qml.RX(rx_param, 0)
return qml.expval(
qml.sum(
qml.s_prod(y, qml.PauliY(0)),
qml.s_prod(z, qml.PauliZ(0)),
)
)

def test_super_expval_not_called(self, mocker):
def test_super_expval_not_called(self, is_state_batched, mocker):
"""Tests basic expval result, and ensures QubitDevice.expval is not called."""
dev = qml.device("default.qubit", wires=1)
spy = mocker.spy(qml.QubitDevice, "expval")
Expand All @@ -2385,49 +2392,51 @@ def test_super_expval_not_called(self, mocker):
spy.assert_not_called()

@pytest.mark.autograd
def test_trainable_autograd(self):
def test_trainable_autograd(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with autograd."""
if is_state_batched:
pytest.skip(msg="Broadcasting, qml.jacobian and new return types do not work together")
dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="autograd")
y, z = np.array([1.1, 2.2])
actual = qml.grad(qnode)(y, z)
assert np.allclose(actual, self.expected_grad)
actual = qml.grad(qnode, argnum=[0, 1])(y, z, is_state_batched)
assert np.allclose(actual, self.expected_grad(is_state_batched))

@pytest.mark.torch
def test_trainable_torch(self):
def test_trainable_torch(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with torch."""
import torch

dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="torch")
y, z = torch.tensor(1.1, requires_grad=True), torch.tensor(2.2, requires_grad=True)
qnode(y, z).backward()
actual = [y.grad, z.grad]
assert np.allclose(actual, self.expected_grad)
_qnode = partial(qnode, is_state_batched=is_state_batched)
actual = torch.stack(torch.autograd.functional.jacobian(_qnode, (y, z)))
assert np.allclose(actual, self.expected_grad(is_state_batched))

@pytest.mark.tf
def test_trainable_tf(self):
def test_trainable_tf(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with tf."""
import tensorflow as tf

dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="tensorflow")
y, z = tf.Variable(1.1, dtype=tf.float64), tf.Variable(2.2, dtype=tf.float64)
with tf.GradientTape() as tape:
res = qnode(y, z)
actual = tape.gradient(res, [y, z])
assert np.allclose(actual, self.expected_grad)
res = qnode(y, z, is_state_batched)
actual = tape.jacobian(res, [y, z])
assert np.allclose(actual, self.expected_grad(is_state_batched))

@pytest.mark.jax
def test_trainable_jax(self):
def test_trainable_jax(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with jax."""
import jax

dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="jax")
y, z = jax.numpy.array([1.1, 2.2])
actual = jax.grad(qnode, argnums=[0, 1])(y, z)
assert np.allclose(actual, self.expected_grad)
actual = jax.jacobian(qnode, argnums=[0, 1])(y, z, is_state_batched)
assert np.allclose(actual, self.expected_grad(is_state_batched))


class TestGetBatchSize:
Expand Down

0 comments on commit 99c3409

Please sign in to comment.