Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate TransformProgram with QNode #4404

Merged
merged 52 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4fa2a93
Draft structure
rmoyard Jul 10, 2023
d39731f
draf exec
rmoyard Jul 10, 2023
7ec3116
Simple execute
rmoyard Jul 10, 2023
5505c76
Update
rmoyard Jul 10, 2023
e34a26e
Merge branch 'master' into execute_transforms
rmoyard Jul 10, 2023
3813df9
Merge branch 'master' into execute_transforms
rmoyard Jul 11, 2023
c4dbf5e
More tests
rmoyard Jul 11, 2023
47ba75e
Merge branch 'execute_transforms' of https://github.com/PennyLaneAI/p…
rmoyard Jul 11, 2023
c69826f
Update
rmoyard Jul 12, 2023
509a799
Update exec
rmoyard Jul 14, 2023
ffa8c88
Pylint and black
rmoyard Jul 14, 2023
ce1c7f0
Update tests
rmoyard Jul 14, 2023
1fa3fe1
Update more tests
rmoyard Jul 14, 2023
950dfd5
Merge branch 'master' into execute_transforms
rmoyard Jul 14, 2023
b1e39cc
More tests
rmoyard Jul 14, 2023
d3da8dd
Merge branch 'execute_transforms' of https://github.com/PennyLaneAI/p…
rmoyard Jul 14, 2023
d13a9a4
changelog
rmoyard Jul 14, 2023
510a7f5
Coverage
rmoyard Jul 14, 2023
d202372
Cover fix
rmoyard Jul 14, 2023
b8caa0d
pylint
rmoyard Jul 14, 2023
95bbe1e
Pylint
rmoyard Jul 14, 2023
8e52b99
Pylint tests
rmoyard Jul 14, 2023
39838af
proposed changes to transform program integration
albi3ro Jul 17, 2023
64e1fff
oops
albi3ro Jul 17, 2023
dfee78f
add to legacy, remove cotransform support
albi3ro Jul 18, 2023
f00ecde
Merge branch 'master' into execute_transform_v2
albi3ro Jul 18, 2023
3e16a98
just transform program call component
albi3ro Jul 18, 2023
506be14
just transform program call component
albi3ro Jul 18, 2023
7d75350
Merge branch 'master' into execute_transform_v2
albi3ro Jul 18, 2023
f817aa2
no longer support cotransforms, fix _batch_postprocessing
albi3ro Jul 24, 2023
9b09692
some more testing
albi3ro Jul 24, 2023
763c377
Merge branch 'master' into execute_transform_v2
albi3ro Jul 24, 2023
6ac0752
test null postprocessing function
albi3ro Jul 25, 2023
deb551d
docstring, rename batch_slices to slices, black
albi3ro Jul 27, 2023
6ef4364
Apply suggestions from code review
albi3ro Jul 27, 2023
7ff1148
Merge branch 'master' into execute_transform_v2
albi3ro Jul 27, 2023
a2b202d
integrate transform program with qnode
albi3ro Jul 27, 2023
a7024b5
adding integration tests
albi3ro Jul 28, 2023
b485451
test modifications
albi3ro Jul 28, 2023
4ca30a8
Merge branch 'master' into qnode-transform-program
albi3ro Jul 28, 2023
69390d1
[skip ci] fiddling
albi3ro Jul 28, 2023
68d2d20
more testing
albi3ro Jul 28, 2023
dd85e26
Merge branch 'master' into qnode-transform-program
albi3ro Jul 28, 2023
7e5efe1
changelog entry
albi3ro Jul 28, 2023
b04c8db
merging
albi3ro Jul 28, 2023
778a670
add to execute, start on testing
albi3ro Jul 31, 2023
9cd4f45
add qml.execute tests
albi3ro Aug 2, 2023
00f4228
Merge branch 'master' into qnode-transform-program
albi3ro Aug 2, 2023
7c284cc
Update doc/releases/changelog-dev.md
albi3ro Aug 2, 2023
e2dbe4b
fix test
albi3ro Aug 2, 2023
316a6e1
Merge branch 'qnode-transform-program' of https://github.com/PennyLan…
albi3ro Aug 2, 2023
5a3cbc7
Merge branch 'master' into qnode-transform-program
rmoyard Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,31 @@
issue, say using JAX, TensorFlow, Torch, try setting `max_workers` to `None`.
[(#4319)](https://github.com/PennyLaneAI/pennylane/pull/4319)

* Transform Programs are now integrated with the `QNode`.
[(#4404)](https://github.com/PennyLaneAI/pennylane/pull/4404)

```
def null_postprocessing(results: qml.typing.ResultBatch) -> qml.typing.Result:
return results[0]

@qml.transforms.core.transform
def scale_shots(tape: qml.tape.QuantumTape, shot_scaling) -> (Tuple[qml.tape.QuantumTape], Callable):
new_shots = tape.shots.total_shots * shot_scaling
new_tape = qml.tape.QuantumScript(tape.operations, tape.measurements, shots=new_shots)
return (new_tape, ), null_postprocessing

dev = qml.devices.experimental.DefaultQubit2()

@partial(scale_shots, shot_scaling=2)
@qml.qnode(dev, interface=None)
def circuit():
return qml.sample(wires=0)

```

>>> circuit(shots=1)
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
array([False, False])

<h3>Improvements 🛠</h3>

* Transform Programs, `qml.transforms.core.TransformProgram`, can now be called on a batch of circuits
Expand Down
31 changes: 23 additions & 8 deletions pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def execute(
device: device_type,
gradient_fn: Optional[Union[Callable, str]] = None,
interface="auto",
transform_program=None,
grad_on_execution="best",
gradient_kwargs=None,
cache: Union[bool, dict, Cache] = True,
Expand Down Expand Up @@ -430,6 +431,7 @@ def cost_fn(params, x):
)

### Specifying and preprocessing variables ####
transform_program = transform_program or qml.transforms.core.TransformProgram()

if interface == "auto":
params = []
Expand Down Expand Up @@ -465,6 +467,7 @@ def cost_fn(params, x):

#### Executing the configured setup #####

tapes, program_post_processing = transform_program(tapes)
tapes, batch_fn, config = _batch_transform(
tapes, device, config, override_shots, device_batch_transform
)
Expand All @@ -491,7 +494,8 @@ def cost_fn(params, x):
pass_kwargs=new_device_interface,
)
results = cached_execute_fn(tapes, execution_config=config)
return batch_fn(results)
results = batch_fn(results)
return program_post_processing(results)

# the default execution function is batch_execute
# use qml.interfaces so that mocker can spy on it during testing
Expand Down Expand Up @@ -621,7 +625,7 @@ def gradient_fn(internal_tapes):
elif mapped_interface == "jax":
_execute = _get_jax_execute_fn(interface, tapes)

res = _execute(
results = _execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
)

Expand All @@ -631,14 +635,16 @@ def gradient_fn(internal_tapes):
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
) from e

return batch_fn(res)
results = batch_fn(results)
return program_post_processing(results)


def _execute_legacy(
tapes: Sequence[QuantumTape],
device: device_type,
gradient_fn: Callable = None,
interface="auto",
transform_program=None,
mode="best",
gradient_kwargs=None,
cache=True,
Expand Down Expand Up @@ -754,6 +760,9 @@ def cost_fn(params, x):
if isinstance(device, qml.devices.experimental.Device):
raise ValueError("New device interface only works with return types enabled.")

transform_program = transform_program or qml.transforms.core.TransformProgram()
tapes, program_post_processing = transform_program(tapes)

if interface == "auto":
params = []
for tape in tapes:
Expand Down Expand Up @@ -782,24 +791,27 @@ def cost_fn(params, x):
if gradient_fn is None:
# don't unwrap if it's an interface device
if "passthru_interface" in device.capabilities():
return batch_fn(
results = batch_fn(
qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)
)
return program_post_processing(results)
unwrapped_tapes = tuple(qml.transforms.convert_to_numpy_parameters(t) for t in tapes)
res = qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(unwrapped_tapes)

return batch_fn(res)
results = batch_fn(res)
return program_post_processing(results)

if gradient_fn == "backprop" or interface is None:
return batch_fn(
results = batch_fn(
qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)
)
return program_post_processing(results)

# the default execution function is batch_execute
execute_fn = qml.interfaces.cache_execute(batch_execute, cache, expand_fn=expand_fn)
Expand Down Expand Up @@ -873,9 +885,12 @@ def cost_fn(params, x):
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
) from e

res = _execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff)
results = _execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
)

return batch_fn(res)
results = batch_fn(results)
return program_post_processing(results)


def _get_jax_execute_fn(interface: str, tapes: Sequence[QuantumTape]):
Expand Down
8 changes: 6 additions & 2 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,12 +963,14 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
if qml.active_return():
if "mode" in self.execute_kwargs:
self.execute_kwargs.pop("mode")

# pylint: disable=unexpected-keyword-arg
res = qml.execute(
[self.tape],
(self._tape,),
device=self.device,
gradient_fn=self.gradient_fn,
interface=self.interface,
transform_program=self.transform_program,
gradient_kwargs=self.gradient_kwargs,
override_shots=override_shots,
**self.execute_kwargs,
Expand Down Expand Up @@ -1018,11 +1020,13 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
grad_on_execution = "best"
self.execute_kwargs["grad_on_execution"] = grad_on_execution
# pylint: disable=unexpected-keyword-arg

res = qml.execute(
[self.tape],
(self._tape,),
device=self.device,
gradient_fn=self.gradient_fn,
interface=self.interface,
transform_program=self._transform_program,
gradient_kwargs=self.gradient_kwargs,
override_shots=override_shots,
**self.execute_kwargs,
Expand Down
1 change: 1 addition & 0 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def hash(self):
fingerprint.extend(op.hash for op in self.operations)
fingerprint.extend(m.hash for m in self.measurements)
fingerprint.extend(self.trainable_params)
fingerprint.extend(self.shots)
return hash(tuple(fingerprint))

def __iter__(self):
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
This module contains the transform function to make your custom transforms compatible with qfunc and QNodes.
"""
from typing import get_type_hints, Sequence, Callable, List, Tuple
from typing import get_type_hints, Sequence, List, Tuple, Callable
import pennylane as qml
from .transform_dispatcher import TransformDispatcher, TransformError

Expand Down Expand Up @@ -156,7 +156,7 @@ def _transform_signature_check(signature):
"pennylane.tape.tape.QuantumTape], <built-in function callable>)"
)

if not ret[0] in (
if ret[0] not in (
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
Sequence[qml.tape.QuantumTape],
List[qml.tape.QuantumTape],
Tuple[qml.tape.QuantumTape],
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def __init__(
self, transform, args=None, kwargs=None, classical_cotransform=None, is_informative=False
): # pylint:disable=redefined-outer-name,too-many-arguments
self._transform = transform
self._args = args if args else []
self._kwargs = kwargs if kwargs else {}
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative

Expand Down
Loading
Loading