Skip to content

Commit

Permalink
Integrate TransformProgram with QNode (#4404)
Browse files Browse the repository at this point in the history
* Draft structure

* draf exec

* Simple execute

* Update

* More tests

* Update

* Update exec

* Pylint and black

* Update tests

* Update more tests

* More tests

* changelog

* Coverage

* Cover fix

* pylint

* Pylint

* Pylint tests

* proposed changes to transform program integration

* oops

* add to legacy, remove cotransform support

* just transform program call component

* just transform program call component

* no longer support cotransforms, fix _batch_postprocessing

* some more testing

* test null postprocessing function

* docstring, rename batch_slices to slices, black

* Apply suggestions from code review

Co-authored-by: Matthew Silverman <[email protected]>

* integrate transform program with qnode

* adding integration tests

* test modifications

* [skip ci] fiddling

* more testing

* changelog entry

* add to execute, start on testing

* add qml.execute tests

* Update doc/releases/changelog-dev.md

Co-authored-by: Matthew Silverman <[email protected]>

* fix test

---------

Co-authored-by: rmoyard <[email protected]>
Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
3 people authored and mudit2812 committed Aug 3, 2023
1 parent a143a3b commit 6a8c9b1
Show file tree
Hide file tree
Showing 9 changed files with 475 additions and 62 deletions.
25 changes: 25 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,31 @@
issue, say using JAX, TensorFlow, Torch, try setting `max_workers` to `None`.
[(#4319)](https://github.com/PennyLaneAI/pennylane/pull/4319)

* Transform Programs are now integrated with the `QNode`.
[(#4404)](https://github.com/PennyLaneAI/pennylane/pull/4404)

```
def null_postprocessing(results: qml.typing.ResultBatch) -> qml.typing.Result:
return results[0]
@qml.transforms.core.transform
def scale_shots(tape: qml.tape.QuantumTape, shot_scaling) -> (Tuple[qml.tape.QuantumTape], Callable):
new_shots = tape.shots.total_shots * shot_scaling
new_tape = qml.tape.QuantumScript(tape.operations, tape.measurements, shots=new_shots)
return (new_tape, ), null_postprocessing
dev = qml.devices.experimental.DefaultQubit2()
@partial(scale_shots, shot_scaling=2)
@qml.qnode(dev, interface=None)
def circuit():
return qml.sample(wires=0)
```

>>> circuit(shots=1)
array([False, False])

<h3>Improvements 🛠</h3>

* Wires can now be reused after making a mid-circuit measurement on them.
Expand Down
31 changes: 23 additions & 8 deletions pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def execute(
device: device_type,
gradient_fn: Optional[Union[Callable, str]] = None,
interface="auto",
transform_program=None,
grad_on_execution="best",
gradient_kwargs=None,
cache: Union[bool, dict, Cache] = True,
Expand Down Expand Up @@ -430,6 +431,7 @@ def cost_fn(params, x):
)

### Specifying and preprocessing variables ####
transform_program = transform_program or qml.transforms.core.TransformProgram()

if interface == "auto":
params = []
Expand Down Expand Up @@ -465,6 +467,7 @@ def cost_fn(params, x):

#### Executing the configured setup #####

tapes, program_post_processing = transform_program(tapes)
tapes, batch_fn, config = _batch_transform(
tapes, device, config, override_shots, device_batch_transform
)
Expand All @@ -491,7 +494,8 @@ def cost_fn(params, x):
pass_kwargs=new_device_interface,
)
results = cached_execute_fn(tapes, execution_config=config)
return batch_fn(results)
results = batch_fn(results)
return program_post_processing(results)

# the default execution function is batch_execute
# use qml.interfaces so that mocker can spy on it during testing
Expand Down Expand Up @@ -621,7 +625,7 @@ def gradient_fn(internal_tapes):
elif mapped_interface == "jax":
_execute = _get_jax_execute_fn(interface, tapes)

res = _execute(
results = _execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
)

Expand All @@ -631,14 +635,16 @@ def gradient_fn(internal_tapes):
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
) from e

return batch_fn(res)
results = batch_fn(results)
return program_post_processing(results)


def _execute_legacy(
tapes: Sequence[QuantumTape],
device: device_type,
gradient_fn: Callable = None,
interface="auto",
transform_program=None,
mode="best",
gradient_kwargs=None,
cache=True,
Expand Down Expand Up @@ -754,6 +760,9 @@ def cost_fn(params, x):
if isinstance(device, qml.devices.experimental.Device):
raise ValueError("New device interface only works with return types enabled.")

transform_program = transform_program or qml.transforms.core.TransformProgram()
tapes, program_post_processing = transform_program(tapes)

if interface == "auto":
params = []
for tape in tapes:
Expand Down Expand Up @@ -782,24 +791,27 @@ def cost_fn(params, x):
if gradient_fn is None:
# don't unwrap if it's an interface device
if "passthru_interface" in device.capabilities():
return batch_fn(
results = batch_fn(
qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)
)
return program_post_processing(results)
unwrapped_tapes = tuple(qml.transforms.convert_to_numpy_parameters(t) for t in tapes)
res = qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(unwrapped_tapes)

return batch_fn(res)
results = batch_fn(res)
return program_post_processing(results)

if gradient_fn == "backprop" or interface is None:
return batch_fn(
results = batch_fn(
qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)
)
return program_post_processing(results)

# the default execution function is batch_execute
execute_fn = qml.interfaces.cache_execute(batch_execute, cache, expand_fn=expand_fn)
Expand Down Expand Up @@ -873,9 +885,12 @@ def cost_fn(params, x):
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
) from e

res = _execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff)
results = _execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
)

return batch_fn(res)
results = batch_fn(results)
return program_post_processing(results)


def _get_jax_execute_fn(interface: str, tapes: Sequence[QuantumTape]):
Expand Down
8 changes: 6 additions & 2 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,12 +963,14 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
if qml.active_return():
if "mode" in self.execute_kwargs:
self.execute_kwargs.pop("mode")

# pylint: disable=unexpected-keyword-arg
res = qml.execute(
[self.tape],
(self._tape,),
device=self.device,
gradient_fn=self.gradient_fn,
interface=self.interface,
transform_program=self.transform_program,
gradient_kwargs=self.gradient_kwargs,
override_shots=override_shots,
**self.execute_kwargs,
Expand Down Expand Up @@ -1018,11 +1020,13 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
grad_on_execution = "best"
self.execute_kwargs["grad_on_execution"] = grad_on_execution
# pylint: disable=unexpected-keyword-arg

res = qml.execute(
[self.tape],
(self._tape,),
device=self.device,
gradient_fn=self.gradient_fn,
interface=self.interface,
transform_program=self._transform_program,
gradient_kwargs=self.gradient_kwargs,
override_shots=override_shots,
**self.execute_kwargs,
Expand Down
1 change: 1 addition & 0 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def hash(self):
fingerprint.extend(op.hash for op in self.operations)
fingerprint.extend(m.hash for m in self.measurements)
fingerprint.extend(self.trainable_params)
fingerprint.extend(self.shots)
return hash(tuple(fingerprint))

def __iter__(self):
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
This module contains the transform function to make your custom transforms compatible with qfunc and QNodes.
"""
from typing import get_type_hints, Sequence, Callable, List, Tuple
from typing import get_type_hints, Sequence, List, Tuple, Callable
import pennylane as qml
from .transform_dispatcher import TransformDispatcher, TransformError

Expand Down Expand Up @@ -156,7 +156,7 @@ def _transform_signature_check(signature):
"pennylane.tape.tape.QuantumTape], <built-in function callable>)"
)

if not ret[0] in (
if ret[0] not in (
Sequence[qml.tape.QuantumTape],
List[qml.tape.QuantumTape],
Tuple[qml.tape.QuantumTape],
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def __init__(
self, transform, args=None, kwargs=None, classical_cotransform=None, is_informative=False
): # pylint:disable=redefined-outer-name,too-many-arguments
self._transform = transform
self._args = args if args else []
self._kwargs = kwargs if kwargs else {}
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative

Expand Down
Loading

0 comments on commit 6a8c9b1

Please sign in to comment.