Skip to content

Commit

Permalink
[Feature] Pyqtorch - First Order Adjoint Differentiation (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz authored Nov 17, 2023
1 parent 14ee2d1 commit e78a711
Show file tree
Hide file tree
Showing 20 changed files with 869 additions and 243 deletions.
20 changes: 20 additions & 0 deletions examples/models/quantum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,23 @@ def circuit(n_qubits):
print("Gradient of inputs: \n")
print(torch.autograd.grad(torch.mean(model.expectation(values)), nx))
print(torch.autograd.grad(torch.mean(model.expectation(values)), ny))

# Finally, lets try ADJOINT
model = QuantumModel(
circuit(n_qubits),
observable=observable,
backend=BackendName.PYQTORCH,
diff_mode=DiffMode.ADJOINT,
)
model.zero_grad()
loss = torch.mean(model.expectation(values))
loss.backward()

print("Gradients using ADJOINT: \n")
print("Gradient in model: \n")
for key, param in model.named_parameters():
print(f"{key}: {param.grad}")

print("Gradient of inputs: \n")
print(torch.autograd.grad(torch.mean(model.expectation(values)), nx))
print(torch.autograd.grad(torch.mean(model.expectation(values)), ny))
1 change: 1 addition & 0 deletions qadence/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class Backend(ABC):
name: BackendName
supports_ad: bool
support_bp: bool
supports_adjoint: bool
is_remote: bool
with_measurements: bool
native_endianness: Endianness
Expand Down
159 changes: 159 additions & 0 deletions qadence/backends/adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import annotations

from typing import Any

from pyqtorch.apply import apply_operator
from pyqtorch.circuit import QuantumCircuit as PyQCircuit
from pyqtorch.parametric import Parametric as PyQParametric
from pyqtorch.primitive import Primitive as PyQPrimitive
from pyqtorch.utils import overlap, param_dict
from torch import Tensor, no_grad, tensor
from torch.autograd import Function
from torch.nn import Module

from qadence.backends.pyqtorch.convert_ops import PyQHamiltonianEvolution, ScalePyQOperation
from qadence.blocks.abstract import AbstractBlock


class AdjointExpectation(Function):
"""
The adjoint differentiation method (https://arxiv.org/pdf/2009.02823.pdf).
is able to perform a backward pass in O(P) time and maintaining
atmost 3 states where P is the number of parameters in a variational circuit.
Pseudo-code of the algorithm:
c: a variational circuit
c.gates = gate0,gate1,..gateN, where N denotes the last gate in c
o: a observable
state: an initial state
1. Forward pass.
for gate in c.gates: # We apply gate0, gate1 to gateN.
state = gate(state)
projected_state = o(state) # Apply the observable to the state.
expval = overlap(state, projected_state) # Compute the expected value.
2. Backward pass:
grads = []
for gate in reversed(c.gates): # Iterate through c.gates in "reverse", so gateN, gateN-1 etc.
state = dagger(gate)(state) # 'Undo' the gate by applying its dagger.
if gate is Parametric:
mu = jacobian(gate)(state) # Compute the jacobian of the gate w.r.t its parameter.
grads.append(2 * overlap(mu, projected_state) # Compute the gradient.
projected_state = dagger(gate)(projected_state) # 'Undo' the gate from the projected_state.
Current Limitations:
(1) The adjoint method is only available in the pyqtorch backend.
(2) Parametric observables are not supported.
(3) Multiple observables are not supported.
(4) Higher order derivatives are not natively supported.
(5) Only expectation values can be differentiated, not wave functions.
"""

@staticmethod
@no_grad()
def forward(
ctx: Any,
circuit: PyQCircuit,
observable: PyQCircuit,
state: Tensor,
param_names: list[str],
*param_values: Tensor,
) -> Tensor:
for param in param_values:
param = param.detach()
ctx.circuit = circuit
ctx.observable = observable
ctx.param_names = param_names
values = param_dict(param_names, param_values)
ctx.out_state = circuit.run(state, values)
ctx.projected_state = observable.run(ctx.out_state, values)
ctx.save_for_backward(*param_values)
return overlap(ctx.out_state, ctx.projected_state)

@staticmethod
@no_grad()
def backward(ctx: Any, grad_out: Tensor) -> tuple:
param_values = ctx.saved_tensors
values = param_dict(ctx.param_names, param_values)

def _apply_adjoint(ctx: Any, op: Module) -> list:
grads: list = []
if isinstance(op, PyQHamiltonianEvolution):
generator = op.block.generator
time_param = values[op.param_names[0]]
ctx.out_state = apply_operator(ctx.out_state, op.dagger(values), op.qubit_support)
# A HamEvo can have a parametrized (1) time evolution and/or (2) generator.
if (
isinstance(generator, AbstractBlock)
and generator.is_parametric
and values[op.param_names[1]].requires_grad
):
# If the generator contains a trainable parameter, we compute its gradient.
mu = apply_operator(
ctx.out_state, op.jacobian_generator(values), op.qubit_support
)
grads.append(2 * overlap(ctx.projected_state, mu))
if time_param.requires_grad:
# If the time evolution is trainable, we compute its gradient.
mu = apply_operator(ctx.out_state, op.jacobian_time(values), op.qubit_support)
grads.append(2 * overlap(ctx.projected_state, mu))
ctx.projected_state = apply_operator(
ctx.projected_state, op.dagger(values), op.qubit_support
)
elif isinstance(op, ScalePyQOperation):
ctx.out_state = apply_operator(ctx.out_state, op.dagger(values), op.qubit_support)
scaled_pyq_op = op.operations[0]
if (
isinstance(scaled_pyq_op, PyQParametric)
and values[scaled_pyq_op.param_name].requires_grad
):
mu = apply_operator(
ctx.out_state,
scaled_pyq_op.jacobian(values),
scaled_pyq_op.qubit_support,
)
grads.append(2 * overlap(ctx.projected_state, mu))

if values[op.param_name].requires_grad:
grads.append(2 * -values[op.param_name])
ctx.projected_state = apply_operator(
ctx.projected_state, op.dagger(values), op.qubit_support
)
elif isinstance(op, PyQCircuit):
grads = [g for sub_op in op.reverse() for g in _apply_adjoint(ctx, sub_op)]
elif isinstance(op, PyQPrimitive):
ctx.out_state = apply_operator(ctx.out_state, op.dagger(values), op.qubit_support)
if isinstance(op, PyQParametric) and values[op.param_name].requires_grad:
mu = apply_operator(
ctx.out_state,
op.jacobian(values),
op.qubit_support,
)
grads.append(2 * overlap(ctx.projected_state, mu))
ctx.projected_state = apply_operator(
ctx.projected_state, op.dagger(values), op.qubit_support
)
else:
raise TypeError(
f"AdjointExpectation does not support a backward pass for type {type(op)}."
)

return grads

grads = list(
reversed(
[grad_out * g for op in ctx.circuit.reverse() for g in _apply_adjoint(ctx, op)]
)
)
num_grads = len(grads)
num_params = len(ctx.saved_tensors)
diff = num_params - num_grads
grads = grads + [tensor([0]) for _ in range(diff)]
# Set observable grads to 0
ctx.save_for_backward(*grads)
return (None, None, None, None, *grads)
2 changes: 2 additions & 0 deletions qadence/backends/braket/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def promote_parameters(parameters: dict[str, Tensor | float]) -> dict[str, float
class Backend(BackendInterface):
name: BackendName = BackendName.BRAKET
supports_ad: bool = False
supports_adjoint: bool = False
# TODO Use native braket adjoint differentiation.
support_bp: bool = False
is_remote: bool = False
with_measurements: bool = True
Expand Down
1 change: 1 addition & 0 deletions qadence/backends/pulser/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class Backend(BackendInterface):

name: BackendName = BackendName.PULSER
supports_ad: bool = False
supports_adjoint: bool = False
support_bp: bool = False
is_remote: bool = False
with_measurements: bool = True
Expand Down
68 changes: 27 additions & 41 deletions qadence/backends/pyqtorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,37 @@

from collections import Counter
from dataclasses import dataclass, field
from math import prod
from typing import Any

import pyqtorch as pyq
import torch
from torch import Tensor

from qadence.backend import Backend as BackendInterface
from qadence.backend import BackendName, ConvertedCircuit, ConvertedObservable
from qadence.backends.utils import to_list_of_dicts
from qadence.backend import ConvertedCircuit, ConvertedObservable
from qadence.backends.utils import (
infer_batchsize,
pyqify,
to_list_of_dicts,
unpyqify,
validate_state,
)
from qadence.blocks import AbstractBlock
from qadence.circuit import QuantumCircuit
from qadence.logger import get_logger
from qadence.measurements import Measurements
from qadence.mitigations.protocols import Mitigations, apply_mitigation
from qadence.noise import Noise
from qadence.noise.protocols import apply_noise
from qadence.overlap import overlap_exact
from qadence.states import zero_state
from qadence.transpile import (
chain_single_qubit_ops,
flatten,
invert_endianness,
scale_primitive_blocks_only,
transpile,
)
from qadence.utils import Endianness, int_to_basis
from qadence.types import BackendName, Endianness
from qadence.utils import int_to_basis

from .config import Configuration, default_passes
from .convert_ops import convert_block, convert_observable
Expand All @@ -42,6 +47,7 @@ class Backend(BackendInterface):
name: BackendName = BackendName.PYQTORCH
supports_ad: bool = True
support_bp: bool = True
supports_adjoint: bool = True
is_remote: bool = False
with_measurements: bool = True
with_noise: bool = False
Expand Down Expand Up @@ -84,42 +90,17 @@ def run(
unpyqify_state: bool = True,
) -> Tensor:
n_qubits = circuit.abstract.n_qubits

if state is not None:
if pyqify_state:
if (state.ndim != 2) or (state.size(1) != 2**n_qubits):
raise ValueError(
"The initial state must be composed of tensors of size "
f"(batch_size, 2**n_qubits). Found: {state.size() = }."
)

# PyQ expects a column vector for the initial state
# where each element is of dim=2.
state = state.T.reshape([2] * n_qubits + [state.size(0)])
else:
if prod(state.size()[:-1]) != 2**n_qubits:
raise ValueError(
"A pyqified initial state must be composed of tensors of size "
f"(2, 2, ..., batch_size). Found: {state.size() = }."
)
if state is None:
# If no state is passed, we infer the batch_size through the length
# of the individual parameter value tensors.
state = circuit.native.init_state(batch_size=infer_batchsize(param_values))
else:
# infer batch_size without state
if len(param_values) == 0:
batch_size = 1
else:
batch_size = max([len(tensor) for tensor in param_values.values()])
state = circuit.native.init_state(batch_size=batch_size)
validate_state(state, n_qubits)
# pyqtorch expects input shape [2] * n_qubits + [batch_size]
state = pyqify(state, n_qubits) if pyqify_state else state
state = circuit.native.run(state, param_values)

# make sure that the batch dimension is the first one, as standard
# for PyTorch, and not the last one as done in PyQ
if unpyqify_state:
state = torch.flatten(state, start_dim=0, end_dim=-2).t()

if endianness != self.native_endianness:
from qadence.transpile import invert_endianness

state = invert_endianness(state)
state = unpyqify(state) if unpyqify_state else state
state = invert_endianness(state) if endianness != self.native_endianness else state
return state

def _batched_expectation(
Expand Down Expand Up @@ -157,7 +138,10 @@ def _looped_expectation(
noise: Noise | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
state = zero_state(circuit.abstract.n_qubits, batch_size=1) if state is None else state
if state is None:
from qadence.states import zero_state

state = zero_state(circuit.abstract.n_qubits, batch_size=1)
if state.size(0) != 1:
raise ValueError(
"Looping expectation does not make sense with batched initial state. "
Expand Down Expand Up @@ -253,6 +237,8 @@ def assign_parameters(self, circuit: ConvertedCircuit, param_values: dict[str, T

@staticmethod
def _overlap(bras: Tensor, kets: Tensor) -> Tensor:
from qadence.overlap import overlap_exact

return overlap_exact(bras, kets)

@staticmethod
Expand Down
Loading

0 comments on commit e78a711

Please sign in to comment.