From 52ccdba73a5b02103b05d7827e7e09af18e3e961 Mon Sep 17 00:00:00 2001
From: Edward Jiang <34989448+eddddddy@users.noreply.github.com>
Date: Wed, 2 Aug 2023 11:22:15 -0400
Subject: [PATCH] Fix `split_non_commuting` when tape contains both `expval`
and `var` measurements (#4426)
* fix split non commuting
* changelog
* pylint
---
doc/releases/changelog-dev.md | 4 +++
pennylane/transforms/split_non_commuting.py | 4 +--
tests/transforms/test_split_non_commuting.py | 30 ++++++++++++++++++++
3 files changed, 36 insertions(+), 2 deletions(-)
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index b6c600dabd4..6b9301ca6b7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -226,6 +226,10 @@
not really have a decomposition.
[(#4407)](https://github.com/PennyLaneAI/pennylane/pull/4407)
+* `qml.transforms.split_non_commuting` now correctly works on tapes containing both `expval`
+ and `var` measurements.
+ [(#4426)](https://github.com/PennyLaneAI/pennylane/pull/4426)
+
Contributors ✍️
This release contains contributions from (in alphabetical order):
diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py
index 463fd90a073..1e05efdde28 100644
--- a/pennylane/transforms/split_non_commuting.py
+++ b/pennylane/transforms/split_non_commuting.py
@@ -164,10 +164,10 @@ def circuit0(x):
if len(groups) > 1:
# make one tape per commuting group
tapes = []
- for group in groups:
+ for group, indices in zip(groups, group_coeffs):
new_tape = tape.__class__(
tape._ops,
- (m.__class__(obs=o) for m, o in zip(tape.measurements, group)),
+ (tape.measurements[i].__class__(obs=o) for o, i in zip(group, indices)),
tape._prep,
)
diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py
index fec09baa760..3cbc08fc197 100644
--- a/tests/transforms/test_split_non_commuting.py
+++ b/tests/transforms/test_split_non_commuting.py
@@ -127,6 +127,36 @@ def test_different_measurement_types(self, meas_type):
for meas in new_tape.measurements:
assert meas.return_type == the_return_type
+ def test_mixed_measurement_types(self):
+ """Test that mixing expval and var works correctly."""
+
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.Hadamard(0)
+ qml.Hadamard(1)
+ qml.expval(qml.PauliX(0))
+ qml.expval(qml.PauliZ(1))
+ qml.var(qml.PauliZ(0))
+
+ tape = qml.tape.QuantumScript.from_queue(q)
+ split, _ = split_non_commuting(tape)
+
+ assert len(split) == 2
+
+ with qml.queuing.AnnotatedQueue() as q:
+ qml.Hadamard(0)
+ qml.Hadamard(1)
+ qml.expval(qml.PauliX(0))
+ qml.var(qml.PauliZ(0))
+ qml.expval(qml.PauliZ(1))
+
+ tape = qml.tape.QuantumScript.from_queue(q)
+ split, _ = split_non_commuting(tape)
+
+ assert len(split) == 2
+ assert qml.equal(split[0].measurements[0], qml.expval(qml.PauliX(0)))
+ assert qml.equal(split[0].measurements[1], qml.expval(qml.PauliZ(1)))
+ assert qml.equal(split[1].measurements[0], qml.var(qml.PauliZ(0)))
+
def test_raise_not_supported(self):
"""Test that NotImplementedError is raised when probabilities or samples are called"""
with qml.queuing.AnnotatedQueue() as q: