Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Engines, Backends] Add JAX Engine and Horqrux Backend #111

Merged
merged 114 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 112 commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
75c856a
Add skeleton
dominikandreasseitz Oct 17, 2023
2d7a1c2
add embedding
dominikandreasseitz Oct 17, 2023
2d354d3
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 17, 2023
926317e
:int
dominikandreasseitz Nov 17, 2023
7d0c06d
Add horqrux files
dominikandreasseitz Nov 17, 2023
0790acd
Fix bunch of stuff
dominikandreasseitz Nov 17, 2023
5a64563
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 20, 2023
85e10a2
Add engines, new Diff backend
dominikandreasseitz Nov 20, 2023
b31c4aa
Fix mypy
dominikandreasseitz Nov 20, 2023
d565711
Adjust horq
dominikandreasseitz Nov 20, 2023
fb88885
Major refac, add psr proto
dominikandreasseitz Nov 21, 2023
b25b27d
Jax PSR prototype
dominikandreasseitz Nov 21, 2023
940d0b3
JaxExpectation
dominikandreasseitz Nov 21, 2023
871616d
Restructure
dominikandreasseitz Nov 22, 2023
602898a
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 22, 2023
5c124e3
Reorder
dominikandreasseitz Nov 22, 2023
0899c1c
new folder for engine tests
dominikandreasseitz Nov 22, 2023
82557a0
cleanup convert_ops
dominikandreasseitz Nov 22, 2023
27f0763
Some changes
dominikandreasseitz Nov 22, 2023
d235c7d
Add example, fix embedding
dominikandreasseitz Nov 22, 2023
9adbdb9
Sign wrong PSR
dominikandreasseitz Nov 23, 2023
ae07435
PSR correct
dominikandreasseitz Nov 23, 2023
38aefb2
Test hea grads
dominikandreasseitz Nov 23, 2023
3f610e5
PSR higher order test with possible autodiff use; to be investigated
dominikandreasseitz Nov 23, 2023
0d95f28
recursive PSR
dominikandreasseitz Nov 23, 2023
35df507
3rd order PSR test
dominikandreasseitz Nov 23, 2023
d5f9f1d
Some cleanup
dominikandreasseitz Nov 23, 2023
8f6ae64
JAX PSR, fix missing parameters
dominikandreasseitz Nov 24, 2023
de8ac03
Add engine module inits
dominikandreasseitz Nov 24, 2023
e5a13ba
Invert endianness tests
dominikandreasseitz Nov 24, 2023
5b07247
Fix endianness tests
dominikandreasseitz Nov 24, 2023
84d31a0
Fix docs diff backend import
dominikandreasseitz Nov 24, 2023
14bcaef
Fix docs
dominikandreasseitz Nov 26, 2023
48b9262
Convert mat in embedding
dominikandreasseitz Nov 26, 2023
72eafe6
Typo
dominikandreasseitz Nov 27, 2023
f2370aa
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 27, 2023
cbdc20c
use numpy to avoid tracing
dominikandreasseitz Nov 27, 2023
7e3cd73
Use corresponding dtype
dominikandreasseitz Nov 27, 2023
1f1a3bc
Add horqrux transpilation steps
dominikandreasseitz Nov 27, 2023
8226a2e
Wrap krons in cirs
dominikandreasseitz Nov 28, 2023
200ad65
Fix all test_backends tests
dominikandreasseitz Nov 28, 2023
f52fa9f
Lint
dominikandreasseitz Nov 28, 2023
a2df392
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 28, 2023
db47a7c
Reorder inits
dominikandreasseitz Nov 28, 2023
1be74fe
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 28, 2023
3fcd8bf
Add heaviside
dominikandreasseitz Nov 28, 2023
760f702
Move qadence to separate engines
dominikandreasseitz Nov 29, 2023
86ad64c
Proper separation of jax and torch
dominikandreasseitz Nov 29, 2023
320f97a
Move fn
dominikandreasseitz Nov 29, 2023
ed9196b
Refac obs
dominikandreasseitz Nov 29, 2023
c1ad9c1
Horqrux, add batching of params and multiple observables
dominikandreasseitz Nov 29, 2023
aeed96e
Remove invert_endianness overload for jax arrays
dominikandreasseitz Nov 29, 2023
174fe8b
No batching in run
dominikandreasseitz Nov 29, 2023
25496ff
Vmap circ
dominikandreasseitz Nov 29, 2023
01ce128
Vmap circ
dominikandreasseitz Nov 29, 2023
be71da8
Use python map for observables
dominikandreasseitz Nov 29, 2023
590d069
Lint
dominikandreasseitz Nov 29, 2023
ff4847a
Squeeze expvals if batch_size=1
dominikandreasseitz Nov 30, 2023
176df68
Add jax to docs envs
dominikandreasseitz Nov 30, 2023
967eaa3
Merge branch 'main' into ds/jax
dominikandreasseitz Nov 30, 2023
74b410f
Update docs with jax
dominikandreasseitz Nov 30, 2023
902305d
Simplify single gap PSR
dominikandreasseitz Nov 30, 2023
d0b70b7
Fix jix for single gap compute
dominikandreasseitz Nov 30, 2023
bf01584
Small refac
dominikandreasseitz Nov 30, 2023
5f75060
Trying to fix lax jit issues with spectral gaps
dominikandreasseitz Nov 30, 2023
b67f816
Comment on jit/lax tensor/array conversions
dominikandreasseitz Nov 30, 2023
aae98a5
Rm pytreenode decorators
dominikandreasseitz Nov 30, 2023
296d96e
Cleanup typing
dominikandreasseitz Nov 30, 2023
45c34a6
Fix phrasing
dominikandreasseitz Dec 1, 2023
94b1acd
Cleanup dependencies
dominikandreasseitz Dec 1, 2023
5eb68ed
Add correct type
dominikandreasseitz Dec 1, 2023
019ff13
Dont pin scipy
dominikandreasseitz Dec 1, 2023
81a3b87
Dont use alias
dominikandreasseitz Dec 1, 2023
c0d5f95
Arraylike type
dominikandreasseitz Dec 1, 2023
4915d3e
Differentiate between import error and backend not available
dominikandreasseitz Dec 1, 2023
c691607
Add available_engines, Rename engines to DifferentiableBackend
dominikandreasseitz Dec 1, 2023
976a46b
Simply DiffBackend
dominikandreasseitz Dec 1, 2023
081b6d3
Lint
dominikandreasseitz Dec 1, 2023
a7c07ee
Typo
dominikandreasseitz Dec 1, 2023
68ae708
Add docstrings
dominikandreasseitz Dec 1, 2023
0872c6e
Update loop
dominikandreasseitz Dec 1, 2023
2f4f5e7
Merge branch 'main' into ds/jax
dominikandreasseitz Dec 1, 2023
a5fb541
Test multi_observables, muliparameter expression and batch_size > 1
dominikandreasseitz Dec 4, 2023
783c964
Use list comprehesion
dominikandreasseitz Dec 4, 2023
f397cf3
Merge branch 'main' into ds/jax
dominikandreasseitz Dec 4, 2023
2dfda0c
Lint
dominikandreasseitz Dec 4, 2023
0836e60
Adjust horqrux to main
dominikandreasseitz Dec 4, 2023
005eb6c
Typo
dominikandreasseitz Dec 4, 2023
522a08a
Merge branch 'main' into ds/jax
dominikandreasseitz Dec 5, 2023
2164366
Merge branch 'main' into ds/jax
dominikandreasseitz Dec 5, 2023
578e377
Lint
dominikandreasseitz Dec 5, 2023
ff7bacc
Add jax to readme
dominikandreasseitz Dec 5, 2023
314ea30
Better example, enforce float64
dominikandreasseitz Dec 5, 2023
28d037d
Typo
dominikandreasseitz Dec 5, 2023
d9e6728
Remove numpy
dominikandreasseitz Dec 5, 2023
6a21b06
Add logger info about backends and engines found
dominikandreasseitz Dec 6, 2023
1c8dbb8
Improve JAX PSR, add tests with batched higher order grads
dominikandreasseitz Dec 6, 2023
d1fa5d4
Add missing backend docstrings
dominikandreasseitz Dec 6, 2023
50fce1d
Vmap param shifting in PSR
dominikandreasseitz Dec 6, 2023
febe3cb
Experimental full vmap PSR
dominikandreasseitz Dec 6, 2023
08dca18
lint
dominikandreasseitz Dec 6, 2023
0d00551
Some comments for reviwers
dominikandreasseitz Dec 6, 2023
3bca06a
Merge branch 'main' into ds/jax
dominikandreasseitz Dec 8, 2023
decdeb4
Use FM
dominikandreasseitz Dec 8, 2023
b42e9c0
Add tests for mixed 2nd order derivatives
dominikandreasseitz Dec 8, 2023
3171a84
Remove sympy2jax import
dominikandreasseitz Dec 8, 2023
1beb3fc
Correct import
dominikandreasseitz Dec 8, 2023
a58cedc
Fix import alias
dominikandreasseitz Dec 8, 2023
0d05000
Docstrings
dominikandreasseitz Dec 8, 2023
5a29419
Improve engine docstrings
dominikandreasseitz Dec 8, 2023
1c98eb9
Adjust to qd paradigm
dominikandreasseitz Dec 8, 2023
1e3c780
Rm experimental PSR for now
dominikandreasseitz Dec 8, 2023
5a2fa31
Merge branch 'main' into ds/jax
dominikandreasseitz Dec 11, 2023
be6e9a4
Bump version, slight refactor PSR, mark hessian test as slow
dominikandreasseitz Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ Qadence is available on [PyPI](https://pypi.org/project/qadence/) and can be ins
pip install qadence
```

The default, pre-installed backend for Qadence is [PyQTorch](https://github.com/pasqal-io/pyqtorch), a differentiable state vector simulator for digital-analog simulation. It is possible to install additional backends and the circuit visualization library using the following extras:
The default, pre-installed backend for Qadence is [PyQTorch](https://github.com/pasqal-io/pyqtorch), a differentiable state vector simulator for digital-analog simulation based on `PyTorch`. It is possible to install additional, `PyTorch` -based backends and the circuit visualization library using the following extras:

* `pulser`: The [Pulser](https://github.com/pasqal-io/Pulser) backend for composing, simulating and executing pulse sequences for neutral-atom quantum devices.
* `braket`: The [Braket](https://github.com/amazon-braket/amazon-braket-sdk-python) backend, an open source library that provides a framework for interacting with quantum computing hardware devices through Amazon Braket.
* `visualization`: A visualization library to display quantum circuit diagrams.

Qadence also supports a `JAX` engine which is currently supporting the [Horqrux](https://github.com/pasqal-io/horqrux) backend. `horqrux` is currently only available via the [low-level API](examples/backends/low_level/horqrux_backend.py).


To install individual extras, use the following syntax (**IMPORTANT** Make sure to use quotes):

```bash
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced_tutorials/differentiability.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ print(docsutils.fig_to_html(plt.gcf())) # markdown-exec: hide
In order to get a finer control over the GPSR differentiation engine we can use the low-level Qadence API to define a `DifferentiableBackend`.

```python exec="on" source="material-block" session="differentiability"
from qadence import DifferentiableBackend
from qadence.engines.torch import DifferentiableBackend
from qadence.backends.pyqtorch import Backend as PyQBackend

# define differentiable quantum backend
Expand Down
3 changes: 2 additions & 1 deletion docs/backends/differentiable.md
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
### ::: qadence.backends.pytorch_wrapper
### ::: qadence.engines.torch.differentiable_backend
dominikandreasseitz marked this conversation as resolved.
Show resolved Hide resolved
### ::: qadence.engines.jax.differentiable_backend
9 changes: 4 additions & 5 deletions docs/development/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ In Qadence there are 4 main objects spread across 3 different levels of abstract
* **Differentiation layer**: Intermediate layer has the purpose of integrating quantum
computation with a given automatic differentiation engine. It is meant to be purely stateless and
contains one object:
* [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend]:
* [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend]:
An abstract class whose concrete implementation wraps a quantum backend and make it
automatically differentiable using different engines (e.g. PyTorch or Jax).
Note, that today only PyTorch is supported but there is plan to add also a Jax
Expand Down Expand Up @@ -57,7 +57,7 @@ and outputs.

### `DifferentiableBackend`

The differentiable backend is a thin wrapper which takes as input a `QuantumCircuit` instance and a chosen quantum backend and make the circuit execution routines (expectation value, overalap, etc.) differentiable. Currently, the only implemented differentiation engine is PyTorch but it is easy to add support to another one like Jax.
The differentiable backend is a thin wrapper which takes as input a `QuantumCircuit` instance and a chosen quantum backend and make the circuit execution routines (expectation value, overalap, etc.) differentiable. Qadence offers both a PyTorch and Jax differentiation engine.

### Quantum `Backend`

Expand Down Expand Up @@ -104,8 +104,7 @@ You can see the logic for choosing the parameter identifier in [`get_param_name`

## Differentiation with parameter shift rules (PSR)

In Qadence, parameter shift rules are implemented by extending the PyTorch autograd engine using custom `Function`
objects. The implementation is based on this PyTorch [guide](https://pytorch.org/docs/stable/notes/extending.html).
In Qadence, parameter shift rules are applied by implementing a custom `torch.autograd.Function` class for PyTorch and the `custom_vjp` in the Jax Engine, respectively.

A custom PyTorch `Function` looks like this:

Expand All @@ -130,7 +129,7 @@ class CustomFunction(Function):
...
```

The class [`PSRExpectation`][qadence.backends.pytorch_wrapper.PSRExpectation] implements parameter shift rules for all parameters using
The class `PSRExpectation` under `qadence.engines.torch.differentiable_expectation` implements parameter shift rules for all parameters using
a custom function as the one above. There are a few implementation details to keep in mind if you want
to modify the PSR code:

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ For more enquiries, please contact: [`[email protected]`](mailto:[email protected]).

## Differentiation backend

The [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend] class enables different differentiation modes
The [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend] class enables different differentiation modes
for the given backend. This can be chosen from two types:

- Automatic differentiation (AD): available for PyTorch based backends (PyQTorch).
Expand Down
2 changes: 1 addition & 1 deletion examples/backends/differentiable_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
CNOT,
RX,
RY,
DifferentiableBackend,
Parameter,
QuantumCircuit,
chain,
total_magnetization,
)
from qadence.backends.pyqtorch.backend import Backend as PyQTorchBackend
from qadence.engines.torch.differentiable_backend import DifferentiableBackend

