Skip to content

Commit

Permalink
Fix split_non_commuting when tape contains both expval and var
Browse files Browse the repository at this point in the history
…measurements (#4426)

* fix split non commuting

* changelog

* pylint
  • Loading branch information
eddddddy authored Aug 2, 2023
1 parent 1baad70 commit 52ccdba
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@
not really have a decomposition.
[(#4407)](https://github.com/PennyLaneAI/pennylane/pull/4407)

* `qml.transforms.split_non_commuting` now correctly works on tapes containing both `expval`
and `var` measurements.
[(#4426)](https://github.com/PennyLaneAI/pennylane/pull/4426)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def circuit0(x):
if len(groups) > 1:
# make one tape per commuting group
tapes = []
for group in groups:
for group, indices in zip(groups, group_coeffs):
new_tape = tape.__class__(
tape._ops,
(m.__class__(obs=o) for m, o in zip(tape.measurements, group)),
(tape.measurements[i].__class__(obs=o) for o, i in zip(group, indices)),
tape._prep,
)

Expand Down
30 changes: 30 additions & 0 deletions tests/transforms/test_split_non_commuting.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,36 @@ def test_different_measurement_types(self, meas_type):
for meas in new_tape.measurements:
assert meas.return_type == the_return_type

def test_mixed_measurement_types(self):
"""Test that mixing expval and var works correctly."""

with qml.queuing.AnnotatedQueue() as q:
qml.Hadamard(0)
qml.Hadamard(1)
qml.expval(qml.PauliX(0))
qml.expval(qml.PauliZ(1))
qml.var(qml.PauliZ(0))

tape = qml.tape.QuantumScript.from_queue(q)
split, _ = split_non_commuting(tape)

assert len(split) == 2

with qml.queuing.AnnotatedQueue() as q:
qml.Hadamard(0)
qml.Hadamard(1)
qml.expval(qml.PauliX(0))
qml.var(qml.PauliZ(0))
qml.expval(qml.PauliZ(1))

tape = qml.tape.QuantumScript.from_queue(q)
split, _ = split_non_commuting(tape)

assert len(split) == 2
assert qml.equal(split[0].measurements[0], qml.expval(qml.PauliX(0)))
assert qml.equal(split[0].measurements[1], qml.expval(qml.PauliZ(1)))
assert qml.equal(split[1].measurements[0], qml.var(qml.PauliZ(0)))

def test_raise_not_supported(self):
"""Test that NotImplementedError is raised when probabilities or samples are called"""
with qml.queuing.AnnotatedQueue() as q:
Expand Down

0 comments on commit 52ccdba

Please sign in to comment.