Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legacy device API handles Prod observables #5475

Merged
merged 20 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)

from pennylane.operation import Observable, Operation, Tensor, Operator, StatePrepBase
from pennylane.ops import Hamiltonian, Sum, LinearCombination
from pennylane.ops import Hamiltonian, Sum, LinearCombination, Prod
from pennylane.tape import QuantumScript, QuantumTape, expand_tape_state_prep
from pennylane.wires import WireError, Wires
from pennylane.queuing import QueuingManager
Expand Down Expand Up @@ -744,7 +744,6 @@ def batch_transform(self, circuit: QuantumTape):
to be applied to the list of evaluated circuit results.
"""
supports_hamiltonian = self.supports_observable("Hamiltonian")

supports_sum = self.supports_observable("Sum")
finite_shots = self.shots is not None
grouping_known = all(
Expand All @@ -759,7 +758,12 @@ def batch_transform(self, circuit: QuantumTape):
isinstance(obs, (Hamiltonian, LinearCombination)) for obs in circuit.observables
)
expval_sum_in_obs = any(
isinstance(m.obs, Sum) and isinstance(m, ExpectationMP) for m in circuit.measurements
(
isinstance(m.obs, Sum)
or (isinstance(m.obs, Prod) and isinstance(m.obs.simplify(), Sum))
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
)
and isinstance(m, ExpectationMP)
for m in circuit.measurements
astralcai marked this conversation as resolved.
Show resolved Hide resolved
)

is_shadow = any(isinstance(m, ShadowExpvalMP) for m in circuit.measurements)
Expand Down Expand Up @@ -1007,6 +1011,21 @@ def check_validity(self, queue, observables):
raise DeviceError(
f"Observable {i.name} not supported on device {self.short_name}"
)

elif isinstance(o, qml.ops.Prod):

supports_prod = self.supports_observable(o.name)
if not supports_prod:
raise DeviceError(f"Observable Prod not supported on device {self.short_name}")
astralcai marked this conversation as resolved.
Show resolved Hide resolved

simplified_op = o.simplify()
if isinstance(simplified_op, qml.ops.Prod):
for i in o.simplify().operands:
if not self.supports_observable(i.name):
raise DeviceError(
f"Observable {i.name} not supported on device {self.short_name}"
astralcai marked this conversation as resolved.
Show resolved Hide resolved
)

else:
observable_name = o.name

Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ class TestHamiltonianSupport:
"""Separate test to ensure that the device can differentiate Hamiltonian observables."""

@pytest.mark.parametrize("ham_constructor", [qml.ops.Hamiltonian, qml.ops.LinearCombination])
@pytest.mark.filterwarnings("ignore::pennylane.PennyLaneDeprecationWarning")
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
def test_hamiltonian_diff(self, ham_constructor, device_kwargs, tol):
"""Tests a simple VQE gradient using parameter-shift rules."""

device_kwargs["wires"] = 1
dev = qml.device(**device_kwargs)
coeffs = np.array([-0.05, 0.17])
Expand Down
4 changes: 3 additions & 1 deletion pennylane/transforms/hamiltonian_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pennylane as qml
from pennylane.measurements import ExpectationMP, MeasurementProcess
from pennylane.ops import SProd, Sum
from pennylane.ops import SProd, Sum, Prod
from pennylane.tape import QuantumScript, QuantumTape
from pennylane.transforms import transform

Expand Down Expand Up @@ -341,6 +341,8 @@ def sum_expand(tape: QuantumTape, group: bool = True) -> (Sequence[QuantumTape],
idxs_coeffs_dict = {} # {m_hash: [(location_idx, coeff)]}
for idx, m in enumerate(tape.measurements):
obs = m.obs
if isinstance(obs, Prod) and isinstance(m, ExpectationMP):
obs = obs.simplify()
astralcai marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(obs, Sum) and isinstance(m, ExpectationMP):
for summand in obs.operands:
coeff = 1
Expand Down
117 changes: 55 additions & 62 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,6 @@
# pylint: disable=abstract-class-instantiated, no-self-use, redefined-outer-name, invalid-name, missing-function-docstring


@pytest.fixture(scope="function")
def mock_device_with_operations(monkeypatch):
"""A function to create a mock device with non-empty operations"""
with monkeypatch.context() as m:
m.setattr(Device, "__abstractmethods__", frozenset())
m.setattr(Device, "operations", mock_device_paulis)
m.setattr(Device, "observables", mock_device_paulis)
m.setattr(Device, "short_name", "MockDevice")

def get_device(wires=1):
return Device(wires=wires)

yield get_device


@pytest.fixture(scope="function")
def mock_device_with_observables(monkeypatch):
"""A function to create a mock device with non-empty observables"""
with monkeypatch.context() as m:
m.setattr(Device, "__abstractmethods__", frozenset())
m.setattr(Device, "operations", mock_device_paulis)
m.setattr(Device, "observables", mock_device_paulis)
m.setattr(Device, "short_name", "MockDevice")

def get_device(wires=1):
return Device(wires=wires)

yield get_device


astralcai marked this conversation as resolved.
Show resolved Hide resolved
@pytest.fixture(scope="function")
def mock_device_with_identity(monkeypatch):
"""A function to create a mock device with non-empty observables"""
Expand Down Expand Up @@ -203,19 +173,15 @@ def get_device(wires=1):


@pytest.fixture(scope="function")
def mock_device_arbitrary_wires(monkeypatch):
def mock_device_supporting_prod(monkeypatch):
with monkeypatch.context() as m:
m.setattr(Device, "__abstractmethods__", frozenset())
m.setattr(Device, "_capabilities", mock_device_capabilities)
m.setattr(Device, "operations", ["PauliY", "RX", "Rot"])
m.setattr(Device, "observables", ["PauliZ"])
m.setattr(Device, "operations", ["PauliX", "PauliZ"])
m.setattr(Device, "observables", ["PauliX", "PauliZ", "Prod"])
m.setattr(Device, "short_name", "MockDevice")
m.setattr(Device, "expval", lambda self, x, y, z: 0)
m.setattr(Device, "var", lambda self, x, y, z: 0)
m.setattr(Device, "sample", lambda self, x, y, z: 0)
m.setattr(Device, "apply", lambda self, x, y, z: None)

def get_device(wires):
def get_device(wires=1):
return Device(wires=wires)

yield get_device
Expand Down Expand Up @@ -245,22 +211,22 @@ class TestDeviceSupportedLogic:

# pylint: disable=no-self-use, redefined-outer-name

def test_supports_operation_argument_types(self, mock_device_with_operations):
def test_supports_operation_argument_types(self, mock_device_supporting_paulis):
"""Checks that device.supports_operations returns the correct result
when passed both string and Operation class arguments"""

dev = mock_device_with_operations()
dev = mock_device_supporting_paulis()

assert dev.supports_operation("PauliX")
assert dev.supports_operation(qml.PauliX)

assert not dev.supports_operation("S")
assert not dev.supports_operation(qml.CNOT)

def test_supports_observable_argument_types(self, mock_device_with_observables):
def test_supports_observable_argument_types(self, mock_device_supporting_paulis):
"""Checks that device.supports_observable returns the correct result
when passed both string and Operation class arguments"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()

