Skip to content

Commit

Permalink
[BUGFIX] Removing measurement process hash/equality warnings (#4498)
Browse files Browse the repository at this point in the history
* Adding pytest error

* Fixed errors

* Update changelog

* linting

* Fixed utils

* Fixed error

* Update tests/ops/functions/test_equal.py

Co-authored-by: Matthew Silverman <[email protected]>

* Change versions (#4499)

* Trigger CI

* Trigger CI

---------

Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
mudit2812 and timmysilv authored Aug 22, 2023
1 parent 75f906e commit 2366c0a
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-0.32.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ array([False, False])
[(#4144)](https://github.com/PennyLaneAI/pennylane/pull/4144)
[(#4454)](https://github.com/PennyLaneAI/pennylane/pull/4454)
[(#4489)](https://github.com/PennyLaneAI/pennylane/pull/4489)
[(#4498)](https://github.com/PennyLaneAI/pennylane/pull/4498)

* The `sampler_seed` argument of `qml.gradients.spsa_grad` has been deprecated, along with a bug
fix of the seed-setting behaviour.
Expand Down
30 changes: 25 additions & 5 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import uuid
from typing import Generic, TypeVar, Optional
import warnings

import pennylane as qml
import pennylane.numpy as np
Expand Down Expand Up @@ -255,16 +256,35 @@ def _apply(self, fn):
def _merge(self, other: "MeasurementValue"):
"""Merge two measurement values"""

# create a new merged list with no duplicates and in lexical ordering
merged_measurements = list(set(self.measurements).union(set(other.measurements)))
with warnings.catch_warnings():
# Using a filter because the new behaviour of MP hash will be valid here
warnings.filterwarnings(
"ignore",
message="The behaviour of measurement process hashing",
category=UserWarning,
)
# create a new merged list with no duplicates and in lexical ordering
merged_measurements = list(set(self.measurements).union(set(other.measurements)))

merged_measurements.sort(key=lambda m: m.id)

# create a new function that selects the correct indices for each sub function
def merged_fn(*x):
sub_args_1 = (x[i] for i in [merged_measurements.index(m) for m in self.measurements])
out_1 = self.processing_fn(*sub_args_1)
with warnings.catch_warnings():
# Using a filter because the new behaviour of MP equality will be valid here
warnings.filterwarnings(
"ignore",
message="The behaviour of measurement process equality",
category=UserWarning,
)
sub_args_1 = (
x[i] for i in [merged_measurements.index(m) for m in self.measurements]
)
sub_args_2 = (
x[i] for i in [merged_measurements.index(m) for m in other.measurements]
)

sub_args_2 = (x[i] for i in [merged_measurements.index(m) for m in other.measurements])
out_1 = self.processing_fn(*sub_args_1)
out_2 = other.processing_fn(*sub_args_2)

return out_1, out_2
Expand Down
14 changes: 14 additions & 0 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pennylane as qml
from pennylane.measurements import MeasurementProcess
from pennylane.measurements.classical_shadow import ShadowExpvalMP
from pennylane.measurements.mid_measure import MidMeasureMP
from pennylane.measurements.mutual_info import MutualInfoMP
from pennylane.measurements.vn_entropy import VnEntropyMP
from pennylane.pulse.parametrized_evolution import ParametrizedEvolution
Expand Down Expand Up @@ -355,6 +356,19 @@ def _equal_measurements(
return False


@_equal.register
# pylint: disable=unused-argument
def _equal_mid_measure(
op1: MidMeasureMP,
op2: MidMeasureMP,
check_interface=True,
check_trainability=True,
rtol=1e-5,
atol=1e-9,
):
return op1.wires == op2.wires and op1.id == op2.id and op1.reset == op2.reset


@_equal.register
# pylint: disable=unused-argument
def _(op1: VnEntropyMP, op2: VnEntropyMP, **kwargs):
Expand Down
17 changes: 15 additions & 2 deletions pennylane/transforms/hamiltonian_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,28 @@ def sum_expand(tape: QuantumTape, group=True):
idxs_coeffs = list(idxs_coeffs_dict.values())

# Create the tapes, group observables if group==True
# pylint: disable=too-many-nested-blocks
if group:
m_groups = _group_measurements(measurements)
# Update ``idxs_coeffs`` list such that it tracks the new ``m_groups`` list of lists
tmp_idxs = []
for m_group in m_groups:
if len(m_group) == 1:
tmp_idxs.append(idxs_coeffs[measurements.index(m_group[0])])
# pylint: disable=undefined-loop-variable
for i, m in enumerate(measurements):
if m is m_group[0]:
break
tmp_idxs.append(idxs_coeffs[i])
else:
tmp_idxs.append([idxs_coeffs[measurements.index(m)] for m in m_group])
inds = []
for mp in m_group:
# pylint: disable=undefined-loop-variable
for i, m in enumerate(measurements):
if m is mp:
break
inds.append(idxs_coeffs[i])
tmp_idxs.append(inds)

idxs_coeffs = tmp_idxs
qscripts = [
QuantumScript(ops=tape.operations, measurements=m_group, shots=tape.shots)
Expand Down
20 changes: 20 additions & 0 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,26 @@ def test_observables_equal_but_wire_order_not(self):
o2 = qml.prod(z0, x1)
assert qml.equal(qml.expval(o1), qml.expval(o2))

def test_mid_measure(self):
"""Test that `MidMeasureMP`s are equal only if their wires
an id are equal and their `reset` attribute match."""
mp = qml.measurements.MidMeasureMP(wires=qml.wires.Wires([0, 1]), reset=True, id="test_id")

mp1 = qml.measurements.MidMeasureMP(wires=qml.wires.Wires([1, 0]), reset=True, id="test_id")
mp2 = qml.measurements.MidMeasureMP(
wires=qml.wires.Wires([0, 1]), reset=False, id="test_id"
)
mp3 = qml.measurements.MidMeasureMP(wires=qml.wires.Wires([0, 1]), reset=True, id="foo")

assert not qml.equal(mp, mp1)
assert not qml.equal(mp, mp2)
assert not qml.equal(mp, mp3)

assert qml.equal(
mp,
qml.measurements.MidMeasureMP(wires=qml.wires.Wires([0, 1]), reset=True, id="test_id"),
)


class TestObservablesComparisons:
"""Tests comparisons between Hamiltonians, Tensors and PauliX/Y/Z operators"""
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ filterwarnings =
ignore:Call to deprecated create function:DeprecationWarning
ignore:the imp module is deprecated:DeprecationWarning
error:The behaviour of operator:UserWarning
error:The behaviour of measurement process:UserWarning

0 comments on commit 2366c0a

Please sign in to comment.