-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move to specify shots on the QNode #4375
Changes from all commits
24b6b81
034a147
f0ba48a
37527d9
c25370a
4854bf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,6 +305,12 @@ def run_cnot(): | |
if name in plugin_devices: | ||
options = {} | ||
|
||
if "shots" in kwargs: | ||
warnings.warn( | ||
"In v0.33, the shots will always be determined by the QNode. Please specify shots there.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And maybe "You have provided a shots argument to a device. In v0.33, ..." |
||
UserWarning, | ||
) | ||
|
||
# load global configuration settings if available | ||
config = kwargs.get("config", default_config) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -384,6 +384,7 @@ def __init__( | |||||||
device: Union[Device, "qml.devices.experimental.Device"], | ||||||||
interface="auto", | ||||||||
diff_method="best", | ||||||||
shots=None, | ||||||||
expansion_strategy="gradient", | ||||||||
max_expansion=10, | ||||||||
grad_on_execution="best", | ||||||||
|
@@ -443,13 +444,26 @@ def __init__( | |||||||
UserWarning, | ||||||||
) | ||||||||
|
||||||||
if ( | ||||||||
hasattr(device, "shots") | ||||||||
and device.shots != shots | ||||||||
and device.shots is not None | ||||||||
and shots is None | ||||||||
): | ||||||||
warnings.warn( | ||||||||
"Shots should now be specified on the qnode instead of on the device." | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
"Using shots from the device. QNode specified shots will be used in v0.33." | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
) | ||||||||
shots = device.shots | ||||||||
|
||||||||
# input arguments | ||||||||
self.func = func | ||||||||
self.device = device | ||||||||
self._interface = interface | ||||||||
self.diff_method = diff_method | ||||||||
self.expansion_strategy = expansion_strategy | ||||||||
self.max_expansion = max_expansion | ||||||||
self._shots = shots | ||||||||
|
||||||||
# execution keyword arguments | ||||||||
self.execute_kwargs = { | ||||||||
|
@@ -490,6 +504,15 @@ def __repr__(self): | |||||||
self.diff_method, | ||||||||
) | ||||||||
|
||||||||
@property | ||||||||
def default_shots(self) -> qml.measurements.Shots: | ||||||||
"""The default shots to use for an execution. | ||||||||
|
||||||||
Can be overridden on a per-call basis using the ``qnode(*args, shots=new_shots, **kwargs)`` syntax. | ||||||||
|
||||||||
""" | ||||||||
return qml.measurements.Shots(self._shots) | ||||||||
|
||||||||
@property | ||||||||
def interface(self): | ||||||||
"""The interface used by the QNode""" | ||||||||
|
@@ -736,9 +759,7 @@ def _validate_backprop_method(device, interface, shots=None): | |||||||
expand_fn = device.expand_fn | ||||||||
batch_transform = device.batch_transform | ||||||||
|
||||||||
device = qml.device( | ||||||||
backprop_devices[mapped_interface], wires=device.wires, shots=device.shots | ||||||||
) | ||||||||
device = qml.device(backprop_devices[mapped_interface], wires=device.wires) | ||||||||
device.expand_fn = expand_fn | ||||||||
device.batch_transform = batch_transform | ||||||||
|
||||||||
|
@@ -826,23 +847,16 @@ def tape(self) -> QuantumTape: | |||||||
|
||||||||
qtape = tape # for backwards compatibility | ||||||||
|
||||||||
def construct(self, args, kwargs): # pylint: disable=too-many-branches | ||||||||
def construct(self, args, kwargs, override_shots="unset"): # pylint: disable=too-many-branches | ||||||||
"""Call the quantum function with a tape context, ensuring the operations get queued.""" | ||||||||
old_interface = self.interface | ||||||||
|
||||||||
if not self._qfunc_uses_shots_arg: | ||||||||
shots = kwargs.pop("shots", None) | ||||||||
else: | ||||||||
shots = ( | ||||||||
self._original_device._raw_shot_sequence | ||||||||
if self._original_device._shot_vector | ||||||||
else self._original_device.shots | ||||||||
) | ||||||||
override_shots = self._shots if isinstance(override_shots, str) else override_shots | ||||||||
|
||||||||
if old_interface == "auto": | ||||||||
self.interface = qml.math.get_interface(*args, *list(kwargs.values())) | ||||||||
|
||||||||
self._tape = make_qscript(self.func, shots)(*args, **kwargs) | ||||||||
self._tape = make_qscript(self.func, override_shots)(*args, **kwargs) | ||||||||
self._qfunc_output = self.tape._qfunc_output | ||||||||
|
||||||||
params = self.tape.get_parameters(trainable_only=False) | ||||||||
|
@@ -920,38 +934,24 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: | |||||||
self.interface = qml.math.get_interface(*args, *list(kwargs.values())) | ||||||||
self.device.tracker = self._original_device.tracker | ||||||||
|
||||||||
original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device] | ||||||||
override_shots = self._shots | ||||||||
if not self._qfunc_uses_shots_arg: | ||||||||
# If shots specified in call but not in qfunc signature, | ||||||||
# interpret it as device shots value for this call. | ||||||||
override_shots = kwargs.get("shots", False) | ||||||||
|
||||||||
if override_shots is not False: | ||||||||
if "shots" in kwargs: | ||||||||
override_shots = kwargs.pop("shots") | ||||||||
# Since shots has changed, we need to update the preferred gradient function. | ||||||||
# This is because the gradient function chosen at initialization may | ||||||||
# no longer be applicable. | ||||||||
|
||||||||
# store the initialization gradient function | ||||||||
original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device] | ||||||||
|
||||||||
# pylint: disable=not-callable | ||||||||
# update the gradient function | ||||||||
if isinstance(self._original_device, Device): | ||||||||
set_shots(self._original_device, override_shots)(self._update_gradient_fn)() | ||||||||
else: | ||||||||
self._update_gradient_fn(shots=override_shots) | ||||||||
|
||||||||
else: | ||||||||
if isinstance(self._original_device, Device): | ||||||||
kwargs["shots"] = ( | ||||||||
self._original_device._raw_shot_sequence | ||||||||
if self._original_device._shot_vector | ||||||||
else self._original_device.shots | ||||||||
) | ||||||||
else: | ||||||||
kwargs["shots"] = None | ||||||||
|
||||||||
# construct the tape | ||||||||
self.construct(args, kwargs) | ||||||||
self.construct(args, kwargs, override_shots=override_shots) | ||||||||
|
||||||||
cache = self.execute_kwargs.get("cache", False) | ||||||||
using_custom_cache = ( | ||||||||
|
@@ -1002,9 +1002,8 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: | |||||||
else: | ||||||||
res = type(self.tape._qfunc_output)(res) | ||||||||
|
||||||||
if override_shots is not False: | ||||||||
# restore the initialization gradient function | ||||||||
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn | ||||||||
# restore the initialization gradient function | ||||||||
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn | ||||||||
|
||||||||
self._update_original_device() | ||||||||
|
||||||||
|
@@ -1063,9 +1062,7 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: | |||||||
qfunc_output_type = type(self._qfunc_output) | ||||||||
return qfunc_output_type(res) | ||||||||
|
||||||||
if override_shots is not False: | ||||||||
# restore the initialization gradient function | ||||||||
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn | ||||||||
self.gradient_fn, self.gradient_kwargs, self.device = original_grad_fn | ||||||||
|
||||||||
self._update_original_device() | ||||||||
|
||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we say "Please specify shots either at QNode instantiation or when calling your QNode."