From 52ccdba73a5b02103b05d7827e7e09af18e3e961 Mon Sep 17 00:00:00 2001 From: Edward Jiang <34989448+eddddddy@users.noreply.github.com> Date: Wed, 2 Aug 2023 11:22:15 -0400 Subject: [PATCH] Fix `split_non_commuting` when tape contains both `expval` and `var` measurements (#4426) * fix split non commuting * changelog * pylint --- doc/releases/changelog-dev.md | 4 +++ pennylane/transforms/split_non_commuting.py | 4 +-- tests/transforms/test_split_non_commuting.py | 30 ++++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b6c600dabd4..6b9301ca6b7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -226,6 +226,10 @@ not really have a decomposition. [(#4407)](https://github.com/PennyLaneAI/pennylane/pull/4407) +* `qml.transforms.split_non_commuting` now correctly works on tapes containing both `expval` + and `var` measurements. + [(#4426)](https://github.com/PennyLaneAI/pennylane/pull/4426) +

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/transforms/split_non_commuting.py b/pennylane/transforms/split_non_commuting.py index 463fd90a073..1e05efdde28 100644 --- a/pennylane/transforms/split_non_commuting.py +++ b/pennylane/transforms/split_non_commuting.py @@ -164,10 +164,10 @@ def circuit0(x): if len(groups) > 1: # make one tape per commuting group tapes = [] - for group in groups: + for group, indices in zip(groups, group_coeffs): new_tape = tape.__class__( tape._ops, - (m.__class__(obs=o) for m, o in zip(tape.measurements, group)), + (tape.measurements[i].__class__(obs=o) for o, i in zip(group, indices)), tape._prep, ) diff --git a/tests/transforms/test_split_non_commuting.py b/tests/transforms/test_split_non_commuting.py index fec09baa760..3cbc08fc197 100644 --- a/tests/transforms/test_split_non_commuting.py +++ b/tests/transforms/test_split_non_commuting.py @@ -127,6 +127,36 @@ def test_different_measurement_types(self, meas_type): for meas in new_tape.measurements: assert meas.return_type == the_return_type + def test_mixed_measurement_types(self): + """Test that mixing expval and var works correctly.""" + + with qml.queuing.AnnotatedQueue() as q: + qml.Hadamard(0) + qml.Hadamard(1) + qml.expval(qml.PauliX(0)) + qml.expval(qml.PauliZ(1)) + qml.var(qml.PauliZ(0)) + + tape = qml.tape.QuantumScript.from_queue(q) + split, _ = split_non_commuting(tape) + + assert len(split) == 2 + + with qml.queuing.AnnotatedQueue() as q: + qml.Hadamard(0) + qml.Hadamard(1) + qml.expval(qml.PauliX(0)) + qml.var(qml.PauliZ(0)) + qml.expval(qml.PauliZ(1)) + + tape = qml.tape.QuantumScript.from_queue(q) + split, _ = split_non_commuting(tape) + + assert len(split) == 2 + assert qml.equal(split[0].measurements[0], qml.expval(qml.PauliX(0))) + assert qml.equal(split[0].measurements[1], qml.expval(qml.PauliZ(1))) + assert qml.equal(split[1].measurements[0], qml.var(qml.PauliZ(0))) + def test_raise_not_supported(self): """Test that NotImplementedError is raised when probabilities or samples are called""" with qml.queuing.AnnotatedQueue() as q: