Skip to content

Commit

Permalink
[Engines, Backends] Add JAX Engine and Horqrux Backend (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz authored Dec 11, 2023
1 parent 4ba3a9c commit 5a7c8dd
Show file tree
Hide file tree
Showing 42 changed files with 1,460 additions and 103 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ Qadence is available on [PyPI](https://pypi.org/project/qadence/) and can be ins
pip install qadence
```

The default, pre-installed backend for Qadence is [PyQTorch](https://github.com/pasqal-io/pyqtorch), a differentiable state vector simulator for digital-analog simulation. It is possible to install additional backends and the circuit visualization library using the following extras:
The default, pre-installed backend for Qadence is [PyQTorch](https://github.com/pasqal-io/pyqtorch), a differentiable state vector simulator for digital-analog simulation based on `PyTorch`. It is possible to install additional, `PyTorch` -based backends and the circuit visualization library using the following extras:

* `pulser`: The [Pulser](https://github.com/pasqal-io/Pulser) backend for composing, simulating and executing pulse sequences for neutral-atom quantum devices.
* `braket`: The [Braket](https://github.com/amazon-braket/amazon-braket-sdk-python) backend, an open source library that provides a framework for interacting with quantum computing hardware devices through Amazon Braket.
* `visualization`: A visualization library to display quantum circuit diagrams.

Qadence also supports a `JAX` engine which is currently supporting the [Horqrux](https://github.com/pasqal-io/horqrux) backend. `horqrux` is currently only available via the [low-level API](examples/backends/low_level/horqrux_backend.py).


To install individual extras, use the following syntax (**IMPORTANT** Make sure to use quotes):

```bash
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced_tutorials/differentiability.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ print(docsutils.fig_to_html(plt.gcf())) # markdown-exec: hide
In order to get a finer control over the GPSR differentiation engine we can use the low-level Qadence API to define a `DifferentiableBackend`.

```python exec="on" source="material-block" session="differentiability"
from qadence import DifferentiableBackend
from qadence.engines.torch import DifferentiableBackend
from qadence.backends.pyqtorch import Backend as PyQBackend

# define differentiable quantum backend
Expand Down
3 changes: 2 additions & 1 deletion docs/backends/differentiable.md
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
### ::: qadence.backends.pytorch_wrapper
### ::: qadence.engines.torch.differentiable_backend
### ::: qadence.engines.jax.differentiable_backend
9 changes: 4 additions & 5 deletions docs/development/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ In Qadence there are 4 main objects spread across 3 different levels of abstract
* **Differentiation layer**: Intermediate layer has the purpose of integrating quantum
computation with a given automatic differentiation engine. It is meant to be purely stateless and
contains one object:
* [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend]:
* [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend]:
An abstract class whose concrete implementation wraps a quantum backend and make it
automatically differentiable using different engines (e.g. PyTorch or Jax).
Note, that today only PyTorch is supported but there is plan to add also a Jax
Expand Down Expand Up @@ -57,7 +57,7 @@ and outputs.

### `DifferentiableBackend`

The differentiable backend is a thin wrapper which takes as input a `QuantumCircuit` instance and a chosen quantum backend and make the circuit execution routines (expectation value, overalap, etc.) differentiable. Currently, the only implemented differentiation engine is PyTorch but it is easy to add support to another one like Jax.
The differentiable backend is a thin wrapper which takes as input a `QuantumCircuit` instance and a chosen quantum backend and make the circuit execution routines (expectation value, overalap, etc.) differentiable. Qadence offers both a PyTorch and Jax differentiation engine.

### Quantum `Backend`

Expand Down Expand Up @@ -104,8 +104,7 @@ You can see the logic for choosing the parameter identifier in [`get_param_name`

## Differentiation with parameter shift rules (PSR)

In Qadence, parameter shift rules are implemented by extending the PyTorch autograd engine using custom `Function`
objects. The implementation is based on this PyTorch [guide](https://pytorch.org/docs/stable/notes/extending.html).
In Qadence, parameter shift rules are applied by implementing a custom `torch.autograd.Function` class for PyTorch and the `custom_vjp` in the Jax Engine, respectively.

A custom PyTorch `Function` looks like this:

Expand All @@ -130,7 +129,7 @@ class CustomFunction(Function):
...
```

The class [`PSRExpectation`][qadence.backends.pytorch_wrapper.PSRExpectation] implements parameter shift rules for all parameters using
The class `PSRExpectation` under `qadence.engines.torch.differentiable_expectation` implements parameter shift rules for all parameters using
a custom function as the one above. There are a few implementation details to keep in mind if you want
to modify the PSR code:

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ For more enquiries, please contact: [`[email protected]`](mailto:[email protected]).

## Differentiation backend

The [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend] class enables different differentiation modes
The [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend] class enables different differentiation modes
for the given backend. This can be chosen from two types:

- Automatic differentiation (AD): available for PyTorch based backends (PyQTorch).
Expand Down
2 changes: 1 addition & 1 deletion examples/backends/differentiable_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
CNOT,
RX,
RY,
DifferentiableBackend,
Parameter,
QuantumCircuit,
chain,
total_magnetization,
)
from qadence.backends.pyqtorch.backend import Backend as PyQTorchBackend
from qadence.engines.torch.differentiable_backend import DifferentiableBackend

torch.manual_seed(42)

Expand Down
72 changes: 72 additions & 0 deletions examples/backends/low_level/horqrux_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from typing import Callable

import jax.numpy as jnp
import optax
from jax import Array, jit, value_and_grad
from numpy.typing import ArrayLike

from qadence.backends import backend_factory
from qadence.blocks.utils import chain
from qadence.circuit import QuantumCircuit
from qadence.constructors import feature_map, hea, total_magnetization
from qadence.types import BackendName, DiffMode

backend = BackendName.HORQRUX

num_epochs = 10
n_qubits = 4
depth = 1

fm = feature_map(n_qubits)
circ = QuantumCircuit(n_qubits, chain(fm, hea(n_qubits, depth=depth)))
obs = total_magnetization(n_qubits)

for diff_mode in [DiffMode.AD, DiffMode.GPSR]:
bknd = backend_factory(backend, diff_mode)
conv_circ, conv_obs, embedding_fn, vparams = bknd.convert(circ, obs)
init_params = vparams.copy()
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(vparams)

loss: Array
grads: dict[str, Array] # 'grads' is the same datatype as 'params'
inputs: dict[str, Array] = {"phi": jnp.array(1.0)}

def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state

def exp_fn(params: dict[str, Array], inputs: dict[str, Array] = inputs) -> ArrayLike:
return bknd.expectation(conv_circ, conv_obs, embedding_fn(params, inputs))

init_pred = exp_fn(vparams)

def mse_loss(params: dict[str, Array], y_true: Array) -> Array:
expval = exp_fn(params)
return (expval - y_true) ** 2

@jit
def train_step(
params: dict,
opt_state: Array,
y_true: Array = jnp.array(1.0, dtype=jnp.float64),
loss_fn: Callable = mse_loss,
) -> tuple:
loss, grads = value_and_grad(loss_fn)(params, y_true)
params, opt_state = optimize_step(params, opt_state, grads)
return loss, params, opt_state

for epoch in range(num_epochs):
loss, vparams, opt_state = train_step(vparams, opt_state)
print(f"epoch {epoch} loss:{loss}")

final_pred = exp_fn(vparams)

print(
f"diff_mode '{diff_mode}: Initial prediction: {init_pred}, initial vparams: {init_params}"
)
print(f"Final prediction: {final_pred}, final vparams: {vparams}")
print("----------")
26 changes: 22 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ authors = [
]
requires-python = ">=3.9,<3.12"
license = {text = "Apache 2.0"}
version = "1.2.0"
version = "1.2.1"
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
Expand All @@ -31,8 +31,9 @@ classifiers=[
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"openfermion",
"numpy",
"torch",
"openfermion",
"sympytorch>=0.1.2",
"rich",
"tensorboard>=2.12.0",
Expand All @@ -41,7 +42,7 @@ dependencies = [
"nevergrad",
"scipy",
"pyqtorch==1.0.3",
"matplotlib"
"matplotlib",
]

[tool.hatch.metadata]
Expand All @@ -57,13 +58,29 @@ visualization = [
# "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg",
# "scour",
]
horqrux = [
"horqrux==0.3.0",
"jax",
"flax",
"optax",
"jaxopt",
"einops",
"sympy2jax"]

all = [
"pulser>=0.15.2",
"amazon-braket-sdk",
"graphviz",
# FIXME: will be needed once we support latex labels
# "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg",
# "scour",
"horqrux==0.3.0",
"jax",
"flax",
"optax",
"jaxopt",
"einops",
"sympy2jax"
]

[tool.hatch.envs.default]
Expand Down Expand Up @@ -120,7 +137,7 @@ dependencies = [
"markdown-exec",
"mike",
]
features = ["pulser", "braket", "visualization"]
features = ["pulser", "braket", "horqrux", "visualization"]

[tool.hatch.envs.docs.scripts]
build = "mkdocs build --clean --strict"
Expand Down Expand Up @@ -167,6 +184,7 @@ required-imports = ["from __future__ import annotations"]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"operations.py" = ["E742"] # Avoid ambiguous class name warning for identity.
"qadence/backends/horqrux/convert_ops.py" = ["E741"] # Avoid ambiguous class name warning for 0.

[tool.ruff.mccabe]
max-complexity = 15
Expand Down
1 change: 1 addition & 0 deletions qadence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .blocks import *
from .circuit import *
from .constructors import *
from .engines import *
from .exceptions import *
from .execution import *
from .measurements import *
Expand Down
32 changes: 18 additions & 14 deletions qadence/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from qadence.mitigations import Mitigations
from qadence.noise import Noise
from qadence.parameters import stringify
from qadence.types import BackendName, DiffMode, Endianness
from qadence.types import ArrayLike, BackendName, DiffMode, Endianness, Engine, ParamDictType
from qadence.utils import validate_values_and_state

logger = get_logger(__file__)
Expand Down Expand Up @@ -100,11 +100,14 @@ class Backend(ABC):
name: backend unique string identifier
supports_ad: whether or not the backend has a native autograd
supports_bp: whether or not the backend has a native backprop
supports_adjoint: Does the backend support native adjoint differentation.
is_remote: whether computations are executed locally or remotely on this
backend, useful when using cloud platforms where credentials are
needed for example.
with_measurements: whether it supports counts or not
with_noise: whether to add realistic noise or not
native_endianness: The native endianness of the backend
engine: The underlying (native) automatic differentiation engine of the backend.
"""

name: BackendName
Expand All @@ -114,6 +117,7 @@ class Backend(ABC):
is_remote: bool
with_measurements: bool
native_endianness: Endianness
engine: Engine

# FIXME: should this also go into the configuration?
with_noise: bool
Expand Down Expand Up @@ -199,7 +203,7 @@ def check_observable(obs_obj: Any) -> AbstractBlock:

conv_circ = self.circuit(circuit)
circ_params, circ_embedding_fn = embedding(
conv_circ.abstract.block, self.config._use_gate_params
conv_circ.abstract.block, self.config._use_gate_params, self.engine
)
params = circ_params
if observable is not None:
Expand All @@ -211,7 +215,7 @@ def check_observable(obs_obj: Any) -> AbstractBlock:
obs = check_observable(obs)
c_obs = self.observable(obs, max(circuit.n_qubits, obs.n_qubits))
obs_params, obs_embedding_fn = embedding(
c_obs.abstract, self.config._use_gate_params
c_obs.abstract, self.config._use_gate_params, self.engine
)
params.update(obs_params)
obs_embedding_fn_list.append(obs_embedding_fn)
Expand All @@ -236,7 +240,7 @@ def sample(
circuit: ConvertedCircuit,
param_values: dict[str, Tensor] = {},
n_shots: int = 1000,
state: Tensor | None = None,
state: ArrayLike | None = None,
noise: Noise | None = None,
mitigation: Mitigations | None = None,
endianness: Endianness = Endianness.BIG,
Expand All @@ -259,10 +263,10 @@ def sample(
def _run(
self,
circuit: ConvertedCircuit,
param_values: dict[str, Tensor] = {},
state: Tensor | None = None,
param_values: dict[str, ArrayLike] = {},
state: ArrayLike | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
) -> ArrayLike:
"""Run a circuit and return the resulting wave function.
Arguments:
Expand All @@ -281,12 +285,12 @@ def _run(
def run(
self,
circuit: ConvertedCircuit,
param_values: dict[str, Tensor] = {},
param_values: dict[str, ArrayLike] = {},
state: Tensor | None = None,
endianness: Endianness = Endianness.BIG,
*args: Any,
**kwargs: Any,
) -> Tensor:
) -> ArrayLike:
"""Run a circuit and return the resulting wave function.
Arguments:
Expand All @@ -308,7 +312,7 @@ def run_dm(
self,
circuit: ConvertedCircuit,
noise: Noise,
param_values: dict[str, Tensor] = {},
param_values: dict[str, ArrayLike] = {},
state: Tensor | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
Expand All @@ -335,13 +339,13 @@ def expectation(
self,
circuit: ConvertedCircuit,
observable: list[ConvertedObservable] | ConvertedObservable,
param_values: dict[str, Tensor] = {},
state: Tensor | None = None,
param_values: ParamDictType = {},
state: ArrayLike | None = None,
measurement: Measurements | None = None,
noise: Noise | None = None,
mitigation: Mitigations | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
) -> ArrayLike:
"""Compute the expectation value of the `circuit` with the given `observable`.
Arguments:
Expand Down Expand Up @@ -398,7 +402,7 @@ class Converted:
circuit: ConvertedCircuit
observable: list[ConvertedObservable] | ConvertedObservable | None
embedding_fn: Callable
params: dict[str, Tensor]
params: ParamDictType

def __iter__(self) -> Iterator:
yield self.circuit
Expand Down
3 changes: 1 addition & 2 deletions qadence/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

from .api import backend_factory, config_factory
from .pytorch_wrapper import DifferentiableBackend

# Modules to be automatically added to the qadence namespace
__all__ = ["backend_factory", "config_factory", "DifferentiableBackend"]
__all__ = ["backend_factory", "config_factory"]
Loading

0 comments on commit 5a7c8dd

Please sign in to comment.