Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate TransformProgram with QNode #4404

Merged
merged 52 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4fa2a93
Draft structure
rmoyard Jul 10, 2023
d39731f
draf exec
rmoyard Jul 10, 2023
7ec3116
Simple execute
rmoyard Jul 10, 2023
5505c76
Update
rmoyard Jul 10, 2023
e34a26e
Merge branch 'master' into execute_transforms
rmoyard Jul 10, 2023
3813df9
Merge branch 'master' into execute_transforms
rmoyard Jul 11, 2023
c4dbf5e
More tests
rmoyard Jul 11, 2023
47ba75e
Merge branch 'execute_transforms' of https://github.com/PennyLaneAI/p…
rmoyard Jul 11, 2023
c69826f
Update
rmoyard Jul 12, 2023
509a799
Update exec
rmoyard Jul 14, 2023
ffa8c88
Pylint and black
rmoyard Jul 14, 2023
ce1c7f0
Update tests
rmoyard Jul 14, 2023
1fa3fe1
Update more tests
rmoyard Jul 14, 2023
950dfd5
Merge branch 'master' into execute_transforms
rmoyard Jul 14, 2023
b1e39cc
More tests
rmoyard Jul 14, 2023
d3da8dd
Merge branch 'execute_transforms' of https://github.com/PennyLaneAI/p…
rmoyard Jul 14, 2023
d13a9a4
changelog
rmoyard Jul 14, 2023
510a7f5
Coverage
rmoyard Jul 14, 2023
d202372
Cover fix
rmoyard Jul 14, 2023
b8caa0d
pylint
rmoyard Jul 14, 2023
95bbe1e
Pylint
rmoyard Jul 14, 2023
8e52b99
Pylint tests
rmoyard Jul 14, 2023
39838af
proposed changes to transform program integration
albi3ro Jul 17, 2023
64e1fff
oops
albi3ro Jul 17, 2023
dfee78f
add to legacy, remove cotransform support
albi3ro Jul 18, 2023
f00ecde
Merge branch 'master' into execute_transform_v2
albi3ro Jul 18, 2023
3e16a98
just transform program call component
albi3ro Jul 18, 2023
506be14
just transform program call component
albi3ro Jul 18, 2023
7d75350
Merge branch 'master' into execute_transform_v2
albi3ro Jul 18, 2023
f817aa2
no longer support cotransforms, fix _batch_postprocessing
albi3ro Jul 24, 2023
9b09692
some more testing
albi3ro Jul 24, 2023
763c377
Merge branch 'master' into execute_transform_v2
albi3ro Jul 24, 2023
6ac0752
test null postprocessing function
albi3ro Jul 25, 2023
deb551d
docstring, rename batch_slices to slices, black
albi3ro Jul 27, 2023
6ef4364
Apply suggestions from code review
albi3ro Jul 27, 2023
7ff1148
Merge branch 'master' into execute_transform_v2
albi3ro Jul 27, 2023
a2b202d
integrate transform program with qnode
albi3ro Jul 27, 2023
a7024b5
adding integration tests
albi3ro Jul 28, 2023
b485451
test modifications
albi3ro Jul 28, 2023
4ca30a8
Merge branch 'master' into qnode-transform-program
albi3ro Jul 28, 2023
69390d1
[skip ci] fiddling
albi3ro Jul 28, 2023
68d2d20
more testing
albi3ro Jul 28, 2023
dd85e26
Merge branch 'master' into qnode-transform-program
albi3ro Jul 28, 2023
7e5efe1
changelog entry
albi3ro Jul 28, 2023
b04c8db
merging
albi3ro Jul 28, 2023
778a670
add to execute, start on testing
albi3ro Jul 31, 2023
9cd4f45
add qml.execute tests
albi3ro Aug 2, 2023
00f4228
Merge branch 'master' into qnode-transform-program
albi3ro Aug 2, 2023
7c284cc
Update doc/releases/changelog-dev.md
albi3ro Aug 2, 2023
e2dbe4b
fix test
albi3ro Aug 2, 2023
316a6e1
Merge branch 'qnode-transform-program' of https://github.com/PennyLan…
albi3ro Aug 2, 2023
5a3cbc7
Merge branch 'master' into qnode-transform-program
rmoyard Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,31 @@
issue, say using JAX, TensorFlow, Torch, try setting `max_workers` to `None`.
[(#4319)](https://github.com/PennyLaneAI/pennylane/pull/4319)

* Transform Programs are now integrated with the `QNode`.
[(#4404)](https://github.com/PennyLaneAI/pennylane/pull)
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

```
def null_postprocessing(results: qml.typing.ResultBatch) -> qml.typing.Result:
return results[0]

@qml.transforms.core.transform
def scale_shots(tape: qml.tape.QuantumTape, shot_scaling) -> (Tuple[qml.tape.QuantumTape], Callable):
new_shots = tape.shots.total_shots * shot_scaling
new_tape = qml.tape.QuantumScript(tape.operations, tape.measurements, shots=new_shots)
return (new_tape, ), null_postprocessing

dev = qml.devices.experimental.DefaultQubit2()

@partial(scale_shots, shot_scaling=2)
@qml.qnode(dev, interface=None)
def circuit():
return qml.sample(wires=0)

```

>>> circuit(shots=1)
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
array([False, False])

<h3>Improvements 🛠</h3>

* Transform Programs, `qml.transforms.core.TransformProgram`, can now be called on a batch of circuits
Expand Down
10 changes: 8 additions & 2 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,16 +964,19 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
if qml.active_return():
if "mode" in self.execute_kwargs:
self.execute_kwargs.pop("mode")

batch, post_processing = self.transform_program([self.tape])
# pylint: disable=unexpected-keyword-arg
res = qml.execute(
[self.tape],
batch,
device=self.device,
gradient_fn=self.gradient_fn,
interface=self.interface,
gradient_kwargs=self.gradient_kwargs,
override_shots=override_shots,
**self.execute_kwargs,
)
res = post_processing(res)

res = res[0]

Expand Down Expand Up @@ -1015,15 +1018,18 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
grad_on_execution = "best"
self.execute_kwargs["grad_on_execution"] = grad_on_execution
# pylint: disable=unexpected-keyword-arg

batch, postprocessing = self.transform_program([self.tape])
res = qml.execute(
[self.tape],
batch,
device=self.device,
gradient_fn=self.gradient_fn,
interface=self.interface,
gradient_kwargs=self.gradient_kwargs,
override_shots=override_shots,
**self.execute_kwargs,
)
res = postprocessing(res)

if old_interface == "auto":
self.interface = "auto"
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
This module contains the transform function to make your custom transforms compatible with qfunc and QNodes.
"""
from typing import get_type_hints, Sequence, Callable, List, Tuple
from typing import get_type_hints, Sequence, List, Tuple, Callable
import pennylane as qml
from .transform_dispatcher import TransformDispatcher, TransformError

Expand Down Expand Up @@ -156,7 +156,7 @@ def _transform_signature_check(signature):
"pennylane.tape.tape.QuantumTape], <built-in function callable>)"
)

if not ret[0] in (
if ret[0] not in (
rmoyard marked this conversation as resolved.
Show resolved Hide resolved
Sequence[qml.tape.QuantumTape],
List[qml.tape.QuantumTape],
Tuple[qml.tape.QuantumTape],
Expand Down
4 changes: 2 additions & 2 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def __init__(
self, transform, args=None, kwargs=None, classical_cotransform=None, is_informative=False
): # pylint:disable=redefined-outer-name,too-many-arguments
self._transform = transform
self._args = args if args else []
self._kwargs = kwargs if kwargs else {}
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative

Expand Down
204 changes: 157 additions & 47 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
"""Unit tests for the QNode"""
# pylint: disable=import-outside-toplevel, protected-access, no-member
import warnings
from collections import defaultdict
import copy

from functools import partial
from typing import Callable, Tuple

import numpy as np
import pytest
from scipy.sparse import csr_matrix


import pennylane as qml
from pennylane import QNode
from pennylane.devices import experimental
Expand Down Expand Up @@ -1268,7 +1272,7 @@ def circuit():
with qml.queuing.AnnotatedQueue() as q:
circuit()

assert q.queue == []
assert q.queue == [] # pylint: disable=use-implicit-booleaness-not-comparison
assert len(circuit.tape.operations) == 1


Expand Down Expand Up @@ -1508,67 +1512,173 @@ def qn2(x, y):
assert qn2.tape.shots.shot_vector == shot_vector


@pytest.mark.xfail
class TestSpecs:
"""Tests for the qnode property specs"""
class TestTransformProgramIntegration:
def test_transform_program_modifies_circuit(self):
"""Test qnode integration with a transform that turns the circuit into just a pauli x."""
dev = qml.device("default.qubit", wires=1)

# pylint: disable=pointless-statement
def test_specs_error(self):
"""Tests an error is raised if the tape is not constructed."""
def null_postprocessing(results):
return results[0]

dev = qml.device("default.qubit", wires=4)
@qml.transforms.core.transform
def just_pauli_x_out(
tape: qml.tape.QuantumTape,
) -> (Tuple[qml.tape.QuantumTape], Callable):
return (
qml.tape.QuantumScript([qml.PauliX(0)], tape.measurements),
), null_postprocessing

@qnode(dev)
def circuit():
@just_pauli_x_out
@qml.qnode(dev, interface=None, diff_method=None)
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

with pytest.raises(qml.QuantumFunctionError, match=r"The QNode specifications"):
circuit.specs # pylint: disable=pointless-statement
assert circuit.transform_program[0].transform == just_pauli_x_out.transform

@pytest.mark.parametrize(
"diff_method, len_info", [("backprop", 10), ("parameter-shift", 12), ("adjoint", 11)]
)
def test_specs(self, diff_method, len_info):
"""Tests the specs property with backprop, parameter-shift and adjoint diff_method"""
assert qml.math.allclose(circuit(0.1), -1)

dev = qml.device("default.qubit", wires=4)
with circuit.device.tracker as tracker:
circuit(0.1)

@qnode(dev, diff_method=diff_method)
def circuit(x, y):
qml.RX(x[0], wires=0)
qml.Toffoli(wires=(0, 1, 2))
qml.CRY(x[1], wires=(0, 1))
qml.Rot(x[2], x[3], y, wires=2)
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(1))
assert tracker.totals["executions"] == 1
assert tracker.history["resources"][0].gate_types["PauliX"] == 1
assert tracker.history["resources"][0].gate_types["RX"] == 0

def tet_transform_program_modifies_results(self):
"""Test integration with a transform that modifies the result output."""

dev = qml.device("default.qubit", wires=2)

x = pnp.array([0.05, 0.1, 0.2, 0.3], requires_grad=True)
y = pnp.array(0.1, requires_grad=False)
@qml.transforms.core.transform
def pin_result(
tape: qml.tape.QuantumTape, requested_result
) -> (Tuple[qml.tape.QuantumTape], Callable):
def postprocessing(_: qml.typing.ResultBatch) -> qml.typing.Result:
return requested_result

_ = circuit(x, y)
return (tape,), postprocessing

info = circuit.specs
@partial(pin_result, requested_result=3.0)
@qml.qnode(dev, interface=None, diff_method=None)
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

assert len(info) == len_info
assert circuit.transform_program[0].transform == pin_result.transform
assert circuit.transform_program[0].kwargs == {"requested_result": 3.0}

assert info["gate_sizes"] == defaultdict(int, {1: 2, 3: 1, 2: 1})
assert info["gate_types"] == defaultdict(int, {"RX": 1, "Toffoli": 1, "CRY": 1, "Rot": 1})
assert info["num_operations"] == 4
assert info["num_observables"] == 2
assert info["num_diagonalizing_gates"] == 1
assert info["num_used_wires"] == 3
assert info["depth"] == 3
assert info["num_device_wires"] == 4
assert qml.math.allclose(circuit(0.1), 3.0)

assert info["diff_method"] == diff_method
def test_transform_order_circuit_processing(self):
"""Test that transforms are applied in the correct order in integration."""

if diff_method == "parameter-shift":
assert info["num_parameter_shift_executions"] == 7
dev = qml.device("default.qubit", wires=2)

if diff_method != "backprop":
assert info["device_name"] == "default.qubit"
assert info["num_trainable_params"] == 4
else:
assert info["device_name"] == "default.qubit.autograd"
def null_postprocessing(results):
return results[0]

@qml.transforms.core.transform
def just_pauli_x_out(tape: qml.tape.QuantumTape) -> (Tuple[qml.tape.QuantumTape], Callable):
return (
qml.tape.QuantumScript([qml.PauliX(0)], tape.measurements),
), null_postprocessing

@qml.transforms.core.transform
def repeat_operations(
tape: qml.tape.QuantumTape,
) -> (Tuple[qml.tape.QuantumTape], Callable):
new_tape = qml.tape.QuantumScript(
tape.operations + copy.deepcopy(tape.operations), tape.measurements
)
return (new_tape,), null_postprocessing

@repeat_operations
@just_pauli_x_out
@qml.qnode(dev, interface=None, diff_method=None)
def circuit1(x):
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

with circuit1.device.tracker as tracker:
assert qml.math.allclose(circuit1(0.1), 1.0)

assert tracker.history["resources"][0].gate_types["PauliX"] == 2

@just_pauli_x_out
@repeat_operations
@qml.qnode(dev, interface=None, diff_method=None)
def circuit2(x):
qml.RX(x, 0)
return qml.expval(qml.PauliZ(0))

with circuit2.device.tracker as tracker:
assert qml.math.allclose(circuit2(0.1), -1.0)

assert tracker.history["resources"][0].gate_types["PauliX"] == 1

def test_transform_order_postprocessing(self):
"""Test that transform postprocessing is called in the right order."""

dev = qml.device("default.qubit", wires=2)

def scale_by_factor(results, factor):
return results[0] * factor

def add_shift(results, shift):
return results[0] + shift

@qml.transforms.core.transform
def scale_output(
tape: qml.tape.QuantumTape, factor
) -> (Tuple[qml.tape.QuantumTape], Callable):
return (tape,), partial(scale_by_factor, factor=factor)

@qml.transforms.core.transform
def shift_output(
tape: qml.tape.QuantumTape, shift
) -> (Tuple[qml.tape.QuantumTape], Callable):
return (tape,), partial(add_shift, shift=shift)

@partial(shift_output, shift=1.0)
@partial(scale_output, factor=2.0)
@qml.qnode(dev, interface=None, diff_method=None)
def circuit1():
return qml.expval(qml.PauliZ(0))

# first add one, then scale by 2.0. Outer postprocessing transforms are applied first
assert qml.math.allclose(circuit1(), 4.0)

@partial(scale_output, factor=2.0)
@partial(shift_output, shift=1.0)
@qml.qnode(dev, interface=None, diff_method=None)
def circuit2():
return qml.expval(qml.PauliZ(0))

# first scale by 2, then add one. Outer postprocessing transforms are applied first
assert qml.math.allclose(circuit2(), 3.0)

def test_scaling_shots_transform(self):
"""Test a transform that scales the number of shots used in an execution."""

# note that this won't work with the old device interface :(
dev = qml.devices.experimental.DefaultQubit2()

def num_of_shots_from_sample(results):
return len(results[0])

@qml.transforms.core.transform
def use_n_shots(tape: qml.tape.QuantumTape, n) -> (Tuple[qml.tape.QuantumTape], Callable):
return (
qml.tape.QuantumScript(tape.operations, tape.measurements, shots=n),
), num_of_shots_from_sample

@partial(use_n_shots, n=100)
@qml.qnode(dev, interface=None, diff_method=None)
def circuit():
return qml.sample(wires=0)

assert circuit() == 100


# pylint: disable=unused-argument
Expand Down
Loading