Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature, Testing] Allow nested calls to ConcretizedCallable, more tests #275

Merged
merged 26 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c63ae8f
nested embedding
jpmoutinho Aug 30, 2024
4122c09
algebraic rules
jpmoutinho Aug 30, 2024
1a8b9bb
no embedding in scale
jpmoutinho Sep 2, 2024
9c427ba
revert passing ConcretizedCallable as a parameter
jpmoutinho Sep 2, 2024
38862c7
restore scales
jpmoutinho Sep 2, 2024
fe056b5
random embed helper
jpmoutinho Sep 3, 2024
5500adb
fix forward
jpmoutinho Sep 5, 2024
cb9864a
revert
jpmoutinho Sep 5, 2024
ac7e7a4
fixed HamiltonianEvolution
vytautas-a Sep 23, 2024
bf58c2c
nested embedding
jpmoutinho Aug 30, 2024
684e429
algebraic rules
jpmoutinho Aug 30, 2024
e6e1adf
no embedding in scale
jpmoutinho Sep 2, 2024
00664d4
revert passing ConcretizedCallable as a parameter
jpmoutinho Sep 2, 2024
899b507
restore scales
jpmoutinho Sep 2, 2024
dac0ce1
random embed helper
jpmoutinho Sep 3, 2024
036b169
fix forward
jpmoutinho Sep 5, 2024
8439656
revert
jpmoutinho Sep 5, 2024
866e125
fixed failing tests
vytautas-a Sep 24, 2024
66d6f25
added custom differentiable Heaviside function implementation
vytautas-a Oct 2, 2024
a1acc3b
refactor Heaviside function
vytautas-a Oct 2, 2024
2e9f781
remove unnecessary code
vytautas-a Oct 3, 2024
5416ee9
added tests
vytautas-a Oct 4, 2024
140654e
add utils tests; code cleanup
vytautas-a Oct 7, 2024
633c70a
re-introduced ConcretizedCallable as possible parameter for parametri…
vytautas-a Oct 10, 2024
70a5d50
added test for ConcretizedCallable param in parametric gates
vytautas-a Oct 10, 2024
4d8c3dc
added parameter type checking for parametrized gates and operations
vytautas-a Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "hatchling.build"
name = "pyqtorch"
description = "An efficient, large-scale emulator designed for quantum machine learning, seamlessly integrated with a PyTorch backend. Please refer to https://pyqtorch.readthedocs.io/en/latest/ for setup and usage info, along with the full documentation."
readme = "README.md"
version = "1.4.9"
version = "1.5.0"
requires-python = ">=3.8,<3.13"
license = { text = "Apache 2.0" }
keywords = ["quantum"]
Expand Down Expand Up @@ -117,7 +117,7 @@ serve = "mkdocs serve --dev-addr localhost:8000"

[tool.ruff]
lint.select = ["E", "F", "I", "Q"]
lint.extend-ignore = ["F841"]
lint.extend-ignore = ["F841", "E731"]
line-length = 100

