diff --git a/doc/releases/changelog-0.31.0.md b/doc/releases/changelog-0.31.0.md
index f00da904469..4213ac92337 100644
--- a/doc/releases/changelog-0.31.0.md
+++ b/doc/releases/changelog-0.31.0.md
@@ -447,6 +447,7 @@
* Allow for `Sum` observables with trainable parameters.
[(#4251)](https://github.com/PennyLaneAI/pennylane/pull/4251)
+ [(#4275)](https://github.com/PennyLaneAI/pennylane/pull/4275)
Contributors ✍️
diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py
index 35a0a5da1c8..ed33195fbae 100644
--- a/pennylane/devices/default_qubit.py
+++ b/pennylane/devices/default_qubit.py
@@ -570,10 +570,13 @@ def expval(self, observable, shot_range=None, bin_size=None):
Hamiltonian is not NumPy or Autograd
"""
+ is_state_batched = self._ndim(self.state) == 2
# intercept Sums
if isinstance(observable, Sum) and not self.shots:
return measure(
- ExpectationMP(observable.map_wires(self.wire_map)), self._pre_rotated_state
+ ExpectationMP(observable.map_wires(self.wire_map)),
+ self._pre_rotated_state,
+ is_state_batched,
)
# intercept other Hamiltonians
@@ -592,7 +595,7 @@ def expval(self, observable, shot_range=None, bin_size=None):
if backprop_mode:
# TODO[dwierichs]: This branch is not adapted to broadcasting yet
- if self._ndim(self.state) == 2:
+ if is_state_batched:
raise NotImplementedError(
"Expectation values of Hamiltonians for interface!=None are "
"not supported together with parameter broadcasting yet"
@@ -632,7 +635,7 @@ def expval(self, observable, shot_range=None, bin_size=None):
Hmat = observable.sparse_matrix(wire_order=self.wires)
state = qml.math.toarray(self.state)
- if self._ndim(state) == 2:
+ if is_state_batched:
res = qml.math.array(
[
csr_matrix.dot(
diff --git a/tests/devices/test_default_qubit.py b/tests/devices/test_default_qubit.py
index 85a2ef86781..52788c14ba6 100644
--- a/tests/devices/test_default_qubit.py
+++ b/tests/devices/test_default_qubit.py
@@ -19,6 +19,7 @@
# pylint: disable=protected-access,cell-var-from-loop
import math
+from functools import partial
import pytest
import pennylane as qml
@@ -2361,14 +2362,20 @@ def test_Hamiltonian_filtered_from_rotations(self, mocker):
assert qml.equal(call_args.measurements[0], qml.expval(qml.PauliX(0)))
+@pytest.mark.parametrize("is_state_batched", [False, True])
class TestSumSupport:
"""Tests for custom Sum support in DefaultQubit."""
- expected_grad = [-np.sin(1.3), np.cos(1.3)]
+ @staticmethod
+ def expected_grad(is_state_batched):
+ if is_state_batched:
+ return [[-np.sin(1.3), -np.sin(0.4)], [np.cos(1.3), np.cos(0.4)]]
+ return [-np.sin(1.3), np.cos(1.3)]
@staticmethod
- def circuit(y, z):
- qml.RX(1.3, 0)
+ def circuit(y, z, is_state_batched):
+ rx_param = [1.3, 0.4] if is_state_batched else 1.3
+ qml.RX(rx_param, 0)
return qml.expval(
qml.sum(
qml.s_prod(y, qml.PauliY(0)),
@@ -2376,7 +2383,7 @@ def circuit(y, z):
)
)
- def test_super_expval_not_called(self, mocker):
+ def test_super_expval_not_called(self, is_state_batched, mocker):
"""Tests basic expval result, and ensures QubitDevice.expval is not called."""
dev = qml.device("default.qubit", wires=1)
spy = mocker.spy(qml.QubitDevice, "expval")
@@ -2385,28 +2392,30 @@ def test_super_expval_not_called(self, mocker):
spy.assert_not_called()
@pytest.mark.autograd
- def test_trainable_autograd(self):
+ def test_trainable_autograd(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with autograd."""
+ if is_state_batched:
+ pytest.skip(msg="Broadcasting, qml.jacobian and new return types do not work together")
dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="autograd")
y, z = np.array([1.1, 2.2])
- actual = qml.grad(qnode)(y, z)
- assert np.allclose(actual, self.expected_grad)
+ actual = qml.grad(qnode, argnum=[0, 1])(y, z, is_state_batched)
+ assert np.allclose(actual, self.expected_grad(is_state_batched))
@pytest.mark.torch
- def test_trainable_torch(self):
+ def test_trainable_torch(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with torch."""
import torch
dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="torch")
y, z = torch.tensor(1.1, requires_grad=True), torch.tensor(2.2, requires_grad=True)
- qnode(y, z).backward()
- actual = [y.grad, z.grad]
- assert np.allclose(actual, self.expected_grad)
+ _qnode = partial(qnode, is_state_batched=is_state_batched)
+ actual = torch.stack(torch.autograd.functional.jacobian(_qnode, (y, z)))
+ assert np.allclose(actual, self.expected_grad(is_state_batched))
@pytest.mark.tf
- def test_trainable_tf(self):
+ def test_trainable_tf(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with tf."""
import tensorflow as tf
@@ -2414,20 +2423,20 @@ def test_trainable_tf(self):
qnode = qml.QNode(self.circuit, dev, interface="tensorflow")
y, z = tf.Variable(1.1, dtype=tf.float64), tf.Variable(2.2, dtype=tf.float64)
with tf.GradientTape() as tape:
- res = qnode(y, z)
- actual = tape.gradient(res, [y, z])
- assert np.allclose(actual, self.expected_grad)
+ res = qnode(y, z, is_state_batched)
+ actual = tape.jacobian(res, [y, z])
+ assert np.allclose(actual, self.expected_grad(is_state_batched))
@pytest.mark.jax
- def test_trainable_jax(self):
+ def test_trainable_jax(self, is_state_batched):
"""Tests that coeffs passed to a sum are trainable with jax."""
import jax
dev = qml.device("default.qubit", wires=1)
qnode = qml.QNode(self.circuit, dev, interface="jax")
y, z = jax.numpy.array([1.1, 2.2])
- actual = jax.grad(qnode, argnums=[0, 1])(y, z)
- assert np.allclose(actual, self.expected_grad)
+ actual = jax.jacobian(qnode, argnums=[0, 1])(y, z, is_state_batched)
+ assert np.allclose(actual, self.expected_grad(is_state_batched))
class TestGetBatchSize: