-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dd4e621
commit cf5b3fb
Showing
4 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# flake8: noqa F401 | ||
from __future__ import annotations | ||
|
||
from .api import backend_factory, config_factory | ||
from .pytorch_wrapper import DifferentiableBackend, DiffMode | ||
|
||
# Modules to be automatically added to the qadence namespace | ||
__all__ = ["backend_factory", "config_factory", "DifferentiableBackend", "DiffMode"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from __future__ import annotations | ||
|
||
from qadence.backend import Backend, BackendConfiguration | ||
from qadence.backends.pytorch_wrapper import DifferentiableBackend, DiffMode | ||
from qadence.extensions import available_backends, set_backend_config | ||
from qadence.types import BackendName | ||
|
||
__all__ = ["backend_factory", "config_factory"] | ||
|
||
|
||
def backend_factory( | ||
backend: BackendName | str, | ||
diff_mode: DiffMode | str | None = None, | ||
configuration: BackendConfiguration | dict | None = None, | ||
) -> Backend | DifferentiableBackend: | ||
backend_inst: Backend | DifferentiableBackend | ||
backend_name = BackendName(backend) | ||
backends = available_backends() | ||
|
||
try: | ||
BackendCls = backends[backend_name] | ||
except (KeyError, ValueError): | ||
raise NotImplementedError(f"The requested backend '{backend_name}' is not implemented.") | ||
|
||
default_config = BackendCls.default_configuration() | ||
if configuration is None: | ||
configuration = default_config | ||
elif isinstance(configuration, dict): | ||
configuration = config_factory(backend_name, configuration) | ||
else: | ||
# NOTE: types have to match exactly, hence we use `type` | ||
if not isinstance(configuration, type(BackendCls.default_configuration())): | ||
raise ValueError( | ||
f"Given config class '{type(configuration)}' does not match the backend", | ||
f" class: '{BackendCls}'. Expected: '{type(BackendCls.default_configuration())}.'", | ||
) | ||
|
||
# Create the backend | ||
backend_inst = BackendCls( | ||
config=configuration | ||
if configuration is not None | ||
else BackendCls.default_configuration() # type: ignore[attr-defined] | ||
) | ||
|
||
# Set backend configurations which depend on the differentiation mode | ||
set_backend_config(backend_inst, diff_mode) | ||
|
||
if diff_mode is not None: | ||
backend_inst = DifferentiableBackend(backend_inst, DiffMode(diff_mode)) | ||
return backend_inst | ||
|
||
|
||
def config_factory(name: BackendName | str, config: dict) -> BackendConfiguration: | ||
backends = available_backends() | ||
|
||
try: | ||
BackendCls = backends[BackendName(name)] | ||
except KeyError: | ||
raise NotImplementedError(f"The requested backend '{name}' is not implemented!") | ||
|
||
BackendConfigCls = type(BackendCls.default_configuration()) | ||
return BackendConfigCls(**config) # type: ignore[no-any-return] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from __future__ import annotations | ||
|
||
from collections import Counter | ||
from typing import Sequence | ||
|
||
import numpy as np | ||
import torch | ||
from torch import Tensor | ||
|
||
from qadence.utils import Endianness, int_to_basis | ||
|
||
# Dict of NumPy dtype -> torch dtype (when the correspondence exists) | ||
numpy_to_torch_dtype_dict = { | ||
np.bool_: torch.bool, | ||
np.uint8: torch.uint8, | ||
np.int8: torch.int8, | ||
np.int16: torch.int16, | ||
np.int32: torch.int32, | ||
np.int64: torch.int64, | ||
np.float16: torch.float16, | ||
np.float32: torch.float32, | ||
np.float64: torch.float64, | ||
np.complex64: torch.complex64, | ||
np.complex128: torch.complex128, | ||
int: torch.int64, | ||
float: torch.float64, | ||
complex: torch.complex128, | ||
} | ||
|
||
|
||
def param_dict(keys: Sequence[str], values: Sequence[Tensor]) -> dict[str, Tensor]: | ||
return {key: val for key, val in zip(keys, values)} | ||
|
||
|
||
def numpy_to_tensor( | ||
x: np.ndarray, | ||
device: torch.device = torch.device("cpu"), | ||
dtype: torch.dtype = torch.complex128, | ||
requires_grad: bool = False, | ||
) -> Tensor: | ||
"""This only copies the numpy array if device or dtype are different than the ones of x.""" | ||
return torch.as_tensor(x, dtype=dtype, device=device).requires_grad_(requires_grad) | ||
|
||
|
||
def promote_to_tensor( | ||
x: Tensor | np.ndarray | float, | ||
dtype: torch.dtype = torch.complex128, | ||
requires_grad: bool = True, | ||
) -> Tensor: | ||
"""Convert the given type inco a torch.Tensor""" | ||
if isinstance(x, float): | ||
return torch.tensor([[x]], dtype=dtype, requires_grad=requires_grad) | ||
elif isinstance(x, np.ndarray): | ||
return numpy_to_tensor( | ||
x, dtype=numpy_to_torch_dtype_dict.get(x.dtype), requires_grad=requires_grad | ||
) | ||
elif isinstance(x, Tensor): | ||
return x.requires_grad_(requires_grad) | ||
else: | ||
raise ValueError(f"Don't know how to promote {type(x)} to Tensor") | ||
|
||
|
||
# FIXME: Not being used, maybe remove in v1.0.0 | ||
def count_bitstrings(sample: Tensor, endianness: Endianness = Endianness.BIG) -> Counter: | ||
# Convert to a tensor of integers. | ||
n_qubits = sample.size()[1] | ||
base = torch.ones(n_qubits, dtype=torch.int64) * 2 | ||
powers_of_2 = torch.pow(base, reversed(torch.arange(n_qubits))) | ||
int_tensor = torch.matmul(sample, powers_of_2) | ||
# Count occurences of integers. | ||
count_int = torch.bincount(int_tensor) | ||
# Return a Counter for non-empty bitstring counts. | ||
return Counter( | ||
{ | ||
int_to_basis(k=k, n_qubits=n_qubits, endianness=endianness): count.item() | ||
for k, count in enumerate(count_int) | ||
if count > 0 | ||
} | ||
) | ||
|
||
|
||
def to_list_of_dicts(param_values: dict[str, Tensor]) -> list[dict[str, float]]: | ||
if not param_values: | ||
return [param_values] | ||
|
||
max_batch_size = max(p.size()[0] for p in param_values.values()) | ||
batched_values = { | ||
k: (v if v.size()[0] == max_batch_size else v.repeat(max_batch_size, 1)) | ||
for k, v in param_values.items() | ||
} | ||
|
||
return [{k: v[i] for k, v in batched_values.items()} for i in range(max_batch_size)] |