diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6b9301ca6b7..bb4a9d4b84a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -87,6 +87,16 @@ * When given a callable, `qml.ctrl` now does its custom pre-processing on all queued operators from the callable. [(#4370)](https://github.com/PennyLaneAI/pennylane/pull/4370) + +* `qml.interfaces.set_shots` accepts `Shots` object as well as `int`'s and tuples of `int`'s. + [(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388) + + +* `pennylane.devices.experimental.Device` now accepts a shots keyword argument and has a `shots` + property. This property is merely used to set defaults for a workflow, and does not directly + influence the number of shots used in executions or derivatives. + [(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388) + * PennyLane no longer directly relies on `Operator.__eq__`. [(#4398)](https://github.com/PennyLaneAI/pennylane/pull/4398) diff --git a/pennylane/devices/experimental/default_qubit_2.py b/pennylane/devices/experimental/default_qubit_2.py index db074d82ce0..e072f64edbb 100644 --- a/pennylane/devices/experimental/default_qubit_2.py +++ b/pennylane/devices/experimental/default_qubit_2.py @@ -44,13 +44,14 @@ class DefaultQubit2(Device): """A PennyLane device written in Python and capable of backpropagation derivatives. - Keyword Args: + Args: + shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots to use in executions involving + this device. seed="global" (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng`` or a request to seed from numpy's global random number generator. The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None`` will pull a seed from the OS entropy. - max_workers=None (int): A ``ProcessPoolExecutor`` executes tapes asynchronously using a pool of at most ``max_workers`` processes. If ``max_workers`` is ``None``, only the current process executes tapes. If you experience any @@ -137,8 +138,8 @@ def name(self): """The name of the device.""" return "default.qubit.2" - def __init__(self, seed="global", max_workers=None) -> None: - super().__init__() + def __init__(self, shots=None, seed="global", max_workers=None) -> None: + super().__init__(shots=shots) self._max_workers = max_workers seed = np.random.randint(0, high=10000000) if seed == "global" else seed self._rng = np.random.default_rng(seed) diff --git a/pennylane/devices/experimental/device_api.py b/pennylane/devices/experimental/device_api.py index bed857ee0e2..cc1b0c8b0de 100644 --- a/pennylane/devices/experimental/device_api.py +++ b/pennylane/devices/experimental/device_api.py @@ -20,6 +20,7 @@ from numbers import Number from typing import Callable, Union, Sequence, Tuple, Optional +from pennylane.measurements import Shots from pennylane.tape import QuantumTape, QuantumScript from pennylane.typing import Result, ResultBatch from pennylane import Tracker @@ -152,9 +153,20 @@ def name(self) -> str: self.tracker.record() """ - def __init__(self) -> None: + def __init__(self, shots=None) -> None: # each instance should have its own Tracker. self.tracker = Tracker() + self._shots = Shots(shots) + + @property + def shots(self) -> Shots: + """Default shots for execution workflows containing this device. + + Note that the device itself should **always** pull shots from the provided :class:`~.QuantumTape` and its + :attr:`~.QuantumTape.shots`, not from this property. This property is used to provide a default at the start of a workflow. + + """ + return self._shots def preprocess( self, diff --git a/pennylane/interfaces/execution.py b/pennylane/interfaces/execution.py index 76922fb1567..25ee2769ebd 100644 --- a/pennylane/interfaces/execution.py +++ b/pennylane/interfaces/execution.py @@ -23,9 +23,7 @@ # pylint: disable=unused-argument,unnecessary-lambda-assignment,inconsistent-return-statements, # pylint: disable=too-many-statements, invalid-unary-operand-type, function-redefined -import inspect import warnings -from contextlib import _GeneratorContextManager from functools import wraps, partial from typing import Callable, Sequence, Optional, Union, Tuple @@ -244,23 +242,7 @@ def wrapper(tapes: Sequence[QuantumTape], **kwargs): # Tape exists within the cache, store the cached result cached_results[i] = cache[hashes[i]] - # Introspect the set_shots decorator of the input function: - # warn the user in case of finite shots with cached results - finite_shots = False - - closure = inspect.getclosurevars(fn).nonlocals - if "original_fn" in closure: # deal with expand_fn wrapper above - closure = inspect.getclosurevars(closure["original_fn"]).nonlocals - - # retrieve the captured context manager instance (for set_shots) - if "self" in closure and isinstance(closure["self"], _GeneratorContextManager): - # retrieve the shots from the arguments or device instance - if closure["self"].func.__name__ == "set_shots": - dev, shots = closure["self"].args - shots = dev.shots if shots is False else shots - finite_shots = isinstance(shots, int) - - if finite_shots and getattr(cache, "_persistent_cache", True): + if tape.shots and getattr(cache, "_persistent_cache", True): warnings.warn( "Cached execution with finite shots detected!\n" "Note that samples as well as all noisy quantities computed via sampling " diff --git a/pennylane/interfaces/set_shots.py b/pennylane/interfaces/set_shots.py index 3d3643dc1af..779cb3246c8 100644 --- a/pennylane/interfaces/set_shots.py +++ b/pennylane/interfaces/set_shots.py @@ -18,6 +18,9 @@ # pylint: disable=protected-access import contextlib +import pennylane as qml +from pennylane.measurements import Shots + @contextlib.contextmanager def set_shots(device, shots): @@ -40,6 +43,13 @@ def set_shots(device, shots): >>> set_shots(dev, shots=100)(lambda: dev.shots)() 100 """ + if isinstance(device, qml.devices.experimental.Device): + raise ValueError( + "The new device interface is not compatible with `set_shots`. " + "Set shots when calling the qnode or put the shots on the QuantumTape." + ) + if isinstance(shots, Shots): + shots = shots.shot_vector if shots.has_partitioned_shots else shots.total_shots if shots == device.shots: yield return diff --git a/pennylane/measurements/shots.py b/pennylane/measurements/shots.py index 88d6f473000..2d4a4bafe26 100644 --- a/pennylane/measurements/shots.py +++ b/pennylane/measurements/shots.py @@ -190,7 +190,11 @@ def __repr__(self): def __eq__(self, other): """Equality between Shot instances.""" - return self.total_shots == other.total_shots and self.shot_vector == other.shot_vector + return ( + isinstance(other, Shots) + and self.total_shots == other.total_shots + and self.shot_vector == other.shot_vector + ) def __hash__(self): """Hash for a given Shot instance.""" diff --git a/pennylane/qnode.py b/pennylane/qnode.py index 4fb65ff3e2f..4a02926e49b 100644 --- a/pennylane/qnode.py +++ b/pennylane/qnode.py @@ -27,7 +27,7 @@ import pennylane as qml from pennylane import Device from pennylane.interfaces import INTERFACE_MAP, SUPPORTED_INTERFACES, set_shots -from pennylane.measurements import ClassicalShadowMP, CountsMP, MidMeasureMP +from pennylane.measurements import ClassicalShadowMP, CountsMP, MidMeasureMP, Shots from pennylane.tape import QuantumTape, make_qscript @@ -49,6 +49,15 @@ def _convert_to_interface(res, interface): return qml.math.asarray(res, like=interface if interface != "tf" else "tensorflow") +# pylint: disable=protected-access +def _get_device_shots(device) -> Shots: + if isinstance(device, Device): + if device._shot_vector: + return Shots(device._raw_shot_sequence) + return Shots(device.shots) + return device.shots + + class QNode: """Represents a quantum node in the hybrid computational graph. @@ -693,7 +702,7 @@ def best_method_str(device, interface): @staticmethod def _validate_backprop_method(device, interface, shots=None): - if shots is not None or getattr(device, "shots", None) is not None: + if shots is not None or _get_device_shots(device): raise qml.QuantumFunctionError("Backpropagation is only supported when shots=None.") if isinstance(device, qml.devices.experimental.Device): @@ -830,14 +839,10 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches """Call the quantum function with a tape context, ensuring the operations get queued.""" old_interface = self.interface - if not self._qfunc_uses_shots_arg: - shots = kwargs.pop("shots", None) + if self._qfunc_uses_shots_arg: + shots = _get_device_shots(self._original_device) else: - shots = ( - self._original_device._raw_shot_sequence - if self._original_device._shot_vector - else self._original_device.shots - ) + shots = kwargs.pop("shots", _get_device_shots(self._original_device)) if old_interface == "auto": self.interface = qml.math.get_interface(*args, *list(kwargs.values())) @@ -935,20 +940,14 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: # pylint: disable=not-callable # update the gradient function - if isinstance(self._original_device, Device): - set_shots(self._original_device, override_shots)(self._update_gradient_fn)() + if isinstance(self._original_device, qml.Device): + set_shots(self._original_device, override_shots)(self._update_gradient_fn)( + shots=override_shots + ) else: self._update_gradient_fn(shots=override_shots) - else: - if isinstance(self._original_device, Device): - kwargs["shots"] = ( - self._original_device._raw_shot_sequence - if self._original_device._shot_vector - else self._original_device.shots - ) - else: - kwargs["shots"] = None + kwargs["shots"] = _get_device_shots(self._original_device) # construct the tape self.construct(args, kwargs) diff --git a/tests/devices/experimental/test_default_qubit_2.py b/tests/devices/experimental/test_default_qubit_2.py index f7106062f86..6e2e21d5f0d 100644 --- a/tests/devices/experimental/test_default_qubit_2.py +++ b/tests/devices/experimental/test_default_qubit_2.py @@ -29,6 +29,15 @@ def test_name(): assert DefaultQubit2().name == "default.qubit.2" +def test_shots(): + """Test the shots property of DefaultQubit2.""" + assert DefaultQubit2().shots == qml.measurements.Shots(None) + assert DefaultQubit2(shots=100).shots == qml.measurements.Shots(100) + + with pytest.raises(AttributeError): + DefaultQubit2().shots = 10 + + def test_debugger_attribute(): """Test that DefaultQubit2 has a debugger attribute and that it is `None`""" # pylint: disable=protected-access diff --git a/tests/devices/experimental/test_device_api.py b/tests/devices/experimental/test_device_api.py index fbfb6722d28..05c6eb449c3 100644 --- a/tests/devices/experimental/test_device_api.py +++ b/tests/devices/experimental/test_device_api.py @@ -49,6 +49,17 @@ def test_device_name(self): """Test the default name is the name of the class""" assert self.dev.name == "MinimalDevice" + def test_shots(self): + """Test default behavior for shots.""" + + assert self.dev.shots == qml.measurements.Shots(None) + + shots_dev = self.MinimalDevice(shots=100) + assert shots_dev.shots == qml.measurements.Shots(100) + + with pytest.raises(AttributeError): + self.dev.shots = 100 # pylint: disable=attribute-defined-outside-init + def test_tracker_set_on_initialization(self): """Test that a new tracker instance is initialized with the class.""" assert isinstance(self.dev.tracker, qml.Tracker) diff --git a/tests/interfaces/test_set_shots.py b/tests/interfaces/test_set_shots.py new file mode 100644 index 00000000000..f5233b4f565 --- /dev/null +++ b/tests/interfaces/test_set_shots.py @@ -0,0 +1,58 @@ +# Copyright 2018-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for interfaces.set_shots +""" + +import pytest + +import pennylane as qml +from pennylane.interfaces import set_shots +from pennylane.measurements import Shots + + +def test_shots_new_device_interface(): + """Test that calling set_shots on a device implementing the new interface leaves it + untouched. + """ + dev = qml.devices.experimental.DefaultQubit2() + with pytest.raises(ValueError): + with set_shots(dev, 10): + pass + + +def test_set_with_shots_class(): + """Test that shots can be set on the old device interface with a Shots class.""" + + dev = qml.devices.DefaultQubit(wires=1) + with set_shots(dev, Shots(10)): + assert dev.shots == 10 + + assert dev.shots is None + + shot_tuples = Shots((10, 10)) + with set_shots(dev, shot_tuples): + assert dev.shots == 20 + assert dev.shot_vector == list(shot_tuples.shot_vector) + + assert dev.shots is None + + +def test_shots_not_altered_if_False(): + """Test a value of False can be passed to shots, indicating to not override + shots on the device.""" + + dev = qml.devices.DefaultQubit(wires=1) + with set_shots(dev, False): + assert dev.shots is None diff --git a/tests/test_qnode.py b/tests/test_qnode.py index 2216d548c80..8489b48e29d 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -1426,7 +1426,7 @@ def test_warning_finite_shots_tape(self): qml.RZ(0.3, wires=0) qml.expval(qml.PauliZ(0)) - tape = QuantumScript.from_queue(q) + tape = QuantumScript.from_queue(q, shots=5) # no warning on the first execution cache = {} qml.execute([tape], dev, None, cache=cache) @@ -1696,16 +1696,6 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool: assert not kwargs assert new_dev is dev - def test_shots_not_set_on_device(self): - """Test that shots are not set on the device when override shots are passed on a call.""" - - def f(): - return qml.expval(qml.PauliZ(0)) - - qn = QNode(f, self.dev) - qn(shots=10) - assert getattr(self.dev, "shots", "not here") == "not here" - def test_shots_integration(self): """Test that shots provided at call time are passed through the workflow."""