diff --git a/README.md b/README.md index 22c8a5e6..d11b06c1 100644 --- a/README.md +++ b/README.md @@ -34,12 +34,15 @@ Qadence is available on [PyPI](https://pypi.org/project/qadence/) and can be ins pip install qadence ``` -The default, pre-installed backend for Qadence is [PyQTorch](https://github.com/pasqal-io/pyqtorch), a differentiable state vector simulator for digital-analog simulation. It is possible to install additional backends and the circuit visualization library using the following extras: +The default, pre-installed backend for Qadence is [PyQTorch](https://github.com/pasqal-io/pyqtorch), a differentiable state vector simulator for digital-analog simulation based on `PyTorch`. It is possible to install additional, `PyTorch` -based backends and the circuit visualization library using the following extras: * `pulser`: The [Pulser](https://github.com/pasqal-io/Pulser) backend for composing, simulating and executing pulse sequences for neutral-atom quantum devices. * `braket`: The [Braket](https://github.com/amazon-braket/amazon-braket-sdk-python) backend, an open source library that provides a framework for interacting with quantum computing hardware devices through Amazon Braket. * `visualization`: A visualization library to display quantum circuit diagrams. +Qadence also supports a `JAX` engine which is currently supporting the [Horqrux](https://github.com/pasqal-io/horqrux) backend. `horqrux` is currently only available via the [low-level API](examples/backends/low_level/horqrux_backend.py). + + To install individual extras, use the following syntax (**IMPORTANT** Make sure to use quotes): ```bash diff --git a/docs/advanced_tutorials/differentiability.md b/docs/advanced_tutorials/differentiability.md index bb6034fe..45b987d2 100644 --- a/docs/advanced_tutorials/differentiability.md +++ b/docs/advanced_tutorials/differentiability.md @@ -139,7 +139,7 @@ print(docsutils.fig_to_html(plt.gcf())) # markdown-exec: hide In order to get a finer control over the GPSR differentiation engine we can use the low-level Qadence API to define a `DifferentiableBackend`. ```python exec="on" source="material-block" session="differentiability" -from qadence import DifferentiableBackend +from qadence.engines.torch import DifferentiableBackend from qadence.backends.pyqtorch import Backend as PyQBackend # define differentiable quantum backend diff --git a/docs/backends/differentiable.md b/docs/backends/differentiable.md index 3341e5a2..d8baf667 100644 --- a/docs/backends/differentiable.md +++ b/docs/backends/differentiable.md @@ -1 +1,2 @@ -### ::: qadence.backends.pytorch_wrapper +### ::: qadence.engines.torch.differentiable_backend +### ::: qadence.engines.jax.differentiable_backend diff --git a/docs/development/architecture.md b/docs/development/architecture.md index 45243070..1db2e37c 100644 --- a/docs/development/architecture.md +++ b/docs/development/architecture.md @@ -21,7 +21,7 @@ In Qadence there are 4 main objects spread across 3 different levels of abstract * **Differentiation layer**: Intermediate layer has the purpose of integrating quantum computation with a given automatic differentiation engine. It is meant to be purely stateless and contains one object: - * [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend]: + * [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend]: An abstract class whose concrete implementation wraps a quantum backend and make it automatically differentiable using different engines (e.g. PyTorch or Jax). Note, that today only PyTorch is supported but there is plan to add also a Jax @@ -57,7 +57,7 @@ and outputs. ### `DifferentiableBackend` -The differentiable backend is a thin wrapper which takes as input a `QuantumCircuit` instance and a chosen quantum backend and make the circuit execution routines (expectation value, overalap, etc.) differentiable. Currently, the only implemented differentiation engine is PyTorch but it is easy to add support to another one like Jax. +The differentiable backend is a thin wrapper which takes as input a `QuantumCircuit` instance and a chosen quantum backend and make the circuit execution routines (expectation value, overalap, etc.) differentiable. Qadence offers both a PyTorch and Jax differentiation engine. ### Quantum `Backend` @@ -104,8 +104,7 @@ You can see the logic for choosing the parameter identifier in [`get_param_name` ## Differentiation with parameter shift rules (PSR) -In Qadence, parameter shift rules are implemented by extending the PyTorch autograd engine using custom `Function` -objects. The implementation is based on this PyTorch [guide](https://pytorch.org/docs/stable/notes/extending.html). +In Qadence, parameter shift rules are applied by implementing a custom `torch.autograd.Function` class for PyTorch and the `custom_vjp` in the Jax Engine, respectively. A custom PyTorch `Function` looks like this: @@ -130,7 +129,7 @@ class CustomFunction(Function): ... ``` -The class [`PSRExpectation`][qadence.backends.pytorch_wrapper.PSRExpectation] implements parameter shift rules for all parameters using +The class `PSRExpectation` under `qadence.engines.torch.differentiable_expectation` implements parameter shift rules for all parameters using a custom function as the one above. There are a few implementation details to keep in mind if you want to modify the PSR code: diff --git a/docs/tutorials/backends.md b/docs/tutorials/backends.md index 9dcc706a..9074110a 100644 --- a/docs/tutorials/backends.md +++ b/docs/tutorials/backends.md @@ -25,7 +25,7 @@ For more enquiries, please contact: [`info@pasqal.com`](mailto:info@pasqal.com). ## Differentiation backend -The [`DifferentiableBackend`][qadence.backends.pytorch_wrapper.DifferentiableBackend] class enables different differentiation modes +The [`DifferentiableBackend`][qadence.engines.torch.DifferentiableBackend] class enables different differentiation modes for the given backend. This can be chosen from two types: - Automatic differentiation (AD): available for PyTorch based backends (PyQTorch). diff --git a/examples/backends/differentiable_backend.py b/examples/backends/differentiable_backend.py index 84e67ca0..45c7aa99 100644 --- a/examples/backends/differentiable_backend.py +++ b/examples/backends/differentiable_backend.py @@ -8,13 +8,13 @@ CNOT, RX, RY, - DifferentiableBackend, Parameter, QuantumCircuit, chain, total_magnetization, ) from qadence.backends.pyqtorch.backend import Backend as PyQTorchBackend +from qadence.engines.torch.differentiable_backend import DifferentiableBackend torch.manual_seed(42) diff --git a/examples/backends/low_level/horqrux_backend.py b/examples/backends/low_level/horqrux_backend.py new file mode 100644 index 00000000..5470cae3 --- /dev/null +++ b/examples/backends/low_level/horqrux_backend.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Callable + +import jax.numpy as jnp +import optax +from jax import Array, jit, value_and_grad +from numpy.typing import ArrayLike + +from qadence.backends import backend_factory +from qadence.blocks.utils import chain +from qadence.circuit import QuantumCircuit +from qadence.constructors import feature_map, hea, total_magnetization +from qadence.types import BackendName, DiffMode + +backend = BackendName.HORQRUX + +num_epochs = 10 +n_qubits = 4 +depth = 1 + +fm = feature_map(n_qubits) +circ = QuantumCircuit(n_qubits, chain(fm, hea(n_qubits, depth=depth))) +obs = total_magnetization(n_qubits) + +for diff_mode in [DiffMode.AD, DiffMode.GPSR]: + bknd = backend_factory(backend, diff_mode) + conv_circ, conv_obs, embedding_fn, vparams = bknd.convert(circ, obs) + init_params = vparams.copy() + optimizer = optax.adam(learning_rate=0.001) + opt_state = optimizer.init(vparams) + + loss: Array + grads: dict[str, Array] # 'grads' is the same datatype as 'params' + inputs: dict[str, Array] = {"phi": jnp.array(1.0)} + + def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple: + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state + + def exp_fn(params: dict[str, Array], inputs: dict[str, Array] = inputs) -> ArrayLike: + return bknd.expectation(conv_circ, conv_obs, embedding_fn(params, inputs)) + + init_pred = exp_fn(vparams) + + def mse_loss(params: dict[str, Array], y_true: Array) -> Array: + expval = exp_fn(params) + return (expval - y_true) ** 2 + + @jit + def train_step( + params: dict, + opt_state: Array, + y_true: Array = jnp.array(1.0, dtype=jnp.float64), + loss_fn: Callable = mse_loss, + ) -> tuple: + loss, grads = value_and_grad(loss_fn)(params, y_true) + params, opt_state = optimize_step(params, opt_state, grads) + return loss, params, opt_state + + for epoch in range(num_epochs): + loss, vparams, opt_state = train_step(vparams, opt_state) + print(f"epoch {epoch} loss:{loss}") + + final_pred = exp_fn(vparams) + + print( + f"diff_mode '{diff_mode}: Initial prediction: {init_pred}, initial vparams: {init_params}" + ) + print(f"Final prediction: {final_pred}, final vparams: {vparams}") + print("----------") diff --git a/pyproject.toml b/pyproject.toml index ae5dc848..ad468562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ authors = [ ] requires-python = ">=3.9,<3.12" license = {text = "Apache 2.0"} -version = "1.2.0" +version = "1.2.1" classifiers=[ "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", @@ -31,8 +31,9 @@ classifiers=[ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "openfermion", + "numpy", "torch", + "openfermion", "sympytorch>=0.1.2", "rich", "tensorboard>=2.12.0", @@ -41,7 +42,7 @@ dependencies = [ "nevergrad", "scipy", "pyqtorch==1.0.3", - "matplotlib" + "matplotlib", ] [tool.hatch.metadata] @@ -57,6 +58,15 @@ visualization = [ # "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg", # "scour", ] +horqrux = [ + "horqrux==0.3.0", + "jax", + "flax", + "optax", + "jaxopt", + "einops", + "sympy2jax"] + all = [ "pulser>=0.15.2", "amazon-braket-sdk", @@ -64,6 +74,13 @@ all = [ # FIXME: will be needed once we support latex labels # "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg", # "scour", + "horqrux==0.3.0", + "jax", + "flax", + "optax", + "jaxopt", + "einops", + "sympy2jax" ] [tool.hatch.envs.default] @@ -120,7 +137,7 @@ dependencies = [ "markdown-exec", "mike", ] -features = ["pulser", "braket", "visualization"] +features = ["pulser", "braket", "horqrux", "visualization"] [tool.hatch.envs.docs.scripts] build = "mkdocs build --clean --strict" @@ -167,6 +184,7 @@ required-imports = ["from __future__ import annotations"] [tool.ruff.per-file-ignores] "__init__.py" = ["F401"] "operations.py" = ["E742"] # Avoid ambiguous class name warning for identity. +"qadence/backends/horqrux/convert_ops.py" = ["E741"] # Avoid ambiguous class name warning for 0. [tool.ruff.mccabe] max-complexity = 15 diff --git a/qadence/__init__.py b/qadence/__init__.py index a98837fa..33c529e9 100644 --- a/qadence/__init__.py +++ b/qadence/__init__.py @@ -11,6 +11,7 @@ from .blocks import * from .circuit import * from .constructors import * +from .engines import * from .exceptions import * from .execution import * from .measurements import * diff --git a/qadence/backend.py b/qadence/backend.py index aafb87ff..7ed3971b 100644 --- a/qadence/backend.py +++ b/qadence/backend.py @@ -25,7 +25,7 @@ from qadence.mitigations import Mitigations from qadence.noise import Noise from qadence.parameters import stringify -from qadence.types import BackendName, DiffMode, Endianness +from qadence.types import ArrayLike, BackendName, DiffMode, Endianness, Engine, ParamDictType from qadence.utils import validate_values_and_state logger = get_logger(__file__) @@ -100,11 +100,14 @@ class Backend(ABC): name: backend unique string identifier supports_ad: whether or not the backend has a native autograd supports_bp: whether or not the backend has a native backprop + supports_adjoint: Does the backend support native adjoint differentation. is_remote: whether computations are executed locally or remotely on this backend, useful when using cloud platforms where credentials are needed for example. with_measurements: whether it supports counts or not with_noise: whether to add realistic noise or not + native_endianness: The native endianness of the backend + engine: The underlying (native) automatic differentiation engine of the backend. """ name: BackendName @@ -114,6 +117,7 @@ class Backend(ABC): is_remote: bool with_measurements: bool native_endianness: Endianness + engine: Engine # FIXME: should this also go into the configuration? with_noise: bool @@ -199,7 +203,7 @@ def check_observable(obs_obj: Any) -> AbstractBlock: conv_circ = self.circuit(circuit) circ_params, circ_embedding_fn = embedding( - conv_circ.abstract.block, self.config._use_gate_params + conv_circ.abstract.block, self.config._use_gate_params, self.engine ) params = circ_params if observable is not None: @@ -211,7 +215,7 @@ def check_observable(obs_obj: Any) -> AbstractBlock: obs = check_observable(obs) c_obs = self.observable(obs, max(circuit.n_qubits, obs.n_qubits)) obs_params, obs_embedding_fn = embedding( - c_obs.abstract, self.config._use_gate_params + c_obs.abstract, self.config._use_gate_params, self.engine ) params.update(obs_params) obs_embedding_fn_list.append(obs_embedding_fn) @@ -236,7 +240,7 @@ def sample( circuit: ConvertedCircuit, param_values: dict[str, Tensor] = {}, n_shots: int = 1000, - state: Tensor | None = None, + state: ArrayLike | None = None, noise: Noise | None = None, mitigation: Mitigations | None = None, endianness: Endianness = Endianness.BIG, @@ -259,10 +263,10 @@ def sample( def _run( self, circuit: ConvertedCircuit, - param_values: dict[str, Tensor] = {}, - state: Tensor | None = None, + param_values: dict[str, ArrayLike] = {}, + state: ArrayLike | None = None, endianness: Endianness = Endianness.BIG, - ) -> Tensor: + ) -> ArrayLike: """Run a circuit and return the resulting wave function. Arguments: @@ -281,12 +285,12 @@ def _run( def run( self, circuit: ConvertedCircuit, - param_values: dict[str, Tensor] = {}, + param_values: dict[str, ArrayLike] = {}, state: Tensor | None = None, endianness: Endianness = Endianness.BIG, *args: Any, **kwargs: Any, - ) -> Tensor: + ) -> ArrayLike: """Run a circuit and return the resulting wave function. Arguments: @@ -308,7 +312,7 @@ def run_dm( self, circuit: ConvertedCircuit, noise: Noise, - param_values: dict[str, Tensor] = {}, + param_values: dict[str, ArrayLike] = {}, state: Tensor | None = None, endianness: Endianness = Endianness.BIG, ) -> Tensor: @@ -335,13 +339,13 @@ def expectation( self, circuit: ConvertedCircuit, observable: list[ConvertedObservable] | ConvertedObservable, - param_values: dict[str, Tensor] = {}, - state: Tensor | None = None, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, measurement: Measurements | None = None, noise: Noise | None = None, mitigation: Mitigations | None = None, endianness: Endianness = Endianness.BIG, - ) -> Tensor: + ) -> ArrayLike: """Compute the expectation value of the `circuit` with the given `observable`. Arguments: @@ -398,7 +402,7 @@ class Converted: circuit: ConvertedCircuit observable: list[ConvertedObservable] | ConvertedObservable | None embedding_fn: Callable - params: dict[str, Tensor] + params: ParamDictType def __iter__(self) -> Iterator: yield self.circuit diff --git a/qadence/backends/__init__.py b/qadence/backends/__init__.py index 7a4a1470..ca77a8d9 100644 --- a/qadence/backends/__init__.py +++ b/qadence/backends/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations from .api import backend_factory, config_factory -from .pytorch_wrapper import DifferentiableBackend # Modules to be automatically added to the qadence namespace -__all__ = ["backend_factory", "config_factory", "DifferentiableBackend"] +__all__ = ["backend_factory", "config_factory"] diff --git a/qadence/backends/api.py b/qadence/backends/api.py index 348a06b4..35d7a04c 100644 --- a/qadence/backends/api.py +++ b/qadence/backends/api.py @@ -1,9 +1,9 @@ from __future__ import annotations from qadence.backend import Backend, BackendConfiguration -from qadence.backends.pytorch_wrapper import DifferentiableBackend -from qadence.extensions import available_backends, set_backend_config -from qadence.types import BackendName, DiffMode +from qadence.engines.differentiable_backend import DifferentiableBackend +from qadence.extensions import available_backends, available_engines, set_backend_config +from qadence.types import BackendName, DiffMode, Engine __all__ = ["backend_factory", "config_factory"] @@ -14,13 +14,18 @@ def backend_factory( configuration: BackendConfiguration | dict | None = None, ) -> Backend | DifferentiableBackend: backend_inst: Backend | DifferentiableBackend - backend_name = BackendName(backend) backends = available_backends() - + try: + backend_name = BackendName(backend) + except ValueError: + raise NotImplementedError(f"The requested backend '{backend}' is not implemented.") try: BackendCls = backends[backend_name] - except (KeyError, ValueError): - raise NotImplementedError(f"The requested backend '{backend_name}' is not implemented.") + except Exception as e: + raise ImportError( + f"The requested backend '{backend_name}' is either not installed\ + or could not be imported due to {e}." + ) default_config = BackendCls.default_configuration() if configuration is None: @@ -44,9 +49,22 @@ def backend_factory( # Set backend configurations which depend on the differentiation mode set_backend_config(backend_inst, diff_mode) - + # Wrap the quantum Backend in a DifferentiableBackend if a diff_mode is passed. if diff_mode is not None: - backend_inst = DifferentiableBackend(backend_inst, DiffMode(diff_mode)) + try: + engine_name = Engine(backend_inst.engine) + except ValueError: + raise NotImplementedError( + f"The requested engine '{backend_inst.engine}' is not implemented." + ) + try: + diff_backend_cls = available_engines()[engine_name] + backend_inst = diff_backend_cls(backend=backend_inst, diff_mode=DiffMode(diff_mode)) # type: ignore[arg-type] + except Exception as e: + raise ImportError( + f"The requested engine '{engine_name}' is either not installed\ + or could not be imported due to {e}." + ) return backend_inst diff --git a/qadence/backends/braket/backend.py b/qadence/backends/braket/backend.py index 4be5ac60..aab8d6f8 100644 --- a/qadence/backends/braket/backend.py +++ b/qadence/backends/braket/backend.py @@ -23,7 +23,7 @@ from qadence.noise.protocols import apply_noise from qadence.overlap import overlap_exact from qadence.transpile import transpile -from qadence.types import BackendName +from qadence.types import BackendName, Engine from qadence.utils import Endianness from .config import Configuration, default_passes @@ -55,6 +55,7 @@ class Backend(BackendInterface): with_noise: bool = False native_endianness: Endianness = Endianness.BIG config: Configuration = field(default_factory=Configuration) + engine: Engine = Engine.TORCH # braket specifics # TODO: include it in the configuration? diff --git a/qadence/backends/horqrux/__init__.py b/qadence/backends/horqrux/__init__.py new file mode 100644 index 00000000..59cb7d4d --- /dev/null +++ b/qadence/backends/horqrux/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .backend import Backend +from .config import Configuration +from .convert_ops import supported_gates diff --git a/qadence/backends/horqrux/backend.py b/qadence/backends/horqrux/backend.py new file mode 100644 index 00000000..e4dbe141 --- /dev/null +++ b/qadence/backends/horqrux/backend.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass, field +from typing import Any + +import jax +import jax.numpy as jnp +from horqrux.utils import prepare_state +from jax.typing import ArrayLike + +from qadence.backend import Backend as BackendInterface +from qadence.backend import ConvertedCircuit, ConvertedObservable +from qadence.backends.jax_utils import ( + tensor_to_jnp, + unhorqify, + uniform_batchsize, +) +from qadence.backends.utils import pyqify +from qadence.blocks import AbstractBlock +from qadence.circuit import QuantumCircuit +from qadence.measurements import Measurements +from qadence.mitigations import Mitigations +from qadence.noise import Noise +from qadence.transpile import flatten, scale_primitive_blocks_only, transpile +from qadence.types import BackendName, Endianness, Engine, ParamDictType +from qadence.utils import int_to_basis + +from .config import Configuration, default_passes +from .convert_ops import HorqruxCircuit, convert_block, convert_observable + + +@dataclass(frozen=True, eq=True) +class Backend(BackendInterface): + # set standard interface parameters + name: BackendName = BackendName.HORQRUX # type: ignore[assignment] + supports_ad: bool = True + supports_adjoint: bool = False + support_bp: bool = True + supports_native_psr: bool = False + is_remote: bool = False + with_measurements: bool = True + with_noise: bool = False + native_endianness: Endianness = Endianness.BIG + config: Configuration = field(default_factory=Configuration) + engine: Engine = Engine.JAX + + def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit: + passes = self.config.transpilation_passes + if passes is None: + passes = default_passes(self.config) + + original_circ = circuit + if len(passes) > 0: + circuit = transpile(*passes)(circuit) + ops = convert_block(circuit.block, n_qubits=circuit.n_qubits, config=self.config) + return ConvertedCircuit( + native=HorqruxCircuit(ops), abstract=circuit, original=original_circ + ) + + def observable(self, observable: AbstractBlock, n_qubits: int) -> ConvertedObservable: + transpilations = [ + flatten, + scale_primitive_blocks_only, + ] + block = transpile(*transpilations)(observable) # type: ignore[call-overload] + hq_obs = convert_observable(block, n_qubits=n_qubits, config=self.config) + return ConvertedObservable(native=hq_obs, abstract=block, original=observable) + + def _run( + self, + circuit: ConvertedCircuit, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + endianness: Endianness = Endianness.BIG, + horqify_state: bool = True, + unhorqify_state: bool = True, + ) -> ArrayLike: + n_qubits = circuit.abstract.n_qubits + if state is None: + state = prepare_state(n_qubits, "0" * n_qubits) + else: + state = tensor_to_jnp(pyqify(state)) if horqify_state else state + state = circuit.native.forward(state, param_values) + if endianness != self.native_endianness: + state = jnp.reshape(state, (1, 2**n_qubits)) # batch_size is always 1 + ls = list(range(2**n_qubits)) + permute_ind = jnp.array([int(f"{num:0{n_qubits}b}"[::-1], 2) for num in ls]) + state = state[:, permute_ind] + if unhorqify_state: + state = unhorqify(state) + return state + + def run_dm( + self, + circuit: ConvertedCircuit, + noise: Noise, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + endianness: Endianness = Endianness.BIG, + ) -> ArrayLike: + raise NotImplementedError + + def expectation( + self, + circuit: ConvertedCircuit, + observable: list[ConvertedObservable] | ConvertedObservable, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + measurement: Measurements | None = None, + noise: Noise | None = None, + mitigation: Mitigations | None = None, + endianness: Endianness = Endianness.BIG, + ) -> ArrayLike: + observable = observable if isinstance(observable, list) else [observable] + batch_size = max([arr.size for arr in param_values.values()]) + n_obs = len(observable) + + def _expectation(params: ParamDictType) -> ArrayLike: + out_state = self.run( + circuit, params, state, endianness, horqify_state=True, unhorqify_state=False + ) + return jnp.array([o.native.forward(out_state, params) for o in observable]) + + if batch_size > 1: # We vmap for batch_size > 1 + expvals = jax.vmap(_expectation, in_axes=({k: 0 for k in param_values.keys()},))( + uniform_batchsize(param_values) + ) + else: + expvals = _expectation(param_values) + if expvals.size > 1: + expvals = jnp.reshape(expvals, (batch_size, n_obs)) + else: + expvals = jnp.squeeze( + expvals, 0 + ) # For the case of batch_size == n_obs == 1, we remove the dims + return expvals + + def sample( + self, + circuit: ConvertedCircuit, + param_values: ParamDictType = {}, + n_shots: int = 1, + state: ArrayLike | None = None, + noise: Noise | None = None, + mitigation: Mitigations | None = None, + endianness: Endianness = Endianness.BIG, + ) -> list[Counter]: + """Samples from a batch of discrete probability distributions. + + Args: + circuit: A ConvertedCircuit object holding the native PyQ Circuit. + param_values: A dict holding the embedded parameters which the native ciruit expects. + n_shots: The number of samples to generate per distribution. + state: The input state. + endianness (Endianness): The target endianness of the resulting samples. + + Returns: + A list of Counter objects where each key represents a bitstring + and its value the number of times it has been sampled from the given wave function. + """ + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + def _sample( + _probs: ArrayLike, n_shots: int, endianness: Endianness, n_qubits: int + ) -> Counter: + _logits = jax.vmap(lambda _p: jnp.log(_p / (1 - _p)))(_probs) + + def _smple(accumulator: ArrayLike, i: int) -> tuple[ArrayLike, None]: + accumulator = accumulator.at[i].set( + jax.random.categorical(jax.random.PRNGKey(i), _logits) + ) + return accumulator, None + + samples = jax.lax.scan( + _smple, jnp.empty_like(jnp.arange(n_shots)), jnp.arange(n_shots) + )[0] + return Counter( + { + int_to_basis(k=k, n_qubits=n_qubits, endianness=endianness): count.item() + for k, count in enumerate(jnp.bincount(samples)) + if count > 0 + } + ) + + wf = self.run( + circuit=circuit, + param_values=param_values, + state=state, + horqify_state=True, + unhorqify_state=False, + ) + probs = jnp.abs(jnp.float_power(wf, 2.0)).ravel() + samples = [ + _sample( + _probs=probs, + n_shots=n_shots, + endianness=endianness, + n_qubits=circuit.abstract.n_qubits, + ), + ] + + return samples + + def assign_parameters(self, circuit: ConvertedCircuit, param_values: ParamDictType) -> Any: + raise NotImplementedError + + @staticmethod + def _overlap(bras: ArrayLike, kets: ArrayLike) -> ArrayLike: + # TODO + raise NotImplementedError + + @staticmethod + def default_configuration() -> Configuration: + return Configuration() diff --git a/qadence/backends/horqrux/config.py b/qadence/backends/horqrux/config.py new file mode 100644 index 00000000..d31e79f1 --- /dev/null +++ b/qadence/backends/horqrux/config.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from qadence.backend import BackendConfiguration +from qadence.logger import get_logger +from qadence.transpile import ( + blockfn_to_circfn, + flatten, + scale_primitive_blocks_only, +) + +logger = get_logger(__name__) + + +def default_passes(config: Configuration) -> list[Callable]: + return [ + flatten, + blockfn_to_circfn(scale_primitive_blocks_only), + ] + + +@dataclass +class Configuration(BackendConfiguration): + pass diff --git a/qadence/backends/horqrux/convert_ops.py b/qadence/backends/horqrux/convert_ops.py new file mode 100644 index 00000000..2f0c4b5c --- /dev/null +++ b/qadence/backends/horqrux/convert_ops.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from functools import reduce +from itertools import chain as flatten +from operator import add +from typing import Any, Callable, Dict + +import jax.numpy as jnp +from horqrux.gates import NOT, H, I, Rx, Ry, Rz, X, Y, Z +from horqrux.ops import apply_gate +from horqrux.types import Gate +from horqrux.utils import overlap +from jax import Array +from jax.tree_util import register_pytree_node_class + +from qadence.blocks import ( + AbstractBlock, + AddBlock, + ChainBlock, + CompositeBlock, + KronBlock, + ParametricBlock, + PrimitiveBlock, + ScaleBlock, +) +from qadence.operations import CNOT, CRX, CRY, CRZ +from qadence.types import OpName, ParamDictType + +from .config import Configuration + +ops_map: Dict[str, Callable] = { + OpName.X: X, + OpName.Y: Y, + OpName.Z: Z, + OpName.H: H, + OpName.RX: Rx, + OpName.RY: Ry, + OpName.RZ: Rz, + OpName.CRX: Rx, + OpName.CRY: Ry, + OpName.CRZ: Rz, + OpName.CNOT: NOT, + OpName.I: I, +} + +supported_gates = list(set(list(ops_map.keys()))) + + +@register_pytree_node_class +@dataclass +class HorqruxCircuit: + operators: list[Gate] = field(default_factory=list) + + def tree_flatten(self) -> tuple[tuple[list[Any]], tuple[()]]: + children = (self.operators,) + aux_data = () + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: + return cls(*children, *aux_data) + + def forward(self, state: Array, values: ParamDictType) -> Array: + for op in self.operators: + state = op.forward(state, values) + return state + + +@dataclass +class HorqruxObservable(HorqruxCircuit): + def __init__(self, operators: list[Gate]): + super().__init__(operators=operators) + + def _forward(self, state: Array, values: ParamDictType) -> Array: + for op in self.operators: + state = op.forward(state, values) + return state + + def forward(self, state: Array, values: ParamDictType) -> Array: + return overlap(state, self._forward(state, values)) + + +def convert_observable( + block: AbstractBlock, n_qubits: int, config: Configuration +) -> HorqruxObservable: + _ops = convert_block(block, n_qubits, config) + return HorqruxObservable(_ops) + + +def convert_block( + block: AbstractBlock, n_qubits: int = None, config: Configuration = Configuration() +) -> list: + if n_qubits is None: + n_qubits = max(block.qubit_support) + 1 + ops = [] + if isinstance(block, CompositeBlock): + ops = list(flatten(*(convert_block(b, n_qubits, config) for b in block.blocks))) + if isinstance(block, AddBlock): + ops = [HorqAddGate(ops)] + elif isinstance(block, ChainBlock): + ops = [HorqruxCircuit(ops)] + elif isinstance(block, KronBlock): + if all( + [ + isinstance(b, ParametricBlock) and not isinstance(b, ScaleBlock) + for b in block.blocks + ] + ): + param_names = [config.get_param_name(b)[0] for b in block.blocks if b.is_parametric] + ops = [ + HorqKronParametric( + gates=[ops_map[b.name] for b in block.blocks], + target=[b.qubit_support[0] for b in block.blocks], + param_names=param_names, + ) + ] + + elif all([b.name == "CNOT" for b in block.blocks]): + ops = [ + HorqKronCNOT( + gates=[ops_map[b.name] for b in block.blocks], + target=[b.qubit_support[1] for b in block.blocks], + control=[b.qubit_support[0] for b in block.blocks], + ) + ] + else: + ops = [HorqruxCircuit(ops)] + + elif isinstance(block, CNOT): + native_op = ops_map[block.name] + ops = [ + HorqCNOTGate(native_op, block.qubit_support[0], block.qubit_support[1]) + ] # in horqrux target and control are swapped + + elif isinstance(block, (CRX, CRY, CRZ)): + native_op = ops_map[block.name] + param_name = config.get_param_name(block)[0] + + ops = [ + HorqParametricGate( + gate=native_op, + qubit=block.qubit_support[1], + parameter_name=param_name, + control=block.qubit_support[0], + name=block.name, + ) + ] + elif isinstance(block, ScaleBlock): + op = convert_block(block.block, n_qubits, config=config)[0] + param_name = config.get_param_name(block)[0] + ops = [HorqScaleGate(op, param_name)] + + elif isinstance(block, ParametricBlock): + native_op = ops_map[block.name] + if len(block.parameters._uuid_dict) > 1: + raise NotImplementedError("Only single-parameter operations are supported.") + param_name = config.get_param_name(block)[0] + + ops = [ + HorqParametricGate( + gate=native_op, + qubit=block.qubit_support[0], + parameter_name=param_name, + ) + ] + + elif isinstance(block, PrimitiveBlock): + native_op = ops_map[block.name] + qubit = block.qubit_support[0] + ops = [HorqPrimitiveGate(gate=native_op, qubit=qubit, name=block.name)] + + else: + raise NotImplementedError(f"Non-supported operation of type {type(block)}.") + + return ops + + +class HorqPrimitiveGate: + def __init__(self, gate: Gate, qubit: int, name: str): + self.gates: Gate = gate + self.target = qubit + self.name = name + + def forward(self, state: Array, values: ParamDictType) -> Array: + return apply_gate(state, self.gates(self.target)) + + def __repr__(self) -> str: + return self.name + f"(target={self.target})" + + +class HorqCNOTGate: + def __init__(self, gate: Gate, control: int, target: int): + self.gates: Callable = gate + self.control: int = control + self.target: int = target + + def forward(self, state: Array, values: ParamDictType) -> Array: + return apply_gate(state, self.gates(self.target, self.control)) + + +class HorqKronParametric: + def __init__(self, gates: list[Gate], param_names: list[str], target: list[int]): + self.operators: list[Gate] = gates + self.target: list[int] = target + self.param_names: list[str] = param_names + + def forward(self, state: Array, values: ParamDictType) -> Array: + return apply_gate( + state, + tuple( + gate(values[param_name], target) + for gate, target, param_name in zip(self.operators, self.target, self.param_names) + ), + ) + + +class HorqKronCNOT(HorqruxCircuit): + def __init__(self, gates: list[Gate], target: list[int], control: list[int]): + self.operators: list[Gate] = gates + self.target: list[int] = target + self.control: list[int] = control + + def forward(self, state: Array, values: ParamDictType) -> Array: + return apply_gate( + state, + tuple( + gate(target, control) + for gate, target, control in zip(self.operators, self.target, self.control) + ), + ) + + +class HorqParametricGate: + def __init__( + self, gate: Gate, qubit: int, parameter_name: str, control: int = None, name: str = "" + ): + self.gates: Callable = gate + self.target: int = qubit + self.parameter: str = parameter_name + self.control: int | None = control + self.name = name + + def forward(self, state: Array, values: ParamDictType) -> Array: + val = jnp.array(values[self.parameter]) + return apply_gate(state, self.gates(val, self.target, self.control)) + + def __repr__(self) -> str: + return ( + self.name + + f"(target={self.target}, parameter={self.parameter}, control={self.control})" + ) + + +class HorqAddGate(HorqruxCircuit): + def __init__(self, operations: list[Gate]): + self.operators = operations + self.name = "Add" + + def forward(self, state: Array, values: ParamDictType = {}) -> Array: + return reduce(add, (op.forward(state, values) for op in self.operators)) + + def __repr__(self) -> str: + return self.name + f"({self.operators})" + + +class HorqScaleGate: + def __init__(self, op: Gate, parameter_name: str): + self.op = op + self.parameter: str = parameter_name + + def forward(self, state: Array, values: ParamDictType) -> Array: + return jnp.array(values[self.parameter]) * self.op.forward(state, values) diff --git a/qadence/backends/jax_utils.py b/qadence/backends/jax_utils.py new file mode 100644 index 00000000..0c013e84 --- /dev/null +++ b/qadence/backends/jax_utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Any + +import jax.numpy as jnp +from jax import Array, device_get +from sympy import Expr +from sympy2jax import SymbolicModule as JaxSympyModule +from torch import Tensor, cdouble, from_numpy + +from qadence.types import ParamDictType + + +def jarr_to_tensor(arr: Array, dtype: Any = cdouble) -> Tensor: + return from_numpy(device_get(arr)).to(dtype=dtype) + + +def tensor_to_jnp(tensor: Tensor, dtype: Any = jnp.complex128) -> Array: + return ( + jnp.array(tensor.numpy(), dtype=dtype) + if not tensor.requires_grad + else jnp.array(tensor.detach().numpy(), dtype=dtype) + ) + + +def values_to_jax(param_values: dict[str, Tensor]) -> dict[str, Array]: + return {key: jnp.array(value.detach().numpy()) for key, value in param_values.items()} + + +def jaxify(expr: Expr) -> JaxSympyModule: + return JaxSympyModule(expr) + + +def unhorqify(state: Array) -> Array: + """Convert a state of shape [2] * n_qubits + [batch_size] to (batch_size, 2**n_qubits).""" + return jnp.ravel(state) + + +def uniform_batchsize(param_values: ParamDictType) -> ParamDictType: + max_batch_size = max(p.size for p in param_values.values()) + batched_values = { + k: (v if v.size == max_batch_size else v.repeat(max_batch_size)) + for k, v in param_values.items() + } + return batched_values diff --git a/qadence/backends/pulser/backend.py b/qadence/backends/pulser/backend.py index 25300593..4dab6d12 100644 --- a/qadence/backends/pulser/backend.py +++ b/qadence/backends/pulser/backend.py @@ -28,7 +28,7 @@ from qadence.overlap import overlap_exact from qadence.register import Register from qadence.transpile import transpile -from qadence.types import BackendName, DeviceType, Endianness +from qadence.types import BackendName, DeviceType, Endianness, Engine from .channels import GLOBAL_CHANNEL, LOCAL_CHANNEL from .cloud import get_client @@ -159,6 +159,7 @@ class Backend(BackendInterface): with_noise: bool = False native_endianness: Endianness = Endianness.BIG config: Configuration = field(default_factory=Configuration) + engine: Engine = Engine.TORCH def circuit(self, circ: QuantumCircuit) -> Sequence: passes = self.config.transpilation_passes diff --git a/qadence/backends/pyqtorch/backend.py b/qadence/backends/pyqtorch/backend.py index e8938639..4e8ee532 100644 --- a/qadence/backends/pyqtorch/backend.py +++ b/qadence/backends/pyqtorch/backend.py @@ -30,7 +30,7 @@ scale_primitive_blocks_only, transpile, ) -from qadence.types import BackendName, Endianness +from qadence.types import BackendName, Endianness, Engine from qadence.utils import infer_batchsize, int_to_basis from .config import Configuration, default_passes @@ -52,6 +52,7 @@ class Backend(BackendInterface): with_noise: bool = False native_endianness: Endianness = Endianness.BIG config: Configuration = field(default_factory=Configuration) + engine: Engine = Engine.TORCH def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit: passes = self.config.transpilation_passes diff --git a/qadence/backends/pyqtorch/convert_ops.py b/qadence/backends/pyqtorch/convert_ops.py index f26afa99..85ac7c36 100644 --- a/qadence/backends/pyqtorch/convert_ops.py +++ b/qadence/backends/pyqtorch/convert_ops.py @@ -252,12 +252,12 @@ def sparse_operation(state: Tensor, values: dict[str, Tensor] = None) -> Tensor: convert_block(block, n_qubits, config), ) - def forward(self, state: Tensor, values: dict[str, Tensor]) -> Tensor: - return pyq.overlap(state, self.operation(state, values)) - def run(self, state: Tensor, values: dict[str, Tensor]) -> Tensor: return self.operation(state, values) + def forward(self, state: Tensor, values: dict[str, Tensor]) -> Tensor: + return pyq.overlap(state, self.run(state, values)) + class PyQHamiltonianEvolution(Module): def __init__( diff --git a/qadence/backends/utils.py b/qadence/backends/utils.py index d8428e05..c123cfe7 100644 --- a/qadence/backends/utils.py +++ b/qadence/backends/utils.py @@ -17,8 +17,8 @@ no_grad, rand, ) -from torch import flatten as torchflatten +from qadence.types import ParamDictType from qadence.utils import Endianness, int_to_basis, is_qadence_shape FINITE_DIFF_EPS = 1e-06 @@ -92,7 +92,7 @@ def count_bitstrings(sample: Tensor, endianness: Endianness = Endianness.BIG) -> ) -def to_list_of_dicts(param_values: dict[str, Tensor]) -> list[dict[str, float]]: +def to_list_of_dicts(param_values: ParamDictType) -> list[ParamDictType]: if not param_values: return [param_values] @@ -119,7 +119,7 @@ def pyqify(state: Tensor, n_qubits: int = None) -> Tensor: def unpyqify(state: Tensor) -> Tensor: """Convert a state of shape [2] * n_qubits + [batch_size] to (batch_size, 2**n_qubits).""" - return torchflatten(state, start_dim=0, end_dim=-2).t() + return torch.flatten(state, start_dim=0, end_dim=-2).t() def is_pyq_shape(state: Tensor, n_qubits: int) -> bool: @@ -141,6 +141,11 @@ def validate_state(state: Tensor, n_qubits: int) -> None: ) +def infer_batchsize(param_values: ParamDictType = None) -> int: + """Infer the batch_size through the length of the parameter tensors.""" + return max([len(tensor) for tensor in param_values.values()]) if param_values else 1 + + # The following functions can be used to compute potentially higher order gradients using pyqtorch's # native 'jacobian' methods. diff --git a/qadence/blocks/embedding.py b/qadence/blocks/embedding.py index 030abc92..8bd8fb72 100644 --- a/qadence/blocks/embedding.py +++ b/qadence/blocks/embedding.py @@ -2,11 +2,10 @@ from typing import Callable, Iterable, List -import numpy as np import sympy -import sympytorch # type: ignore [import] -import torch -from torch import Tensor +from numpy import array as nparray +from numpy import cdouble as npcdouble +from torch import tensor from qadence.blocks import ( AbstractBlock, @@ -16,9 +15,24 @@ parameters, uuid_to_expression, ) -from qadence.parameters import evaluate, stringify, torchify +from qadence.parameters import evaluate, make_differentiable, stringify +from qadence.types import ArrayLike, DifferentiableExpression, Engine, ParamDictType, TNumber -StrTensorDict = dict[str, Tensor] + +def _concretize_parameter(engine: Engine) -> Callable: + if engine == Engine.JAX: + from jax.numpy import array as jaxarray + from jax.numpy import float64 as jaxfloat64 + + def concretize_parameter(value: TNumber, trainable: bool = False) -> ArrayLike: + return jaxarray([value], dtype=jaxfloat64) + + else: + + def concretize_parameter(value: TNumber, trainable: bool = False) -> ArrayLike: + return tensor([value], requires_grad=trainable) + + return concretize_parameter def unique(x: Iterable) -> List: @@ -26,14 +40,13 @@ def unique(x: Iterable) -> List: def embedding( - block: AbstractBlock, to_gate_params: bool = False -) -> tuple[StrTensorDict, Callable[[StrTensorDict, StrTensorDict], StrTensorDict],]: - """Construct embedding function. + block: AbstractBlock, to_gate_params: bool = False, engine: Engine = Engine.TORCH +) -> tuple[ParamDictType, Callable[[ParamDictType, ParamDictType], ParamDictType],]: + """Construct embedding function which maps user-facing parameters to either *expression-level*. - It maps user-facing parameters to either *expression-level* - parameters or *gate-level* parameters. The construced embedding function has the signature: + parameters or *gate-level* parameters. The constructed embedding function has the signature: - embedding_fn(params: StrTensorDict, inputs: StrTensorDict) -> StrTensorDict: + embedding_fn(params: ParamDictType, inputs: ParamDictType) -> ParamDictType: which means that it maps the *variational* parameter dict `params` and the *feature* parameter dict `inputs` to one new parameter dict `embedded_dict` which holds all parameters that are @@ -56,6 +69,13 @@ def embedding( Returns: A tuple with variational parameter dict and the embedding function. """ + concretize_parameter = _concretize_parameter(engine) + if engine == Engine.TORCH: + cast_dtype = tensor + else: + from jax.numpy import array + + cast_dtype = array unique_expressions = unique(expressions(block)) unique_symbols = [p for p in unique(parameters(block)) if not isinstance(p, sympy.Array)] @@ -77,16 +97,18 @@ def embedding( # we dont need to care about constant symbols if they are contained in an symbolic expression # we only care about gate params which are ONLY a constant - embeddings: dict[sympy.Expr, sympytorch.SymPyModule] = { - expr: torchify(expr) for expr in unique_expressions if not expr.is_number + embeddings: dict[sympy.Expr, DifferentiableExpression] = { + expr: make_differentiable(expr=expr, engine=engine) + for expr in unique_expressions + if not expr.is_number } uuid_to_expr = uuid_to_expression(block) - def embedding_fn(params: StrTensorDict, inputs: StrTensorDict) -> StrTensorDict: - embedded_params: dict[sympy.Expr, Tensor] = {} + def embedding_fn(params: ParamDictType, inputs: ParamDictType) -> ParamDictType: + embedded_params: dict[sympy.Expr, ArrayLike] = {} for expr, fn in embeddings.items(): - angle: Tensor + angle: ArrayLike values = {} for symbol in expr.free_symbols: if symbol.name in inputs: @@ -112,26 +134,26 @@ def embedding_fn(params: StrTensorDict, inputs: StrTensorDict) -> StrTensorDict: embedded_params[e] = params[stringify(e)] if to_gate_params: - gate_lvl_params: StrTensorDict = {} + gate_lvl_params: ParamDictType = {} for uuid, e in uuid_to_expr.items(): gate_lvl_params[uuid] = embedded_params[e] return gate_lvl_params else: return {stringify(k): v for k, v in embedded_params.items()} - params: StrTensorDict - params = {p.name: torch.tensor([p.value], requires_grad=True) for p in trainable_symbols} + params: ParamDictType + params = { + p.name: concretize_parameter(value=p.value, trainable=True) for p in trainable_symbols + } params.update( { - stringify(expr): torch.tensor([evaluate(expr)], requires_grad=False) + stringify(expr): concretize_parameter(value=evaluate(expr), trainable=False) for expr in constant_expressions } ) params.update( { - stringify(expr): torch.tensor( - np.array(expr.tolist(), dtype=np.cdouble), requires_grad=False - ) + stringify(expr): cast_dtype(nparray(expr.tolist(), dtype=npcdouble)) for expr in unique_const_matrices } ) diff --git a/qadence/engines/__init__.py b/qadence/engines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/qadence/engines/differentiable_backend.py b/qadence/engines/differentiable_backend.py new file mode 100644 index 00000000..dd41258d --- /dev/null +++ b/qadence/engines/differentiable_backend.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import Counter +from dataclasses import dataclass +from typing import Any, Callable + +from qadence.backend import Backend, Converted, ConvertedCircuit, ConvertedObservable +from qadence.blocks import AbstractBlock, PrimitiveBlock +from qadence.blocks.utils import uuid_to_block +from qadence.circuit import QuantumCircuit +from qadence.measurements import Measurements +from qadence.mitigations import Mitigations +from qadence.noise import Noise +from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType + + +@dataclass(frozen=True, eq=True) +class DifferentiableBackend(ABC): + """The abstract class which wraps any (non)-natively differentiable QuantumBackend. + + in an automatic differentiation engine. + + Arguments: + backend: An instance of the QuantumBackend type perform execution. + engine: Which automatic differentiation engine the QuantumBackend runs on. + diff_mode: A differentiable mode supported by the differentiation engine. + """ + + backend: Backend + engine: Engine + diff_mode: DiffMode + + # TODO: Add differentiable overlap calculation + _overlap: Callable = None # type: ignore [assignment] + + def sample( + self, + circuit: ConvertedCircuit, + param_values: ParamDictType = {}, + n_shots: int = 100, + state: ArrayLike | None = None, + noise: Noise | None = None, + mitigation: Mitigations | None = None, + endianness: Endianness = Endianness.BIG, + ) -> list[Counter]: + """Sample bitstring from the registered circuit. + + Arguments: + circuit: A backend native quantum circuit to be executed. + param_values: The values of the parameters after embedding + n_shots: The number of shots. Defaults to 1. + state: Initial state. + noise: A noise model to use. + mitigation: A mitigation protocol to apply to noisy samples. + endianness: Endianness of the resulting bitstrings. + + Returns: + An iterable with all the sampled bitstrings + """ + + return self.backend.sample( + circuit=circuit, + param_values=param_values, + n_shots=n_shots, + state=state, + noise=noise, + mitigation=mitigation, + endianness=endianness, + ) + + def run( + self, + circuit: ConvertedCircuit, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + endianness: Endianness = Endianness.BIG, + ) -> ArrayLike: + """Run on the underlying backend.""" + return self.backend.run( + circuit=circuit, param_values=param_values, state=state, endianness=endianness + ) + + @abstractmethod + def expectation( + self, + circuit: ConvertedCircuit, + observable: list[ConvertedObservable] | ConvertedObservable, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + measurement: Measurements | None = None, + noise: Noise | None = None, + mitigation: Mitigations | None = None, + endianness: Endianness = Endianness.BIG, + ) -> Any: + """Compute the expectation value of the `circuit` with the given `observable`. + + Arguments: + circuit: A converted circuit as returned by `backend.circuit`. + observable: A converted observable as returned by `backend.observable`. + param_values: _**Already embedded**_ parameters of the circuit. See + [`embedding`][qadence.blocks.embedding.embedding] for more info. + state: Initial state. + measurement: Optional measurement protocol. If None, use + exact expectation value with a statevector simulator. + noise: A noise model to use. + mitigation: The error mitigation to use. + endianness: Endianness of the resulting bit strings. + """ + raise NotImplementedError( + "A DifferentiableBackend needs to override the expectation method." + ) + + def default_configuration(self) -> Any: + return self.backend.default_configuration() + + def circuit(self, circuit: QuantumCircuit) -> ConvertedCircuit: + if self.diff_mode == DiffMode.GPSR: + parametrized_blocks = list(uuid_to_block(circuit.block).values()) + non_prim_blocks = filter( + lambda b: not isinstance(b, PrimitiveBlock), parametrized_blocks + ) + if len(list(non_prim_blocks)) > 0: + raise ValueError( + "The circuit contains non-primitive blocks that are currently\ + not supported by the PSR differentiable mode." + ) + return self.backend.circuit(circuit) + + def observable(self, observable: AbstractBlock, n_qubits: int) -> ConvertedObservable: + if self.diff_mode != DiffMode.AD and observable is not None: + msg = ( + f"Differentiation mode '{self.diff_mode}' does not support parametric observables." + ) + if isinstance(observable, list): + for obs in observable: + if obs.is_parametric: + raise ValueError(msg) + else: + if observable.is_parametric: + raise ValueError(msg) + return self.backend.observable(observable, n_qubits) + + def convert( + self, + circuit: QuantumCircuit, + observable: list[AbstractBlock] | AbstractBlock | None = None, + ) -> Converted: + return self.backend.convert(circuit, observable) + + def assign_parameters(self, circuit: ConvertedCircuit, param_values: ParamDictType) -> Any: + return self.backend.assign_parameters(circuit, param_values) diff --git a/qadence/engines/jax/__init__.py b/qadence/engines/jax/__init__.py new file mode 100644 index 00000000..a2680d14 --- /dev/null +++ b/qadence/engines/jax/__init__.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from jax import config + +from .differentiable_backend import DifferentiableBackend +from .differentiable_expectation import DifferentiableExpectation + +config.update("jax_enable_x64", True) diff --git a/qadence/engines/jax/differentiable_backend.py b/qadence/engines/jax/differentiable_backend.py new file mode 100644 index 00000000..c65676f4 --- /dev/null +++ b/qadence/engines/jax/differentiable_backend.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from qadence.backend import Backend, ConvertedCircuit, ConvertedObservable +from qadence.engines.differentiable_backend import ( + DifferentiableBackend as DifferentiableBackendInterface, +) +from qadence.engines.jax.differentiable_expectation import DifferentiableExpectation +from qadence.measurements import Measurements +from qadence.mitigations import Mitigations +from qadence.noise import Noise +from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType + + +class DifferentiableBackend(DifferentiableBackendInterface): + """A class which wraps a QuantumBackend with the automatic differentation engine JAX. + + Arguments: + backend: An instance of the QuantumBackend type perform execution. + diff_mode: A differentiable mode supported by the differentiation engine. + **psr_args: Arguments that will be passed on to `DifferentiableExpectation`. + """ + + def __init__( + self, + backend: Backend, + diff_mode: DiffMode = DiffMode.AD, + **psr_args: int | float | None, + ) -> None: + super().__init__(backend=backend, engine=Engine.JAX, diff_mode=diff_mode) + self.psr_args = psr_args + + def expectation( + self, + circuit: ConvertedCircuit, + observable: list[ConvertedObservable] | ConvertedObservable, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + measurement: Measurements | None = None, + noise: Noise | None = None, + mitigation: Mitigations | None = None, + endianness: Endianness = Endianness.BIG, + ) -> ArrayLike: + """Compute the expectation value of the `circuit` with the given `observable`. + + Arguments: + circuit: A converted circuit as returned by `backend.circuit`. + observable: A converted observable as returned by `backend.observable`. + param_values: _**Already embedded**_ parameters of the circuit. See + [`embedding`][qadence.blocks.embedding.embedding] for more info. + state: Initial state. + measurement: Optional measurement protocol. If None, use + exact expectation value with a statevector simulator. + noise: A noise model to use. + mitigation: The error mitigation to use. + endianness: Endianness of the resulting bit strings. + """ + observable = observable if isinstance(observable, list) else [observable] + + if self.diff_mode == DiffMode.AD: + expectation = self.backend.expectation(circuit, observable, param_values, state) + else: + expectation = DifferentiableExpectation( + backend=self.backend, + circuit=circuit, + observable=observable, + param_values=param_values, + state=state, + measurement=measurement, + noise=noise, + mitigation=mitigation, + endianness=endianness, + ).psr() + return expectation diff --git a/qadence/engines/jax/differentiable_expectation.py b/qadence/engines/jax/differentiable_expectation.py new file mode 100644 index 00000000..c6f674df --- /dev/null +++ b/qadence/engines/jax/differentiable_expectation.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Tuple + +import jax.numpy as jnp +from jax import Array, custom_vjp, vmap + +from qadence.backend import Backend as QuantumBackend +from qadence.backend import ConvertedCircuit, ConvertedObservable +from qadence.backends.jax_utils import ( + tensor_to_jnp, +) +from qadence.blocks.utils import uuid_to_eigen +from qadence.measurements import Measurements +from qadence.mitigations import Mitigations +from qadence.noise import Noise +from qadence.types import Endianness, Engine, ParamDictType + + +def compute_single_gap(eigen_vals: Array, default_val: float = 2.0) -> Array: + eigen_vals = eigen_vals.reshape(1, 2) + gaps = jnp.abs(jnp.tril(eigen_vals.T - eigen_vals)) + return jnp.unique(jnp.where(gaps > 0.0, gaps, default_val), size=1) + + +@dataclass +class DifferentiableExpectation: + """A handler for differentiating expectation estimation using various engines.""" + + backend: QuantumBackend + circuit: ConvertedCircuit + observable: list[ConvertedObservable] | ConvertedObservable + param_values: ParamDictType + state: Array | None = None + measurement: Measurements | None = None + noise: Noise | None = None + mitigation: Mitigations | None = None + endianness: Endianness = Endianness.BIG + engine: Engine = Engine.JAX + + def psr(self) -> Any: + n_obs = len(self.observable) + + def expectation_fn(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Array: + return self.backend.expectation( + circuit=self.circuit, observable=self.observable, param_values=values, state=state + ) + + @custom_vjp + def expectation(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Array: + return expectation_fn(state, values, psr_params) + + uuid_to_eigs = { + k: tensor_to_jnp(v) for k, v in uuid_to_eigen(self.circuit.abstract.block).items() + } + self.psr_params = { + k: self.param_values[k] for k in uuid_to_eigs.keys() + } # Subset of params on which to perform PSR. + + def expectation_fwd(state: Array, values: ParamDictType, psr_params: ParamDictType) -> Any: + return expectation_fn(state, values, psr_params), ( + state, + values, + psr_params, + ) + + def expectation_bwd(res: Tuple[Array, ParamDictType, ParamDictType], tangent: Array) -> Any: + state, values, psr_params = res + grads = {} + # FIXME Hardcoding the single spectral_gap to 2. + spectral_gap = 2.0 + shift = jnp.pi / 2 + + def shift_circ(param_name: str, values: dict) -> Array: + shifted_values = values.copy() + shiftvals = jnp.array( + [shifted_values[param_name] + shift, shifted_values[param_name] - shift] + ) + + def _expectation(val: Array) -> Array: + shifted_values[param_name] = val + return expectation(state, shifted_values, psr_params) + + return vmap(_expectation, in_axes=(0,))(shiftvals) + + for param_name, _ in psr_params.items(): + f_plus, f_min = shift_circ(param_name, values) + grad = spectral_gap * (f_plus - f_min) / (4.0 * jnp.sin(spectral_gap * shift / 2.0)) + grads[param_name] = jnp.sum(tangent * grad, axis=1) if n_obs > 1 else tangent * grad + return None, None, grads + + expectation.defvjp(expectation_fwd, expectation_bwd) + return expectation(self.state, self.param_values, self.psr_params) diff --git a/qadence/engines/torch/__init__.py b/qadence/engines/torch/__init__.py new file mode 100644 index 00000000..0c857f8f --- /dev/null +++ b/qadence/engines/torch/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from .differentiable_backend import DifferentiableBackend +from .differentiable_expectation import DifferentiableExpectation diff --git a/qadence/engines/torch/differentiable_backend.py b/qadence/engines/torch/differentiable_backend.py new file mode 100644 index 00000000..c1f50b27 --- /dev/null +++ b/qadence/engines/torch/differentiable_backend.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from functools import partial + +from qadence.backend import Backend as QuantumBackend +from qadence.backend import ConvertedCircuit, ConvertedObservable +from qadence.engines.differentiable_backend import ( + DifferentiableBackend as DifferentiableBackendInterface, +) +from qadence.engines.torch.differentiable_expectation import DifferentiableExpectation +from qadence.extensions import get_gpsr_fns +from qadence.measurements import Measurements +from qadence.mitigations import Mitigations +from qadence.noise import Noise +from qadence.types import ArrayLike, DiffMode, Endianness, Engine, ParamDictType + + +class DifferentiableBackend(DifferentiableBackendInterface): + """A class which wraps a QuantumBackend with the automatic differentation engine TORCH. + + Arguments: + backend: An instance of the QuantumBackend type perform execution. + diff_mode: A differentiable mode supported by the differentiation engine. + **psr_args: Arguments that will be passed on to `DifferentiableExpectation`. + """ + + def __init__( + self, + backend: QuantumBackend, + diff_mode: DiffMode = DiffMode.AD, + **psr_args: int | float | None, + ) -> None: + super().__init__(backend=backend, engine=Engine.TORCH, diff_mode=diff_mode) + self.psr_args = psr_args + + def expectation( + self, + circuit: ConvertedCircuit, + observable: list[ConvertedObservable] | ConvertedObservable, + param_values: ParamDictType = {}, + state: ArrayLike | None = None, + measurement: Measurements | None = None, + noise: Noise | None = None, + mitigation: Mitigations | None = None, + endianness: Endianness = Endianness.BIG, + ) -> ArrayLike: + """Compute the expectation value of the `circuit` with the given `observable`. + + Arguments: + circuit: A converted circuit as returned by `backend.circuit`. + observable: A converted observable as returned by `backend.observable`. + param_values: _**Already embedded**_ parameters of the circuit. See + [`embedding`][qadence.blocks.embedding.embedding] for more info. + state: Initial state. + measurement: Optional measurement protocol. If None, use + exact expectation value with a statevector simulator. + noise: A noise model to use. + mitigation: The error mitigation to use. + endianness: Endianness of the resulting bit strings. + """ + observable = observable if isinstance(observable, list) else [observable] + differentiable_expectation = DifferentiableExpectation( + backend=self.backend, + circuit=circuit, + observable=observable, + param_values=param_values, + state=state, + measurement=measurement, + noise=noise, + mitigation=mitigation, + endianness=endianness, + ) + + if self.diff_mode == DiffMode.AD: + expectation = differentiable_expectation.ad + elif self.diff_mode == DiffMode.ADJOINT: + expectation = differentiable_expectation.adjoint + else: + try: + fns = get_gpsr_fns() + psr_fn = fns[self.diff_mode] + except KeyError: + raise ValueError(f"{self.diff_mode} differentiation mode is not supported") + expectation = partial(differentiable_expectation.psr, psr_fn=psr_fn, **self.psr_args) + return expectation() diff --git a/qadence/backends/pytorch_wrapper.py b/qadence/engines/torch/differentiable_expectation.py similarity index 99% rename from qadence/backends/pytorch_wrapper.py rename to qadence/engines/torch/differentiable_expectation.py index 34c99850..096474e4 100644 --- a/qadence/backends/pytorch_wrapper.py +++ b/qadence/engines/torch/differentiable_expectation.py @@ -13,7 +13,7 @@ from qadence.backend import Backend as QuantumBackend from qadence.backend import Converted, ConvertedCircuit, ConvertedObservable from qadence.backends.adjoint import AdjointExpectation -from qadence.backends.utils import is_pyq_shape, param_dict, pyqify, validate_state +from qadence.backends.utils import infer_batchsize, is_pyq_shape, param_dict, pyqify, validate_state from qadence.blocks.abstract import AbstractBlock from qadence.blocks.primitive import PrimitiveBlock from qadence.blocks.utils import uuid_to_block, uuid_to_eigen @@ -24,7 +24,6 @@ from qadence.ml_tools import promote_to_tensor from qadence.noise import Noise from qadence.types import DiffMode, Endianness -from qadence.utils import infer_batchsize class PSRExpectation(Function): diff --git a/qadence/extensions.py b/qadence/extensions.py index dc9cc99f..0ad97b69 100644 --- a/qadence/extensions.py +++ b/qadence/extensions.py @@ -6,15 +6,30 @@ from qadence.backend import Backend from qadence.blocks.abstract import TAbstractBlock from qadence.logger import get_logger -from qadence.types import BackendName, DiffMode +from qadence.types import BackendName, DiffMode, Engine backends_namespace = Template("qadence.backends.$name") logger = get_logger(__name__) +def _available_engines() -> dict: + """Returns a dictionary of currently installed, native qadence engines.""" + res = {} + for engine in Engine.list(): + module_path = f"qadence.engines.{engine}.differentiable_backend" + try: + module = importlib.import_module(module_path) + DifferentiableBackendCls = getattr(module, "DifferentiableBackend") + res[engine] = DifferentiableBackendCls + except (ImportError, ModuleNotFoundError): + pass + logger.info(f"Found engines: {res.keys()}") + return res + + def _available_backends() -> dict: - """Fallback function for native Qadence available backends if extensions is not present.""" + """Returns a dictionary of currently installed, native qadence backends.""" res = {} for backend in BackendName.list(): module_path = f"qadence.backends.{backend}.backend" @@ -24,11 +39,12 @@ def _available_backends() -> dict: res[backend] = BackendCls except (ImportError, ModuleNotFoundError): pass + logger.info(f"Found backends: {res.keys()}") return res def _supported_gates(name: BackendName | str) -> list[TAbstractBlock]: - """Fallback function for native Qadence backend supported gates if extensions is not present.""" + """Returns a list of supported gates for the queried backend 'name'.""" from qadence import operations name = str(BackendName(name).name.lower()) @@ -102,6 +118,7 @@ def _set_backend_config(backend: Backend, diff_mode: DiffMode) -> None: set_backend_config = getattr(module, "set_backend_config") except ModuleNotFoundError: available_backends = _available_backends + available_engines = _available_engines supported_gates = _supported_gates get_gpsr_fns = _gpsr_fns set_backend_config = _set_backend_config diff --git a/qadence/models/quantum_model.py b/qadence/models/quantum_model.py index 941a1e3e..8cc5dc23 100644 --- a/qadence/models/quantum_model.py +++ b/qadence/models/quantum_model.py @@ -19,6 +19,7 @@ from qadence.backends.api import backend_factory, config_factory from qadence.blocks.abstract import AbstractBlock from qadence.circuit import QuantumCircuit +from qadence.engines.differentiable_backend import DifferentiableBackend from qadence.logger import get_logger from qadence.measurements import Measurements from qadence.mitigations import Mitigations @@ -36,7 +37,7 @@ class QuantumModel(nn.Module): [here](/advanced_tutorials/custom-models.md). """ - backend: Backend + backend: Backend | DifferentiableBackend embedding_fn: Callable _params: nn.ParameterDict _circuit: ConvertedCircuit diff --git a/qadence/parameters.py b/qadence/parameters.py index 772a47da..31a21e2a 100644 --- a/qadence/parameters.py +++ b/qadence/parameters.py @@ -9,11 +9,11 @@ from sympy import * from sympy import Array, Basic, Expr, Symbol, sympify from sympy.physics.quantum.dagger import Dagger -from sympytorch import SymPyModule +from sympytorch import SymPyModule as torchSympyModule from torch import Tensor, heaviside, no_grad, rand, tensor from qadence.logger import get_logger -from qadence.types import TNumber +from qadence.types import DifferentiableExpression, Engine, TNumber # Modules to be automatically added to the qadence namespace __all__ = ["FeatureParameter", "Parameter", "VariationalParameter"] @@ -190,23 +190,26 @@ def extract_original_param_entry( return param if not param.is_number else evaluate(param) -def torchify(expr: Expr) -> SymPyModule: - """ - Arguments: +def heaviside_func(x: Tensor, _: Any) -> Tensor: + with no_grad(): + res = heaviside(x, tensor(0.5)) + return res - expr: An expression consisting of Parameters. - Returns: - A torchified, differentiable Expression. - """ +def torchify(expr: Expr) -> torchSympyModule: + extra_funcs = {sympy.core.numbers.ImaginaryUnit: 1.0j, sympy.Heaviside: heaviside_func} + return torchSympyModule(expressions=[sympy.N(expr)], extra_funcs=extra_funcs) - def heaviside_func(x: Tensor, _: Any) -> Tensor: - with no_grad(): - res = heaviside(x, tensor(0.5)) - return res - extra_funcs = {sympy.core.numbers.ImaginaryUnit: 1.0j, sympy.Heaviside: heaviside_func} - return SymPyModule(expressions=[sympy.N(expr)], extra_funcs=extra_funcs) +def make_differentiable(expr: Expr, engine: Engine = Engine.TORCH) -> DifferentiableExpression: + diff_expr: DifferentiableExpression + if engine == Engine.JAX: + from qadence.backends.jax_utils import jaxify + + diff_expr = jaxify(expr) + else: + diff_expr = torchify(expr) + return diff_expr def sympy_to_numeric(expr: Basic) -> TNumber: @@ -261,7 +264,7 @@ def evaluate(expr: Expr, values: dict = {}, as_torch: bool = False) -> TNumber | else: raise ValueError(f"No value provided for symbol {s.name}") if as_torch: - res_value = torchify(expr)(**{s.name: tensor(v) for s, v in query.items()}) + res_value = make_differentiable(expr)(**{s.name: tensor(v) for s, v in query.items()}) else: res = expr.subs(query) res_value = sympy_to_numeric(res) diff --git a/qadence/types.py b/qadence/types.py index 5ad9e46d..cb52479e 100644 --- a/qadence/types.py +++ b/qadence/types.py @@ -2,10 +2,11 @@ import importlib from enum import Enum -from typing import Iterable, Tuple, Union +from typing import Callable, Iterable, Tuple, Union import numpy as np import sympy +from numpy.typing import ArrayLike from torch import Tensor, pi TNumber = Union[int, float, complex] @@ -197,6 +198,8 @@ class _BackendName(StrEnum): """The Braket backend.""" PULSER = "pulser" """The Pulser backend.""" + HORQRUX = "horqrux" + """The horqrux backend.""" # If proprietary qadence_extensions is available, import the @@ -386,3 +389,12 @@ class OpName(StrEnum): class ReadOutOptimization(StrEnum): MLE = "mle" CONSTRAINED = "constrained" + + +class Engine(StrEnum): + TORCH = "torch" + JAX = "jax" + + +ParamDictType = dict[str, ArrayLike] +DifferentiableExpression = Callable[..., ArrayLike] diff --git a/tests/backends/horqrux/test_quantum_horqrux.py b/tests/backends/horqrux/test_quantum_horqrux.py new file mode 100644 index 00000000..1b389534 --- /dev/null +++ b/tests/backends/horqrux/test_quantum_horqrux.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import jax.numpy as jnp +import pytest +import sympy +import torch +from jax import Array, grad, jacrev, jit, value_and_grad, vmap + +from qadence import ( + CNOT, + CRX, + CRY, + CRZ, + RX, + RY, + RZ, + BackendName, + FeatureParameter, + QuantumCircuit, + VariationalParameter, + X, + Y, + Z, + chain, + expectation, + ising_hamiltonian, + run, + total_magnetization, +) +from qadence.backends import backend_factory +from qadence.backends.jax_utils import jarr_to_tensor +from qadence.blocks import AbstractBlock +from qadence.constructors import hea + + +def test_psr_firstOrder() -> None: + circ = QuantumCircuit(2, hea(2, 1)) + v_list = [] + grad_dict = {} + for diff_mode in ["ad", "gpsr"]: + hq_bknd = backend_factory("horqrux", diff_mode) + hq_circ, hq_obs, hq_fn, hq_params = hq_bknd.convert(circ, Z(0)) + embedded_params = hq_fn(hq_params, {}) + param_names = embedded_params.keys() + param_values = embedded_params.values() + param_array = jnp.array(jnp.concatenate([arr for arr in param_values])) + + def _exp_fn(values: Array) -> Array: + vals = {k: v for k, v in zip(param_names, values)} + return hq_bknd.expectation(hq_circ, hq_obs, vals) + + v, grads = value_and_grad(_exp_fn)(param_array) + v_list.append(v) + grad_dict[diff_mode] = grads + assert jnp.allclose(grad_dict["ad"], grad_dict["gpsr"]) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("obs", [Z(0), Z(0) + Z(1)]) +def test_psr_d3fdx(batch_size: int, obs: AbstractBlock) -> None: + n_qubits: int = 2 + circ = QuantumCircuit(n_qubits, RX(0, "theta")) + grad_dict = {} + for diff_mode in ["ad", "gpsr"]: + hq_bknd = backend_factory("horqrux", diff_mode) + hq_circ, hq_obs, hq_fn, hq_params = hq_bknd.convert(circ, obs) + embedded_params = hq_fn(hq_params, {}) + param_names = embedded_params.keys() + + def _exp_fn(value: Array) -> Array: + vals = {list(param_names)[0]: value} + return hq_bknd.expectation(hq_circ, hq_obs, vals) + + d1fdx = grad(_exp_fn) + d2fdx = grad(d1fdx) + d3fdx = grad(d2fdx) + jd3fdx = jit(d3fdx) + grad_dict[diff_mode] = vmap(jd3fdx, in_axes=(0,))(jnp.ones(batch_size)) + assert jnp.allclose(grad_dict["ad"], grad_dict["gpsr"]) + + +@pytest.mark.slow +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("obs", [Z(0)]) +def test_psr_2nd_order_mixed(batch_size: int, obs: AbstractBlock) -> None: + n_qubits: int = 2 + circ = QuantumCircuit(n_qubits, hea(n_qubits=n_qubits, depth=1)) + grad_dict = {} + for diff_mode in ["ad", "gpsr"]: + hq_bknd = backend_factory("horqrux", diff_mode) + hq_circ, hq_obs, hq_fn, hq_params = hq_bknd.convert(circ, obs) + hessian = jit( + jacrev(jacrev(lambda params: hq_bknd.expectation(hq_circ, hq_obs, hq_fn(params, {})))) + ) + grad_dict[diff_mode] = hessian(hq_params) + + def _allclose(d0: dict, d1: dict) -> None: + for (k0, dd0), (k1, dd1) in zip(d0.items(), d1.items()): + if isinstance(dd0, dict): + return _allclose(dd0, dd1) + else: + assert jnp.allclose(dd0, dd1) and k0 == k1 + + _allclose(grad_dict["ad"], grad_dict["gpsr"]) + + +@pytest.mark.parametrize("block", [RX(0, 1.0), RY(0, 3.0), RZ(0, 4.0), X(0), Y(0), Z(0)]) +def test_singlequbit_primitive_parametric(block: AbstractBlock) -> None: + wf_pyq = run(block, backend=BackendName.PYQTORCH) + wf_horq = jarr_to_tensor(run(block, backend=BackendName.HORQRUX)) + torch.allclose(wf_horq, wf_pyq) + + +@pytest.mark.parametrize( + "block", [CNOT(0, 1), CNOT(1, 0), CRX(0, 1, 1.0), CRY(1, 0, 2.0), CRZ(0, 1, 3.0)] +) +def test_control(block: AbstractBlock) -> None: + wf_pyq = run(block, backend=BackendName.PYQTORCH) + wf_horq = jarr_to_tensor(run(block, backend=BackendName.HORQRUX)) + torch.allclose(wf_horq, wf_pyq) + + +@pytest.mark.parametrize("block", [hea(2, 1), hea(4, 4)]) +def test_hea(block: AbstractBlock) -> None: + wf_pyq = run(block, backend=BackendName.PYQTORCH) + wf_horq = jarr_to_tensor(run(block, backend=BackendName.HORQRUX)) + torch.allclose(wf_horq, wf_pyq) + + +@pytest.mark.parametrize("block", [hea(2, 1), hea(4, 4)]) +def test_hea_expectation(block: AbstractBlock) -> None: + exp_pyq = expectation(block, Z(0), backend=BackendName.PYQTORCH) + exp_horq = jarr_to_tensor( + expectation(block, Z(0), backend=BackendName.HORQRUX), dtype=torch.double + ) + torch.allclose(exp_pyq, exp_horq) + + +def test_multiparam_multiobs() -> None: + n_qubits = 1 + obs = [total_magnetization(n_qubits), ising_hamiltonian(n_qubits)] + block = chain( + RX(0, sympy.cos(VariationalParameter("theta") * FeatureParameter("phi"))), RY(0, 2.0) + ) + values_jax = {"phi": jnp.array([2.0, 5.0])} + values_torch = {"phi": torch.tensor([2.0, 5.0])} + exp_pyq = expectation(block, obs, values=values_torch, backend=BackendName.PYQTORCH) + exp_horq = jarr_to_tensor( + expectation(block, obs, values=values_jax, backend=BackendName.HORQRUX), dtype=torch.double + ) + + torch.allclose(exp_pyq, exp_horq) diff --git a/tests/backends/pulser_basic/test_differentiation.py b/tests/backends/pulser_basic/test_differentiation.py index 3b7ea4b3..a4400f6b 100644 --- a/tests/backends/pulser_basic/test_differentiation.py +++ b/tests/backends/pulser_basic/test_differentiation.py @@ -5,11 +5,12 @@ import torch from metrics import PULSER_GPSR_ACCEPTANCE -from qadence import DifferentiableBackend, DiffMode, Parameter, QuantumCircuit +from qadence import DiffMode, Parameter, QuantumCircuit from qadence.backends.pulser import Backend as PulserBackend from qadence.backends.pyqtorch import Backend as PyQBackend from qadence.blocks import AbstractBlock, chain from qadence.constructors import total_magnetization +from qadence.engines.torch.differentiable_backend import DifferentiableBackend from qadence.operations import RX, RY, AnalogRot, AnalogRX, wait from qadence.register import Register diff --git a/tests/backends/test_backends.py b/tests/backends/test_backends.py index b7f0ced6..81f2b5e9 100644 --- a/tests/backends/test_backends.py +++ b/tests/backends/test_backends.py @@ -8,11 +8,13 @@ import sympy import torch from hypothesis import given, settings +from jax import Array from metrics import ATOL_DICT, JS_ACCEPTANCE # type: ignore from torch import Tensor from qadence.backend import BackendConfiguration from qadence.backends.api import backend_factory, config_factory +from qadence.backends.jax_utils import jarr_to_tensor, tensor_to_jnp from qadence.blocks import AbstractBlock, chain, kron from qadence.circuit import QuantumCircuit from qadence.constructors import total_magnetization @@ -66,7 +68,7 @@ def flatten_counter(c: Counter | list[Counter]) -> Counter: def test_simple_circuits(backend: str, circuit: QuantumCircuit) -> None: bknd = backend_factory(backend=backend) wf = bknd.run(bknd.circuit(circuit)) - assert isinstance(wf, Tensor) + assert isinstance(wf, (Tensor, Array)) def test_expectation_value(parametric_circuit: QuantumCircuit) -> None: @@ -77,6 +79,8 @@ def test_expectation_value(parametric_circuit: QuantumCircuit) -> None: for b in BACKENDS: bkd = backend_factory(backend=b, diff_mode=None) conv = bkd.convert(parametric_circuit, observable) + if b == BackendName.HORQRUX: + values = {k: tensor_to_jnp(v) for k, v in values.items()} expval = bkd.expectation( conv.circuit, conv.observable, conv.embedding_fn(conv.params, values) # type: ignore ) @@ -91,7 +95,17 @@ def test_expectation_value(parametric_circuit: QuantumCircuit) -> None: assert np.all(np.isclose(wfs_np, wfs_np[0])) -@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize( + "backend", + [ + BackendName.PYQTORCH, + BackendName.BRAKET, + pytest.param( + BackendName.HORQRUX, + marks=pytest.mark.xfail(reason="Horqrux uses JAX engine."), + ), + ], +) def test_qcl_loss(backend: str) -> None: np.random.seed(42) torch.manual_seed(42) @@ -131,6 +145,10 @@ def get_training_data(domain: tuple = (-0.99, 0.99), n_teacher: int = 30) -> tup "backend", [ BackendName.PYQTORCH, + pytest.param( + BackendName.HORQRUX, + marks=pytest.mark.xfail(reason="horqrux doesnt support batching of states."), + ), pytest.param( BackendName.BRAKET, marks=pytest.mark.xfail(reason="state-vector initial state not implemented in Braket"), @@ -197,7 +215,11 @@ def test_run_for_random_circuit(backend: BackendName, circuit: QuantumCircuit) - (circ, _, embed, params) = bknd.convert(circuit) inputs = rand_featureparameters(circuit, 1) wf_pyqtorch = bknd_pyqtorch.run(circ_pyqtorch, embed_pyqtorch(params_pyqtorch, inputs)) + if inputs and backend == BackendName.HORQRUX: + inputs = {k: tensor_to_jnp(v) for k, v in inputs.items()} wf = bknd.run(circ, embed(params, inputs)) + if backend == BackendName.HORQRUX: + wf = jarr_to_tensor(wf) assert equivalent_state(wf_pyqtorch, wf, atol=ATOL_DICT[backend]) @@ -215,6 +237,8 @@ def test_sample_for_random_circuit(backend: BackendName, circuit: QuantumCircuit pyqtorch_samples = bknd_pyqtorch.sample( circ_pyqtorch, embed_pyqtorch(params_pyqtorch, inputs), n_shots=100 ) + if inputs and backend == BackendName.HORQRUX: + inputs = {k: tensor_to_jnp(v) for k, v in inputs.items()} samples = bknd.sample(circ, embed(params, inputs), n_shots=100) for pyqtorch_sample, sample in zip(pyqtorch_samples, samples): @@ -241,7 +265,12 @@ def test_expectation_for_random_circuit( pyqtorch_expectation = bknd_pyqtorch.expectation( circ_pyqtorch, obs_pyqtorch, embed_pyqtorch(params_pyqtorch, inputs) )[0] - expectation = bknd.expectation(circ, obs, embed(params, inputs))[0] + if inputs and backend == BackendName.HORQRUX: + inputs = {k: tensor_to_jnp(v) for k, v in inputs.items()} + + expectation = bknd.expectation(circ, obs, embed(params, inputs)) + if backend == BackendName.HORQRUX: + expectation = jarr_to_tensor(expectation, dtype=torch.double) assert torch.allclose(pyqtorch_expectation, expectation, atol=ATOL_DICT[backend]) @@ -253,8 +282,12 @@ def test_compare_run_to_sample(backend: BackendName, circuit: QuantumCircuit) -> bknd = backend_factory(backend) (conv_circ, _, embed, params) = bknd.convert(circuit) inputs = rand_featureparameters(circuit, 1) + if inputs and backend == BackendName.HORQRUX: + inputs = {k: tensor_to_jnp(v) for k, v in inputs.items()} samples = bknd.sample(conv_circ, embed(params, inputs), n_shots=1000) wf = bknd.run(conv_circ, embed(params, inputs)) + if backend == BackendName.HORQRUX: + wf = jarr_to_tensor(wf) probs = list(torch.abs(torch.pow(wf, 2)).flatten().detach().numpy()) bitstrngs = nqubits_to_basis(circuit.n_qubits) wf_counter = Counter( diff --git a/tests/backends/test_endianness.py b/tests/backends/test_endianness.py index 6e9fd4f0..54000b9c 100644 --- a/tests/backends/test_endianness.py +++ b/tests/backends/test_endianness.py @@ -2,6 +2,7 @@ from typing import Counter +import jax.numpy as jnp import pytest import strategies as st # type: ignore from hypothesis import given, settings @@ -10,6 +11,7 @@ from qadence import QuantumCircuit, block_to_tensor, run, sample from qadence.backends.api import backend_factory +from qadence.backends.jax_utils import jarr_to_tensor, tensor_to_jnp from qadence.blocks import AbstractBlock, MatrixBlock, chain, kron from qadence.divergences import js_divergence from qadence.ml_tools.utils import rand_featureparameters @@ -136,6 +138,8 @@ def test_backend_wf_endianness(circ: QuantumCircuit, truth: Tensor, backend: Bac wf = run(circ, {}, backend=backend, endianness=endianness) if endianness == Endianness.LITTLE: truth = invert_endianness(truth) + if backend == BackendName.HORQRUX: + wf = jarr_to_tensor(wf) assert equivalent_state(wf, truth, atol=ATOL_DICT[backend]) @@ -257,21 +261,28 @@ def test_sample_inversion_for_random_circuit(backend: str, circuit: QuantumCircu bknd = backend_factory(backend=backend) (circ, _, embed, params) = bknd.convert(circuit) inputs = rand_featureparameters(circuit, 1) + if backend == BackendName.HORQRUX: + inputs = {k: tensor_to_jnp(v, dtype=jnp.float64) for k, v in inputs.items()} for endianness in Endianness: samples = bknd.sample(circ, embed(params, inputs), n_shots=100, endianness=endianness) for _sample in samples: - double_inv_wf = invert_endianness(invert_endianness(_sample)) - assert js_divergence(double_inv_wf, _sample) < JS_ACCEPTANCE + double_inv_sample = invert_endianness(invert_endianness(_sample)) + assert js_divergence(double_inv_sample, _sample) < JS_ACCEPTANCE @given(st.restricted_circuits()) @settings(deadline=None) -@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize("backend", [BackendName.PYQTORCH, BackendName.BRAKET]) def test_wf_inversion_for_random_circuit(backend: str, circuit: QuantumCircuit) -> None: bknd = backend_factory(backend=backend) (circ, _, embed, params) = bknd.convert(circuit) inputs = rand_featureparameters(circuit, 1) + if backend == BackendName.HORQRUX: + inputs = {k: tensor_to_jnp(v, dtype=jnp.float64) for k, v in inputs.items()} for endianness in Endianness: wf = bknd.run(circ, embed(params, inputs), endianness=endianness) double_inv_wf = invert_endianness(invert_endianness(wf)) + if backend == BackendName.HORQRUX: + double_inv_wf = jarr_to_tensor(double_inv_wf) + wf = jarr_to_tensor(wf) assert equivalent_state(double_inv_wf, wf) diff --git a/tests/backends/test_gpsr.py b/tests/backends/test_gpsr.py index bcd2df36..778a10b7 100644 --- a/tests/backends/test_gpsr.py +++ b/tests/backends/test_gpsr.py @@ -8,11 +8,12 @@ import torch from metrics import GPSR_ACCEPTANCE, PSR_ACCEPTANCE -from qadence import DifferentiableBackend, DiffMode, Parameter, QuantumCircuit +from qadence import DiffMode, Parameter, QuantumCircuit from qadence.analog import add_background_hamiltonian from qadence.backends.pyqtorch import Backend as PyQBackend from qadence.blocks import add, chain from qadence.constructors import total_magnetization +from qadence.engines.torch.differentiable_backend import DifferentiableBackend from qadence.operations import CNOT, CRX, CRY, RX, RY, ConstantAnalogRotation, HamEvo, X, Y, Z from qadence.parameters import ParamMap from qadence.register import Register diff --git a/tests/backends/test_pytorch_wrapper.py b/tests/engines/test_torch.py similarity index 100% rename from tests/backends/test_pytorch_wrapper.py rename to tests/engines/test_torch.py diff --git a/tests/metrics.py b/tests/metrics.py index ad8d49ca..fff42c79 100644 --- a/tests/metrics.py +++ b/tests/metrics.py @@ -15,6 +15,7 @@ PULSER_GPSR_ACCEPTANCE = 6.0e-2 ATOL_DICT = { BackendName.PYQTORCH: ATOL_32, + BackendName.HORQRUX: ATOL_32, BackendName.PULSER: 1e-02, BackendName.BRAKET: 1e-02, }