assert dev.supports_observable("PauliX")
assert dev.supports_observable(qml.PauliX)
Expand Down Expand Up @@ -309,14 +275,14 @@ class TestInternalFunctions: # pylint:disable=too-many-public-methods
"""Test the internal functions of the abstract Device class"""

# pylint: disable=unnecessary-dunder-call
def test_repr(self, mock_device_with_operations):
def test_repr(self, mock_device_supporting_paulis):
"""Tests the __repr__ function"""
dev = mock_device_with_operations()
dev = mock_device_supporting_paulis()
assert "<Device device (wires=1, shots=1000) at " in dev.__repr__()

def test_str(self, mock_device_with_operations):
def test_str(self, mock_device_supporting_paulis):
"""Tests the __str__ function"""
dev = mock_device_with_operations()
dev = mock_device_supporting_paulis()
string = str(dev)
assert "Short name: MockDevice" in string
assert "Package: pennylane" in string
Expand All @@ -340,6 +306,33 @@ def test_check_validity_on_valid_queue(self, mock_device_supporting_paulis):
# Raises an error if queue or observables are invalid
dev.check_validity(queue, observables)

def test_check_validity_containing_prod(self, mock_device_supporting_prod):
astralcai marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that the function Device.check_validity works with Prod"""

dev = mock_device_supporting_prod()

queue = [
qml.PauliX(wires=0),
qml.PauliZ(wires=1),
]

observables = [
qml.expval(qml.PauliX(0) @ qml.PauliZ(1)),
qml.expval(qml.PauliZ(0) @ (qml.PauliX(1) @ qml.PauliZ(2))),
]

dev.check_validity(queue, observables)

unsupported_nested_observables = [
qml.expval(qml.PauliZ(0) @ (qml.PauliX(1) @ qml.PauliY(2)))
Qottmann marked this conversation as resolved.
Show resolved Hide resolved
]

with pytest.raises(
DeviceError,
match="Observable PauliY not supported",
):
dev.check_validity(queue, unsupported_nested_observables)

@pytest.mark.usefixtures("use_legacy_opmath")
def test_check_validity_on_tensor_support_legacy_opmath(self, mock_device_supporting_paulis):
"""Tests the function Device.check_validity with tensor support capability"""
Expand Down Expand Up @@ -429,9 +422,9 @@ def test_check_validity_on_invalid_observable(self, mock_device_supporting_pauli
with pytest.raises(DeviceError, match="Observable Hadamard not supported on device"):
dev.check_validity(queue, observables)

def test_check_validity_on_projector_as_operation(self, mock_device_with_operations):
def test_check_validity_on_projector_as_operation(self, mock_device_supporting_paulis):
"""Test that an error is raised if the operation queue contains qml.Projector"""
dev = mock_device_with_operations(wires=1)
dev = mock_device_supporting_paulis(wires=1)

queue = [qml.PauliX(0), qml.Projector([0], wires=0), qml.PauliZ(0)]
observables = []
Expand Down Expand Up @@ -592,8 +585,8 @@ def test_conditional_ops_unsupported_error(self, mock_device_with_paulis_and_met
(Wires([0]), Wires([0]), Wires([0])),
],
)
def test_order_wires(self, wires, subset, expected_subset, mock_device_arbitrary_wires):
dev = mock_device_arbitrary_wires(wires=wires)
def test_order_wires(self, wires, subset, expected_subset, mock_device):
dev = mock_device(wires=wires)
ordered_subset = dev.order_wires(subset_wires=subset)
assert ordered_subset == expected_subset

Expand All @@ -606,8 +599,8 @@ def test_order_wires(self, wires, subset, expected_subset, mock_device_arbitrary
(Wires([0]), Wires([2])),
],
)
def test_order_wires_raises_value_error(self, wires, subset, mock_device_arbitrary_wires):
dev = mock_device_arbitrary_wires(wires=wires)
def test_order_wires_raises_value_error(self, wires, subset, mock_device):
dev = mock_device(wires=wires)
with pytest.raises(ValueError, match="Could not find some or all subset wires"):
_ = dev.order_wires(subset_wires=subset)

Expand Down Expand Up @@ -658,11 +651,11 @@ def test_default_expand_with_initial_state(self, op, decomp):
assert new_tape.batch_size == tape.batch_size
assert new_tape.output_dim == tape.output_dim

def test_default_expand_fn_with_invalid_op(self, mock_device_with_operations, recwarn):
def test_default_expand_fn_with_invalid_op(self, mock_device_supporting_paulis, recwarn):
"""Test that default_expand_fn works with an invalid op and some measurement."""
invalid_tape = qml.tape.QuantumScript([qml.S(0)], [qml.expval(qml.PauliZ(0))])
expected_tape = qml.tape.QuantumScript([qml.RZ(np.pi / 2, 0)], [qml.expval(qml.PauliZ(0))])
dev = mock_device_with_operations(wires=1)
dev = mock_device_supporting_paulis(wires=1)
expanded_tape = dev.expand_fn(invalid_tape, max_expansion=3)
assert qml.equal(expanded_tape, expected_tape)
assert len(recwarn) == 0
Expand Down Expand Up @@ -799,33 +792,33 @@ def test_unsupported_operations_raise_error(self, mock_device_with_paulis_and_me
with pytest.raises(DeviceError, match="Gate Hadamard not supported on device"):
dev.execute(queue, observables)

def test_execute_obs_probs(self, mock_device_with_observables):
def test_execute_obs_probs(self, mock_device_supporting_paulis):
"""Tests that the execute function raises an error if probabilities are
not supported by the device"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
obs = qml.probs(op=qml.PauliZ(0))
with pytest.raises(NotImplementedError):
dev.execute([], [obs])

def test_var(self, mock_device_with_observables):
def test_var(self, mock_device_supporting_paulis):
"""Tests that the variance method are not implemented by the device by
default"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
with pytest.raises(NotImplementedError):
dev.var(qml.PauliZ, 0, [])

def test_sample(self, mock_device_with_observables):
def test_sample(self, mock_device_supporting_paulis):
"""Tests that the sample method are not implemented by the device by
default"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
with pytest.raises(NotImplementedError):
dev.sample(qml.PauliZ, 0, [])

@pytest.mark.parametrize("wires", [None, []])
def test_probability(self, mock_device_with_observables, wires):
def test_probability(self, mock_device_supporting_paulis, wires):
"""Tests that the probability method are not implemented by the device
by default"""
dev = mock_device_with_observables()
dev = mock_device_supporting_paulis()
with pytest.raises(NotImplementedError):
dev.probability(wires=wires)

Expand Down
Loading