Skip to content

Commit

Permalink
Adds shots to experimental device interface and integrate with QNode (#…
Browse files Browse the repository at this point in the history
…4388)

* default shots on new device interface

* make fewer changes and clean stuff up later

* fix tests, lint, and sphinx

* Update doc/releases/changelog-dev.md

* Update tests/interfaces/test_set_shots.py

Co-authored-by: Matthew Silverman <[email protected]>

* Update tests/devices/experimental/test_default_qubit_2.py

* make set_shots error with new device, shots type hinting

* Update doc/releases/changelog-dev.md

Co-authored-by: Tom Bromley <[email protected]>

* Update pennylane/interfaces/set_shots.py

Co-authored-by: Matthew Silverman <[email protected]>

* Apply suggestions from code review

* Update tests/devices/experimental/test_default_qubit_2.py

* merge problem

* black

* Update doc/releases/changelog-dev.md

* Update pennylane/qnode.py

Co-authored-by: Edward Jiang <[email protected]>

---------

Co-authored-by: Matthew Silverman <[email protected]>
Co-authored-by: Tom Bromley <[email protected]>
Co-authored-by: Edward Jiang <[email protected]>
  • Loading branch information
4 people authored Aug 2, 2023
1 parent 52ccdba commit e9fb43b
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 56 deletions.
10 changes: 10 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@
* When given a callable, `qml.ctrl` now does its custom pre-processing on all queued operators from the callable.
[(#4370)](https://github.com/PennyLaneAI/pennylane/pull/4370)


* `qml.interfaces.set_shots` accepts `Shots` object as well as `int`'s and tuples of `int`'s.
[(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388)


* `pennylane.devices.experimental.Device` now accepts a shots keyword argument and has a `shots`
property. This property is merely used to set defaults for a workflow, and does not directly
influence the number of shots used in executions or derivatives.
[(#4388)](https://github.com/PennyLaneAI/pennylane/pull/4388)

* PennyLane no longer directly relies on `Operator.__eq__`.
[(#4398)](https://github.com/PennyLaneAI/pennylane/pull/4398)

Expand Down
9 changes: 5 additions & 4 deletions pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@
class DefaultQubit2(Device):
"""A PennyLane device written in Python and capable of backpropagation derivatives.
Keyword Args:
Args:
shots (int, Sequence[int], Sequence[Union[int, Sequence[int]]]): The default number of shots to use in executions involving
this device.
seed="global" (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng`` or
a request to seed from numpy's global random number generator.
The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None``
will pull a seed from the OS entropy.
max_workers=None (int): A ``ProcessPoolExecutor`` executes tapes asynchronously
using a pool of at most ``max_workers`` processes. If ``max_workers`` is ``None``,
only the current process executes tapes. If you experience any
Expand Down Expand Up @@ -137,8 +138,8 @@ def name(self):
"""The name of the device."""
return "default.qubit.2"

def __init__(self, seed="global", max_workers=None) -> None:
super().__init__()
def __init__(self, shots=None, seed="global", max_workers=None) -> None:
super().__init__(shots=shots)
self._max_workers = max_workers
seed = np.random.randint(0, high=10000000) if seed == "global" else seed
self._rng = np.random.default_rng(seed)
Expand Down
14 changes: 13 additions & 1 deletion pennylane/devices/experimental/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from numbers import Number
from typing import Callable, Union, Sequence, Tuple, Optional

from pennylane.measurements import Shots
from pennylane.tape import QuantumTape, QuantumScript
from pennylane.typing import Result, ResultBatch
from pennylane import Tracker
Expand Down Expand Up @@ -152,9 +153,20 @@ def name(self) -> str:
self.tracker.record()
"""

def __init__(self) -> None:
def __init__(self, shots=None) -> None:
# each instance should have its own Tracker.
self.tracker = Tracker()
self._shots = Shots(shots)

@property
def shots(self) -> Shots:
"""Default shots for execution workflows containing this device.
Note that the device itself should **always** pull shots from the provided :class:`~.QuantumTape` and its
:attr:`~.QuantumTape.shots`, not from this property. This property is used to provide a default at the start of a workflow.
"""
return self._shots

def preprocess(
self,
Expand Down
20 changes: 1 addition & 19 deletions pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
# pylint: disable=unused-argument,unnecessary-lambda-assignment,inconsistent-return-statements,
# pylint: disable=too-many-statements, invalid-unary-operand-type, function-redefined

import inspect
import warnings
from contextlib import _GeneratorContextManager
from functools import wraps, partial
from typing import Callable, Sequence, Optional, Union, Tuple

Expand Down Expand Up @@ -244,23 +242,7 @@ def wrapper(tapes: Sequence[QuantumTape], **kwargs):
# Tape exists within the cache, store the cached result
cached_results[i] = cache[hashes[i]]

# Introspect the set_shots decorator of the input function:
# warn the user in case of finite shots with cached results
finite_shots = False

closure = inspect.getclosurevars(fn).nonlocals
if "original_fn" in closure: # deal with expand_fn wrapper above
closure = inspect.getclosurevars(closure["original_fn"]).nonlocals

# retrieve the captured context manager instance (for set_shots)
if "self" in closure and isinstance(closure["self"], _GeneratorContextManager):
# retrieve the shots from the arguments or device instance
if closure["self"].func.__name__ == "set_shots":
dev, shots = closure["self"].args
shots = dev.shots if shots is False else shots
finite_shots = isinstance(shots, int)

if finite_shots and getattr(cache, "_persistent_cache", True):
if tape.shots and getattr(cache, "_persistent_cache", True):
warnings.warn(
"Cached execution with finite shots detected!\n"
"Note that samples as well as all noisy quantities computed via sampling "
Expand Down
10 changes: 10 additions & 0 deletions pennylane/interfaces/set_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# pylint: disable=protected-access
import contextlib

import pennylane as qml
from pennylane.measurements import Shots


@contextlib.contextmanager
def set_shots(device, shots):
Expand All @@ -40,6 +43,13 @@ def set_shots(device, shots):
>>> set_shots(dev, shots=100)(lambda: dev.shots)()
100
"""
if isinstance(device, qml.devices.experimental.Device):
raise ValueError(
"The new device interface is not compatible with `set_shots`. "
"Set shots when calling the qnode or put the shots on the QuantumTape."
)
if isinstance(shots, Shots):
shots = shots.shot_vector if shots.has_partitioned_shots else shots.total_shots
if shots == device.shots:
yield
return
Expand Down
6 changes: 5 additions & 1 deletion pennylane/measurements/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,11 @@ def __repr__(self):

def __eq__(self, other):
"""Equality between Shot instances."""
return self.total_shots == other.total_shots and self.shot_vector == other.shot_vector
return (
isinstance(other, Shots)
and self.total_shots == other.total_shots
and self.shot_vector == other.shot_vector
)

def __hash__(self):
"""Hash for a given Shot instance."""
Expand Down
39 changes: 19 additions & 20 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pennylane as qml
from pennylane import Device
from pennylane.interfaces import INTERFACE_MAP, SUPPORTED_INTERFACES, set_shots
from pennylane.measurements import ClassicalShadowMP, CountsMP, MidMeasureMP
from pennylane.measurements import ClassicalShadowMP, CountsMP, MidMeasureMP, Shots
from pennylane.tape import QuantumTape, make_qscript


Expand All @@ -49,6 +49,15 @@ def _convert_to_interface(res, interface):
return qml.math.asarray(res, like=interface if interface != "tf" else "tensorflow")


# pylint: disable=protected-access
def _get_device_shots(device) -> Shots:
if isinstance(device, Device):
if device._shot_vector:
return Shots(device._raw_shot_sequence)
return Shots(device.shots)
return device.shots


class QNode:
"""Represents a quantum node in the hybrid computational graph.
Expand Down Expand Up @@ -693,7 +702,7 @@ def best_method_str(device, interface):

@staticmethod
def _validate_backprop_method(device, interface, shots=None):
if shots is not None or getattr(device, "shots", None) is not None:
if shots is not None or _get_device_shots(device):
raise qml.QuantumFunctionError("Backpropagation is only supported when shots=None.")

if isinstance(device, qml.devices.experimental.Device):
Expand Down Expand Up @@ -830,14 +839,10 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches
"""Call the quantum function with a tape context, ensuring the operations get queued."""
old_interface = self.interface

if not self._qfunc_uses_shots_arg:
shots = kwargs.pop("shots", None)
if self._qfunc_uses_shots_arg:
shots = _get_device_shots(self._original_device)
else:
shots = (
self._original_device._raw_shot_sequence
if self._original_device._shot_vector
else self._original_device.shots
)
shots = kwargs.pop("shots", _get_device_shots(self._original_device))

if old_interface == "auto":
self.interface = qml.math.get_interface(*args, *list(kwargs.values()))
Expand Down Expand Up @@ -935,20 +940,14 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:

# pylint: disable=not-callable
# update the gradient function
if isinstance(self._original_device, Device):
set_shots(self._original_device, override_shots)(self._update_gradient_fn)()
if isinstance(self._original_device, qml.Device):
set_shots(self._original_device, override_shots)(self._update_gradient_fn)(
shots=override_shots
)
else:
self._update_gradient_fn(shots=override_shots)

else:
if isinstance(self._original_device, Device):
kwargs["shots"] = (
self._original_device._raw_shot_sequence
if self._original_device._shot_vector
else self._original_device.shots
)
else:
kwargs["shots"] = None
kwargs["shots"] = _get_device_shots(self._original_device)

# construct the tape
self.construct(args, kwargs)
Expand Down
9 changes: 9 additions & 0 deletions tests/devices/experimental/test_default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def test_name():
assert DefaultQubit2().name == "default.qubit.2"


def test_shots():
"""Test the shots property of DefaultQubit2."""
assert DefaultQubit2().shots == qml.measurements.Shots(None)
assert DefaultQubit2(shots=100).shots == qml.measurements.Shots(100)

with pytest.raises(AttributeError):
DefaultQubit2().shots = 10


def test_debugger_attribute():
"""Test that DefaultQubit2 has a debugger attribute and that it is `None`"""
# pylint: disable=protected-access
Expand Down
11 changes: 11 additions & 0 deletions tests/devices/experimental/test_device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_device_name(self):
"""Test the default name is the name of the class"""
assert self.dev.name == "MinimalDevice"

def test_shots(self):
"""Test default behavior for shots."""

assert self.dev.shots == qml.measurements.Shots(None)

shots_dev = self.MinimalDevice(shots=100)
assert shots_dev.shots == qml.measurements.Shots(100)

with pytest.raises(AttributeError):
self.dev.shots = 100 # pylint: disable=attribute-defined-outside-init

def test_tracker_set_on_initialization(self):
"""Test that a new tracker instance is initialized with the class."""
assert isinstance(self.dev.tracker, qml.Tracker)
Expand Down
58 changes: 58 additions & 0 deletions tests/interfaces/test_set_shots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for interfaces.set_shots
"""

import pytest

import pennylane as qml
from pennylane.interfaces import set_shots
from pennylane.measurements import Shots


def test_shots_new_device_interface():
"""Test that calling set_shots on a device implementing the new interface leaves it
untouched.
"""
dev = qml.devices.experimental.DefaultQubit2()
with pytest.raises(ValueError):
with set_shots(dev, 10):
pass


def test_set_with_shots_class():
"""Test that shots can be set on the old device interface with a Shots class."""

dev = qml.devices.DefaultQubit(wires=1)
with set_shots(dev, Shots(10)):
assert dev.shots == 10

assert dev.shots is None

shot_tuples = Shots((10, 10))
with set_shots(dev, shot_tuples):
assert dev.shots == 20
assert dev.shot_vector == list(shot_tuples.shot_vector)

assert dev.shots is None


def test_shots_not_altered_if_False():
"""Test a value of False can be passed to shots, indicating to not override
shots on the device."""

dev = qml.devices.DefaultQubit(wires=1)
with set_shots(dev, False):
assert dev.shots is None
12 changes: 1 addition & 11 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def test_warning_finite_shots_tape(self):
qml.RZ(0.3, wires=0)
qml.expval(qml.PauliZ(0))

tape = QuantumScript.from_queue(q)
tape = QuantumScript.from_queue(q, shots=5)
# no warning on the first execution
cache = {}
qml.execute([tape], dev, None, cache=cache)
Expand Down Expand Up @@ -1696,16 +1696,6 @@ def supports_derivatives(self, execution_config=None, circuit=None) -> bool:
assert not kwargs
assert new_dev is dev

def test_shots_not_set_on_device(self):
"""Test that shots are not set on the device when override shots are passed on a call."""

def f():
return qml.expval(qml.PauliZ(0))

qn = QNode(f, self.dev)
qn(shots=10)
assert getattr(self.dev, "shots", "not here") == "not here"

def test_shots_integration(self):
"""Test that shots provided at call time are passed through the workflow."""

Expand Down

0 comments on commit e9fb43b

Please sign in to comment.