[tool.ruff.lint.isort]
Expand Down
11 changes: 10 additions & 1 deletion pyqtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,16 @@
Scale,
Sequence,
)
from .embed import ConcretizedCallable, Embedding
from .embed import (
ConcretizedCallable,
Embedding,
cos,
log,
sin,
sqrt,
tan,
tanh,
)
from .hamiltonians import HamiltonianEvolution, Observable
from .noise import (
AmplitudeDamping,
Expand Down
41 changes: 26 additions & 15 deletions pyqtorch/composite/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import Module, ModuleList, ParameterDict

from pyqtorch.apply import apply_operator
from pyqtorch.embed import Embedding
from pyqtorch.embed import ConcretizedCallable, Embedding
from pyqtorch.matrices import add_batch_dim
from pyqtorch.primitives import CNOT, RX, RY, Parametric, Primitive
from pyqtorch.utils import (
Expand All @@ -35,7 +35,9 @@ class Scale(Sequence):
"""

def __init__(
self, operations: Union[Primitive, Sequence, Add], param_name: str | Tensor
self,
operations: Union[Primitive, Sequence, Add],
param_name: str | float | int | Tensor | ConcretizedCallable,
):
"""
Initializes a Scale object.
Expand All @@ -46,6 +48,11 @@ def __init__(
"""
if not isinstance(operations, (Primitive, Sequence, Add)):
raise ValueError("Scale only supports a single operation, Sequence or Add.")
if not isinstance(param_name, (str, int, float, Tensor, ConcretizedCallable)):
raise TypeError(
"Only str, int, float, Tensor or ConcretizedCallable types \
are supported for param_name"
)
super().__init__([operations])
self.param_name = param_name

Expand All @@ -69,12 +76,14 @@ def forward(
if embedding is not None:
values = embedding(values)

scale = (
values[self.param_name]
if isinstance(self.param_name, str)
else self.param_name
)
return scale * self.operations[0].forward(state, values, embedding)
if isinstance(self.param_name, str):
scale = values[self.param_name]
elif isinstance(self.param_name, Tensor):
scale = self.param_name
elif isinstance(self.param_name, ConcretizedCallable):
scale = self.param_name(values)
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved

return scale * self.operations[0].forward(state, values)

def tensor(
self,
Expand All @@ -97,12 +106,14 @@ def tensor(
if embedding is not None:
values = embedding(values)

scale = (
values[self.param_name]
if isinstance(self.param_name, str)
else self.param_name
)
return scale * self.operations[0].tensor(values, embedding, full_support)
if isinstance(self.param_name, str):
scale = values[self.param_name]
elif isinstance(self.param_name, (Tensor, int, float)):
scale = self.param_name
elif isinstance(self.param_name, ConcretizedCallable):
scale = self.param_name(values)
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved

return scale * self.operations[0].tensor(values, full_support=full_support)

def flatten(self) -> list[Scale]:
"""This method should only be called in the AdjointExpectation,
Expand All @@ -121,7 +132,7 @@ def to(self, *args: Any, **kwargs: Any) -> Scale:
Converted Scale.
"""
super().to(*args, **kwargs)
if not isinstance(self.param_name, str):
if not isinstance(self.param_name, (str, float, int)):
self.param_name = self.param_name.to(*args, **kwargs)

return self
Expand Down
4 changes: 2 additions & 2 deletions pyqtorch/differentiation/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def backward(ctx: Any, grad_out: Tensor) -> Tuple[None, ...]:
grad_out * 2 * inner_prod(ctx.projected_state, mu).real
)

if values[op.param_name].requires_grad:
grads_dict[op.param_name] = grad_out * 2 * -values[op.param_name]
if values[op.param_name].requires_grad: # type: ignore [index]
grads_dict[op.param_name] = grad_out * 2 * -values[op.param_name] # type: ignore [index]
ctx.projected_state = apply_operator(
ctx.projected_state,
op.dagger(values, ctx.embedding),
Expand Down
120 changes: 116 additions & 4 deletions pyqtorch/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Tuple

from numpy.typing import ArrayLike, DTypeLike
from torch import Tensor

logger = getLogger(__name__)

Expand All @@ -21,7 +22,7 @@
"sub": ("jax.numpy", "subtract"),
"div": ("jax.numpy", "divide"),
}
DEFAULT_TORCH_MAPPING: dict = dict()
DEFAULT_TORCH_MAPPING = {"hs": ("pyqtorch.utils", "heaviside")}
DEFAULT_NUMPY_MAPPING = {
"mul": ("numpy", "multiply"),
"sub": ("numpy", "subtract"),
Expand Down Expand Up @@ -74,8 +75,8 @@ class ConcretizedCallable:

def __init__(
self,
call_name: str,
abstract_args: list[str | float | int],
call_name: str = "",
abstract_args: list[str | float | int | complex | ConcretizedCallable] = ["x"],
instruction_mapping: dict[str, Tuple[str, str]] = dict(),
engine_name: str = "torch",
device: str = "cpu",
Expand All @@ -92,6 +93,16 @@ def __init__(
self._dtype = dtype
self.engine_call = None
engine = None
if not all(
[
isinstance(arg, (str, float, int, complex, Tensor, ConcretizedCallable))
for arg in abstract_args
]
):
raise TypeError(
"Only str, float, int, complex, Tensor or ConcretizedCallable type elements \
are supported for abstract_args"
)
try:
engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name]
engine = import_module(engine_name)
Expand All @@ -113,17 +124,94 @@ def __init__(
def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
arraylike_args = []
for symbol_or_numeric in self.abstract_args:
if isinstance(symbol_or_numeric, (float, int)):
if isinstance(symbol_or_numeric, ConcretizedCallable):
arraylike_args.append(symbol_or_numeric(inputs))
if isinstance(symbol_or_numeric, (float, int, Tensor)):
arraylike_args.append(
self.arraylike_fn(symbol_or_numeric, device=self.device)
)
elif isinstance(symbol_or_numeric, str):
arraylike_args.append(inputs[symbol_or_numeric])
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved
return self.engine_call(*arraylike_args) # type: ignore[misc]

@classmethod
def _get_independent_args(cls, cc: ConcretizedCallable) -> set:
out: set = set()
if len(cc.abstract_args) == 1 and isinstance(cc.abstract_args[0], str):
return set([cc.abstract_args[0]])
else:
for arg in cc.abstract_args:
if isinstance(arg, ConcretizedCallable):
res = cls._get_independent_args(arg)
out = out.union(res)
else:
if isinstance(arg, str):
out.add(arg)
return out

@property
def independent_args(self) -> list:
return list(self._get_independent_args(self))

def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
return self.evaluate(inputs)

def __mul__(
self, other: str | int | float | complex | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("mul", [self, other])

def __rmul__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("mul", [other, self])

def __add__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("add", [self, other])

def __radd__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("add", [other, self])

def __sub__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("sub", [self, other])

def __rsub__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("sub", [other, self])

def __pow__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("pow", [self, other])

def __rpow__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("pow", [other, self])

def __truediv__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("div", [self, other])

def __rtruediv__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("div", [other, self])

def __repr__(self) -> str:
return f"{self.call_name}({self.abstract_args})"

def __neg__(self) -> ConcretizedCallable:
return -1 * self

@property
def device(self) -> str:
return self._device
Expand All @@ -150,6 +238,30 @@ def init_param(
return engine.random.uniform(0, 1)


def sin(x: str | ConcretizedCallable):
return ConcretizedCallable("sin", [x])


def cos(x: str | ConcretizedCallable):
return ConcretizedCallable("cos", [x])


def log(x: str | ConcretizedCallable):
return ConcretizedCallable("log", [x])


def tan(x: str | ConcretizedCallable):
return ConcretizedCallable("tan", [x])


def tanh(x: str | ConcretizedCallable):
return ConcretizedCallable("tanh", [x])


def sqrt(x: str | ConcretizedCallable):
return ConcretizedCallable("sqrt", [x])


class Embedding:
"""A class relating variational and feature parameters used in ConcretizedCallable instances to
parameter names used in gates.
Expand Down
Loading