torch.manual_seed(42)

Expand Down
72 changes: 72 additions & 0 deletions examples/backends/low_level/horqrux_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from typing import Callable

import jax.numpy as jnp
import optax
from jax import Array, jit, value_and_grad
from numpy.typing import ArrayLike

from qadence.backends import backend_factory
from qadence.blocks.utils import chain
from qadence.circuit import QuantumCircuit
from qadence.constructors import feature_map, hea, total_magnetization
from qadence.types import BackendName, DiffMode

backend = BackendName.HORQRUX

num_epochs = 10
n_qubits = 4
depth = 1

fm = feature_map(n_qubits)
circ = QuantumCircuit(n_qubits, chain(fm, hea(n_qubits, depth=depth)))
obs = total_magnetization(n_qubits)

for diff_mode in [DiffMode.AD, DiffMode.GPSR]:
bknd = backend_factory(backend, diff_mode)
conv_circ, conv_obs, embedding_fn, vparams = bknd.convert(circ, obs)
init_params = vparams.copy()
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(vparams)

loss: Array
grads: dict[str, Array] # 'grads' is the same datatype as 'params'
inputs: dict[str, Array] = {"phi": jnp.array(1.0)}

def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state

def exp_fn(params: dict[str, Array], inputs: dict[str, Array] = inputs) -> ArrayLike:
return bknd.expectation(conv_circ, conv_obs, embedding_fn(params, inputs))

