Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds shots to experimental device interface and integrate with QNode #4388

Merged
merged 22 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
01df258
default shots on new device interface
albi3ro Jul 24, 2023
0d07d06
make fewer changes and clean stuff up later
albi3ro Jul 25, 2023
399e492
fix tests, lint, and sphinx
albi3ro Jul 25, 2023
1426fc1
Merge branch 'master' into shots-on-new-device
albi3ro Jul 25, 2023
80fd531
Update doc/releases/changelog-dev.md
albi3ro Jul 31, 2023
6a567d8
Merge branch 'master' into shots-on-new-device
albi3ro Jul 31, 2023
49e856e
Merge branch 'master' into shots-on-new-device
albi3ro Jul 31, 2023
cf02d83
Update tests/interfaces/test_set_shots.py
albi3ro Jul 31, 2023
2e776ac
Update tests/devices/experimental/test_default_qubit_2.py
albi3ro Jul 31, 2023
a4f4c25
make set_shots error with new device, shots type hinting
albi3ro Aug 1, 2023
63a89ff
merge conflicts
albi3ro Aug 1, 2023
18a43cf
Update doc/releases/changelog-dev.md
albi3ro Aug 1, 2023
7fd30ad
Update pennylane/interfaces/set_shots.py
albi3ro Aug 1, 2023
7641dbc
Apply suggestions from code review
albi3ro Aug 1, 2023
b3a4600
Merge branch 'master' into shots-on-new-device
albi3ro Aug 1, 2023
51dbceb
Update tests/devices/experimental/test_default_qubit_2.py
albi3ro Aug 1, 2023
fc61700
merge problem
albi3ro Aug 1, 2023
560ffbc
black
albi3ro Aug 1, 2023
a049522
Merge branch 'shots-on-new-device' of https://github.com/PennyLaneAI/…
albi3ro Aug 1, 2023
6cf23c8
Update doc/releases/changelog-dev.md
albi3ro Aug 1, 2023
0bd6cf7
Update pennylane/qnode.py
albi3ro Aug 2, 2023
8b5f1b2
Merge branch 'master' into shots-on-new-device
timmysilv Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@
* When given a callable, `qml.ctrl` now does its custom pre-processing on all queued operators from the callable.
[(#4370)](https://github.com/PennyLaneAI/pennylane/pull/4370)

* `qml.interfaces.set_shots` can now accepts shots specified as a `qml.measurements.Shots` object.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
If a `pennylane.devices.experimental.Device` is provided, the `set_shots` will no longer raise an error, but will simply return without doing anything.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
[(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388)

* `pennylane.devices.experimental.Device` now accepts a shots keyword argument and has a `shots`
property. This property is merely used to set defaults for a workflow, and does not directly influence the number of shots used in executions or derivatives.
[(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388)

* PennyLane no longer directly relies on `Operator.__eq__`.
[(#4398)](https://github.com/PennyLaneAI/pennylane/pull/4398)

Expand Down
7 changes: 4 additions & 3 deletions pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ class DefaultQubit2(Device):
"""A PennyLane device written in Python and capable of backpropagation derivatives.

Keyword Args:
shots (int, Tuple[int], Tuple[Tuple[int]]): The default number of shots to use in executions involving
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
this device.
seed="global" (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng`` or
a request to seed from numpy's global random number generator.
The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None``
will pull a seed from the OS entropy.

max_workers=None (int): A ``ProcessPoolExecutor`` executes tapes asynchronously
using a pool of at most ``max_workers`` processes. If ``max_workers`` is ``None``,
only the current process executes tapes. If you experience any
Expand Down Expand Up @@ -137,8 +138,8 @@ def name(self):
"""The name of the device."""
return "default.qubit.2"

def __init__(self, seed="global", max_workers=None) -> None:
super().__init__()
def __init__(self, shots=None, seed="global", max_workers=None) -> None:
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(shots=shots)
self._max_workers = max_workers
seed = np.random.randint(0, high=10000000) if seed == "global" else seed
self._rng = np.random.default_rng(seed)
Expand Down
14 changes: 13 additions & 1 deletion pennylane/devices/experimental/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from numbers import Number
from typing import Callable, Union, Sequence, Tuple, Optional

from pennylane.measurements import Shots
from pennylane.tape import QuantumTape, QuantumScript
from pennylane.typing import Result, ResultBatch
from pennylane import Tracker
Expand Down Expand Up @@ -152,9 +153,20 @@ def name(self) -> str:
self.tracker.record()
"""

def __init__(self) -> None:
def __init__(self, shots=None) -> None:
# each instance should have its own Tracker.
self.tracker = Tracker()
self._shots = Shots(shots)

@property
def shots(self) -> Shots:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Default shots for execution workflows containing this device.

Note that the device itself should **always** pull shots from the provided :class:`~.QuantumTape` and its
:attr:`~.QuantumTape.shots`, not from this property. This property is used to provide a default at the start of a workflow.

"""
return self._shots

def preprocess(
self,
Expand Down
20 changes: 1 addition & 19 deletions pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
# pylint: disable=unused-argument,unnecessary-lambda-assignment,inconsistent-return-statements,
# pylint: disable=too-many-statements, invalid-unary-operand-type, function-redefined

import inspect
import warnings
from contextlib import _GeneratorContextManager
from functools import wraps, partial
from typing import Callable, Sequence, Optional, Union, Tuple

Expand Down Expand Up @@ -244,23 +242,7 @@ def wrapper(tapes: Sequence[QuantumTape], **kwargs):
# Tape exists within the cache, store the cached result
cached_results[i] = cache[hashes[i]]

# Introspect the set_shots decorator of the input function:
# warn the user in case of finite shots with cached results
finite_shots = False

closure = inspect.getclosurevars(fn).nonlocals
if "original_fn" in closure: # deal with expand_fn wrapper above
closure = inspect.getclosurevars(closure["original_fn"]).nonlocals

# retrieve the captured context manager instance (for set_shots)
if "self" in closure and isinstance(closure["self"], _GeneratorContextManager):
# retrieve the shots from the arguments or device instance
if closure["self"].func.__name__ == "set_shots":
dev, shots = closure["self"].args
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
shots = dev.shots if shots is False else shots
finite_shots = isinstance(shots, int)

if finite_shots and getattr(cache, "_persistent_cache", True):
if tape.shots and getattr(cache, "_persistent_cache", True):
warnings.warn(
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
Expand Down
8 changes: 8 additions & 0 deletions pennylane/interfaces/set_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# pylint: disable=protected-access
import contextlib

import pennylane as qml
from pennylane.measurements import Shots


@contextlib.contextmanager
def set_shots(device, shots):
Expand All @@ -40,6 +43,11 @@ def set_shots(device, shots):
>>> set_shots(dev, shots=100)(lambda: dev.shots)()
100
"""
if isinstance(device, qml.devices.experimental.Device):
timmysilv marked this conversation as resolved.
Show resolved Hide resolved
yield
return
if isinstance(shots, Shots):
shots = shots.shot_vector if shots.has_partitioned_shots else shots.total_shots
if shots == device.shots:
yield
return
Expand Down
6 changes: 5 additions & 1 deletion pennylane/measurements/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,11 @@ def __repr__(self):

def __eq__(self, other):
"""Equality between Shot instances."""
return self.total_shots == other.total_shots and self.shot_vector == other.shot_vector
return (
isinstance(other, Shots)
and self.total_shots == other.total_shots
and self.shot_vector == other.shot_vector
)

def __hash__(self):
"""Hash for a given Shot instance."""
Expand Down
40 changes: 19 additions & 21 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pennylane as qml
from pennylane import Device
from pennylane.interfaces import INTERFACE_MAP, SUPPORTED_INTERFACES, set_shots
from pennylane.measurements import ClassicalShadowMP, CountsMP, MidMeasureMP
from pennylane.measurements import ClassicalShadowMP, CountsMP, MidMeasureMP, Shots
from pennylane.tape import QuantumTape, make_qscript


Expand All @@ -49,6 +49,15 @@ def _convert_to_interface(res, interface):
return qml.math.asarray(res, like=interface if interface != "tf" else "tensorflow")


# pylint: disable=protected-access
def _get_device_shots(device) -> Shots:
if isinstance(device, Device):
if device._shot_vector:
return Shots(device._raw_shot_sequence)
return Shots(device.shots)
return device.shots


class QNode:
"""Represents a quantum node in the hybrid computational graph.

Expand Down Expand Up @@ -693,7 +702,8 @@ def best_method_str(device, interface):

@staticmethod
def _validate_backprop_method(device, interface, shots=None):
if shots is not None or getattr(device, "shots", None) is not None:
if shots is not None or _get_device_shots(device):
print(shots, _get_device_shots(device))
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
raise qml.QuantumFunctionError("Backpropagation is only supported when shots=None.")

if isinstance(device, qml.devices.experimental.Device):
Expand Down Expand Up @@ -830,14 +840,10 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches
"""Call the quantum function with a tape context, ensuring the operations get queued."""
old_interface = self.interface

if not self._qfunc_uses_shots_arg:
shots = kwargs.pop("shots", None)
if self._qfunc_uses_shots_arg:
shots = _get_device_shots(self._original_device)
else:
shots = (
self._original_device._raw_shot_sequence
if self._original_device._shot_vector
else self._original_device.shots
)
shots = kwargs.pop("shots", _get_device_shots(self._original_device))
eddddddy marked this conversation as resolved.
Show resolved Hide resolved

if old_interface == "auto":
self.interface = qml.math.get_interface(*args, *list(kwargs.values()))
Expand Down Expand Up @@ -935,20 +941,12 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:

# pylint: disable=not-callable
# update the gradient function
if isinstance(self._original_device, Device):
set_shots(self._original_device, override_shots)(self._update_gradient_fn)()
else:
self._update_gradient_fn(shots=override_shots)
set_shots(self._original_device, override_shots)(self._update_gradient_fn)(
shots=override_shots
)

else:
if isinstance(self._original_device, Device):
kwargs["shots"] = (
self._original_device._raw_shot_sequence
if self._original_device._shot_vector
else self._original_device.shots
)
else:
kwargs["shots"] = None
kwargs["shots"] = _get_device_shots(self._original_device)

# construct the tape
self.construct(args, kwargs)
Expand Down
46 changes: 46 additions & 0 deletions tests/devices/experimental/test_default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,52 @@ def test_name():
assert DefaultQubit2().name == "default.qubit.2"


def test_shots():
"""Test the shots property of DefaultQubit2."""
assert DefaultQubit2().shots == qml.measurements.Shots(None)
assert DefaultQubit2(shots=100).shots == qml.measurements.Shots(100)

with pytest.raises(AttributeError):
DefaultQubit2().shots = 10


def test_no_jvp_functionality():
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""Test that jvp is not supported on DefaultQubit2."""
dev = DefaultQubit2()

assert not dev.supports_jvp(ExecutionConfig())

with pytest.raises(NotImplementedError):
dev.compute_jvp(qml.tape.QuantumScript(), (10, 10))

with pytest.raises(NotImplementedError):
dev.execute_and_compute_jvp(qml.tape.QuantumScript(), (10, 10))


def test_no_vjp_functionality():
"""Test that vjp is not supported on DefaultQubit2."""
dev = DefaultQubit2()

assert not dev.supports_vjp(ExecutionConfig())

with pytest.raises(NotImplementedError):
dev.compute_vjp(qml.tape.QuantumScript(), (10.0, 10.0))

with pytest.raises(NotImplementedError):
dev.execute_and_compute_vjp(qml.tape.QuantumScript(), (10.0, 10.0))
albi3ro marked this conversation as resolved.
Show resolved Hide resolved


def test_no_device_derivatives():
"""Test that DefaultQubit2 currently doesn't support device derivatives."""
dev = DefaultQubit2()

with pytest.raises(NotImplementedError):
dev.compute_derivatives(qml.tape.QuantumScript())

with pytest.raises(NotImplementedError):
dev.execute_and_compute_derivatives(qml.tape.QuantumScript())


albi3ro marked this conversation as resolved.
Show resolved Hide resolved
def test_debugger_attribute():
"""Test that DefaultQubit2 has a debugger attribute and that it is `None`"""
# pylint: disable=protected-access
Expand Down
11 changes: 11 additions & 0 deletions tests/devices/experimental/test_device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_device_name(self):
"""Test the default name is the name of the class"""
assert self.dev.name == "MinimalDevice"

def test_shots(self):
"""Test default behavior for shots."""

assert self.dev.shots == qml.measurements.Shots(None)

shots_dev = self.MinimalDevice(shots=100)
assert shots_dev.shots == qml.measurements.Shots(100)

with pytest.raises(AttributeError):
self.dev.shots = 100 # pylint: disable=attribute-defined-outside-init

def test_tracker_set_on_initialization(self):
"""Test that a new tracker instance is initialized with the class."""
assert isinstance(self.dev.tracker, qml.Tracker)
Expand Down
55 changes: 55 additions & 0 deletions tests/interfaces/test_set_shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for interfaces.set_shots
"""

import pennylane as qml
from pennylane.interfaces import set_shots
from pennylane.measurements import Shots


def test_shots_new_device_interface():
"""Test that calling set shots on something in new new device interface leaves it
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
untouched.
"""
dev = qml.devices.experimental.DefaultQubit2()
with set_shots(dev, 10):
assert dev.shots == Shots(None)


def test_set_with_shots_class():
"""Test that shots can be set on the old device interface with a Shots class."""

dev = qml.devices.DefaultQubit(wires=1)
with set_shots(dev, Shots(10)):
assert dev.shots == 10

assert dev.shots is None

shot_tuples = Shots((10, 10))
with set_shots(dev, shot_tuples):
assert dev.shots == 20
assert dev.shot_vector == list(shot_tuples.shot_vector)

assert dev.shots is None


def test_shots_not_altered_if_False():
"""Test a value of False can be passed to shots, indicating to not override
shots on the device."""

dev = qml.devices.DefaultQubit(wires=1)
with set_shots(dev, False):
assert dev.shots is None
12 changes: 1 addition & 11 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def test_warning_finite_shots_tape(self):
qml.RZ(0.3, wires=0)
qml.expval(qml.PauliZ(0))

tape = QuantumScript.from_queue(q)
tape = QuantumScript.from_queue(q, shots=5)
# no warning on the first execution
cache = {}
qml.execute([tape], dev, None, cache=cache)
Expand Down Expand Up @@ -1696,16 +1696,6 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool:
assert not kwargs
assert new_dev is dev

def test_shots_not_set_on_device(self):
"""Test that shots are not set on the device when override shots are passed on a call."""

def f():
return qml.expval(qml.PauliZ(0))

qn = QNode(f, self.dev)
qn(shots=10)
assert getattr(self.dev, "shots", "not here") == "not here"

def test_shots_integration(self):
"""Test that shots provided at call time are passed through the workflow."""

Expand Down
Loading