-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Engines, Backends] Add JAX Engine and Horqrux Backend (#111)
- Loading branch information
1 parent
4ba3a9c
commit 5a7c8dd
Showing
42 changed files
with
1,460 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
### ::: qadence.backends.pytorch_wrapper | ||
### ::: qadence.engines.torch.differentiable_backend | ||
### ::: qadence.engines.jax.differentiable_backend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ For more enquiries, please contact: [`[email protected]`](mailto:[email protected]). | |
|
||
## Differentiation backend | ||
|
||
The [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend] class enables different differentiation modes | ||
The [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend] class enables different differentiation modes | ||
for the given backend. This can be chosen from two types: | ||
|
||
- Automatic differentiation (AD): available for PyTorch based backends (PyQTorch). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Callable | ||
|
||
import jax.numpy as jnp | ||
import optax | ||
from jax import Array, jit, value_and_grad | ||
from numpy.typing import ArrayLike | ||
|
||
from qadence.backends import backend_factory | ||
from qadence.blocks.utils import chain | ||
from qadence.circuit import QuantumCircuit | ||
from qadence.constructors import feature_map, hea, total_magnetization | ||
from qadence.types import BackendName, DiffMode | ||
|
||
backend = BackendName.HORQRUX | ||
|
||
num_epochs = 10 | ||
n_qubits = 4 | ||
depth = 1 | ||
|
||
fm = feature_map(n_qubits) | ||
circ = QuantumCircuit(n_qubits, chain(fm, hea(n_qubits, depth=depth))) | ||
obs = total_magnetization(n_qubits) | ||
|
||
for diff_mode in [DiffMode.AD, DiffMode.GPSR]: | ||
bknd = backend_factory(backend, diff_mode) | ||
conv_circ, conv_obs, embedding_fn, vparams = bknd.convert(circ, obs) | ||
init_params = vparams.copy() | ||
optimizer = optax.adam(learning_rate=0.001) | ||
opt_state = optimizer.init(vparams) | ||
|
||
loss: Array | ||
grads: dict[str, Array] # 'grads' is the same datatype as 'params' | ||
inputs: dict[str, Array] = {"phi": jnp.array(1.0)} | ||
|
||
def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple: | ||
updates, opt_state = optimizer.update(grads, opt_state, params) | ||
params = optax.apply_updates(params, updates) | ||
return params, opt_state | ||
|
||
def exp_fn(params: dict[str, Array], inputs: dict[str, Array] = inputs) -> ArrayLike: | ||
return bknd.expectation(conv_circ, conv_obs, embedding_fn(params, inputs)) | ||
|
||
init_pred = exp_fn(vparams) | ||
|
||
def mse_loss(params: dict[str, Array], y_true: Array) -> Array: | ||
expval = exp_fn(params) | ||
return (expval - y_true) ** 2 | ||
|
||
@jit | ||
def train_step( | ||
params: dict, | ||
opt_state: Array, | ||
y_true: Array = jnp.array(1.0, dtype=jnp.float64), | ||
loss_fn: Callable = mse_loss, | ||
) -> tuple: | ||
loss, grads = value_and_grad(loss_fn)(params, y_true) | ||
params, opt_state = optimize_step(params, opt_state, grads) | ||
return loss, params, opt_state | ||
|
||
for epoch in range(num_epochs): | ||
loss, vparams, opt_state = train_step(vparams, opt_state) | ||
print(f"epoch {epoch} loss:{loss}") | ||
|
||
final_pred = exp_fn(vparams) | ||
|
||
print( | ||
f"diff_mode '{diff_mode}: Initial prediction: {init_pred}, initial vparams: {init_params}" | ||
) | ||
print(f"Final prediction: {final_pred}, final vparams: {vparams}") | ||
print("----------") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.