init_pred = exp_fn(vparams)

def mse_loss(params: dict[str, Array], y_true: Array) -> Array:
expval = exp_fn(params)
return (expval - y_true) ** 2

@jit
def train_step(
params: dict,
opt_state: Array,
y_true: Array = jnp.array(1.0, dtype=jnp.float64),
loss_fn: Callable = mse_loss,
) -> tuple:
loss, grads = value_and_grad(loss_fn)(params, y_true)
params, opt_state = optimize_step(params, opt_state, grads)
return loss, params, opt_state

for epoch in range(num_epochs):
loss, vparams, opt_state = train_step(vparams, opt_state)
print(f"epoch {epoch} loss:{loss}")

final_pred = exp_fn(vparams)

print(
f"diff_mode '{diff_mode}: Initial prediction: {init_pred}, initial vparams: {init_params}"
)
print(f"Final prediction: {final_pred}, final vparams: {vparams}")
print("----------")
24 changes: 21 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ classifiers=[
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"openfermion",
"numpy",
"torch",
"openfermion",
"sympytorch>=0.1.2",
"rich",
"tensorboard>=2.12.0",
Expand All @@ -41,7 +42,7 @@ dependencies = [
"nevergrad",
"scipy",
"pyqtorch==1.0.3",
"matplotlib"
"matplotlib",
]

[tool.hatch.metadata]
Expand All @@ -57,13 +58,29 @@ visualization = [
# "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg",
# "scour",
]
horqrux = [
"horqrux==0.3.0",
"jax",
dominikandreasseitz marked this conversation as resolved.
Show resolved Hide resolved
"flax",
"optax",
"jaxopt",
"einops",
"sympy2jax"]

all = [
"pulser>=0.15.2",
"amazon-braket-sdk",
"graphviz",
# FIXME: will be needed once we support latex labels
# "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg",
# "scour",
"horqrux==0.3.0",
"jax",
"flax",
"optax",
"jaxopt",
"einops",
"sympy2jax"
]

[tool.hatch.envs.default]
Expand Down Expand Up @@ -120,7 +137,7 @@ dependencies = [
"markdown-exec",
"mike",
]
features = ["pulser", "braket", "visualization"]
features = ["pulser", "braket", "horqrux", "visualization"]

[tool.hatch.envs.docs.scripts]
build = "mkdocs build --clean --strict"
Expand Down Expand Up @@ -167,6 +184,7 @@ required-imports = ["from __future__ import annotations"]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"operations.py" = ["E742"] # Avoid ambiguous class name warning for identity.
"qadence/backends/horqrux/convert_ops.py" = ["E741"] # Avoid ambiguous class name warning for 0.

[tool.ruff.mccabe]
max-complexity = 15
Expand Down
1 change: 1 addition & 0 deletions qadence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .blocks import *
from .circuit import *
from .constructors import *
from .engines import *
from .exceptions import *
from .execution import *
from .measurements import *
Expand Down
32 changes: 18 additions & 14 deletions qadence/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from qadence.mitigations import Mitigations
from qadence.noise import Noise
from qadence.parameters import stringify
from qadence.types import BackendName, DiffMode, Endianness
from qadence.types import ArrayLike, BackendName, DiffMode, Endianness, Engine, ParamDictType
from qadence.utils import validate_values_and_state

logger = get_logger(__file__)
Expand Down Expand Up @@ -100,11 +100,14 @@ class Backend(ABC):
name: backend unique string identifier
supports_ad: whether or not the backend has a native autograd
supports_bp: whether or not the backend has a native backprop
supports_adjoint: Does the backend support native adjoint differentation.
is_remote: whether computations are executed locally or remotely on this
backend, useful when using cloud platforms where credentials are
needed for example.
with_measurements: whether it supports counts or not
with_noise: whether to add realistic noise or not
native_endianness: The native endianness of the backend
engine: The underlying (native) automatic differentiation engine of the backend.
"""

name: BackendName
Expand All @@ -114,6 +117,7 @@ class Backend(ABC):
is_remote: bool
with_measurements: bool
native_endianness: Endianness
engine: Engine
dominikandreasseitz marked this conversation as resolved.
Show resolved Hide resolved

# FIXME: should this also go into the configuration?
with_noise: bool
Expand Down Expand Up @@ -199,7 +203,7 @@ def check_observable(obs_obj: Any) -> AbstractBlock:

conv_circ = self.circuit(circuit)
circ_params, circ_embedding_fn = embedding(
conv_circ.abstract.block, self.config._use_gate_params
conv_circ.abstract.block, self.config._use_gate_params, self.engine
)
params = circ_params
if observable is not None:
Expand All @@ -211,7 +215,7 @@ def check_observable(obs_obj: Any) -> AbstractBlock:
obs = check_observable(obs)
c_obs = self.observable(obs, max(circuit.n_qubits, obs.n_qubits))
obs_params, obs_embedding_fn = embedding(
c_obs.abstract, self.config._use_gate_params
c_obs.abstract, self.config._use_gate_params, self.engine
)
params.update(obs_params)
obs_embedding_fn_list.append(obs_embedding_fn)
Expand All @@ -236,7 +240,7 @@ def sample(
circuit: ConvertedCircuit,
param_values: dict[str, Tensor] = {},
n_shots: int = 1000,
state: Tensor | None = None,
state: ArrayLike | None = None,
noise: Noise | None = None,
mitigation: Mitigations | None = None,
endianness: Endianness = Endianness.BIG,
Expand All @@ -259,10 +263,10 @@ def sample(
def _run(
self,
circuit: ConvertedCircuit,
param_values: dict[str, Tensor] = {},
state: Tensor | None = None,
param_values: dict[str, ArrayLike] = {},
state: ArrayLike | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
) -> ArrayLike:
"""Run a circuit and return the resulting wave function.

Arguments:
Expand All @@ -281,12 +285,12 @@ def _run(
def run(
self,
circuit: ConvertedCircuit,
param_values: dict[str, Tensor] = {},
param_values: dict[str, ArrayLike] = {},
state: Tensor | None = None,
endianness: Endianness = Endianness.BIG,
*args: Any,
**kwargs: Any,
) -> Tensor:
) -> ArrayLike:
"""Run a circuit and return the resulting wave function.

Arguments:
Expand All @@ -308,7 +312,7 @@ def run_dm(
self,
circuit: ConvertedCircuit,
noise: Noise,
param_values: dict[str, Tensor] = {},
param_values: dict[str, ArrayLike] = {},
state: Tensor | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
Expand All @@ -335,13 +339,13 @@ def expectation(
self,
circuit: ConvertedCircuit,
observable: list[ConvertedObservable] | ConvertedObservable,
param_values: dict[str, Tensor] = {},
state: Tensor | None = None,
param_values: ParamDictType = {},
state: ArrayLike | None = None,
measurement: Measurements | None = None,
noise: Noise | None = None,
mitigation: Mitigations | None = None,
endianness: Endianness = Endianness.BIG,
) -> Tensor:
) -> ArrayLike:
"""Compute the expectation value of the `circuit` with the given `observable`.

Arguments:
Expand Down Expand Up @@ -398,7 +402,7 @@ class Converted:
circuit: ConvertedCircuit
observable: list[ConvertedObservable] | ConvertedObservable | None
embedding_fn: Callable
params: dict[str, Tensor]
params: ParamDictType

def __iter__(self) -> Iterator:
yield self.circuit
Expand Down
3 changes: 1 addition & 2 deletions qadence/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

from .api import backend_factory, config_factory
from .pytorch_wrapper import DifferentiableBackend

# Modules to be automatically added to the qadence namespace
__all__ = ["backend_factory", "config_factory", "DifferentiableBackend"]
__all__ = ["backend_factory", "config_factory"]
Loading