Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature, Testing] Allow nested calls to ConcretizedCallable, more tests #275

Merged
merged 26 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c63ae8f
nested embedding
jpmoutinho Aug 30, 2024
4122c09
algebraic rules
jpmoutinho Aug 30, 2024
1a8b9bb
no embedding in scale
jpmoutinho Sep 2, 2024
9c427ba
revert passing ConcretizedCallable as a parameter
jpmoutinho Sep 2, 2024
38862c7
restore scales
jpmoutinho Sep 2, 2024
fe056b5
random embed helper
jpmoutinho Sep 3, 2024
5500adb
fix forward
jpmoutinho Sep 5, 2024
cb9864a
revert
jpmoutinho Sep 5, 2024
ac7e7a4
fixed HamiltonianEvolution
vytautas-a Sep 23, 2024
bf58c2c
nested embedding
jpmoutinho Aug 30, 2024
684e429
algebraic rules
jpmoutinho Aug 30, 2024
e6e1adf
no embedding in scale
jpmoutinho Sep 2, 2024
00664d4
revert passing ConcretizedCallable as a parameter
jpmoutinho Sep 2, 2024
899b507
restore scales
jpmoutinho Sep 2, 2024
dac0ce1
random embed helper
jpmoutinho Sep 3, 2024
036b169
fix forward
jpmoutinho Sep 5, 2024
8439656
revert
jpmoutinho Sep 5, 2024
866e125
fixed failing tests
vytautas-a Sep 24, 2024
66d6f25
added custom differentiable Heaviside function implementation
vytautas-a Oct 2, 2024
a1acc3b
refactor Heaviside function
vytautas-a Oct 2, 2024
2e9f781
remove unnecessary code
vytautas-a Oct 3, 2024
5416ee9
added tests
vytautas-a Oct 4, 2024
140654e
add utils tests; code cleanup
vytautas-a Oct 7, 2024
633c70a
re-introduced ConcretizedCallable as possible parameter for parametri…
vytautas-a Oct 10, 2024
70a5d50
added test for ConcretizedCallable param in parametric gates
vytautas-a Oct 10, 2024
4d8c3dc
added parameter type checking for parametrized gates and operations
vytautas-a Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ serve = "mkdocs serve --dev-addr localhost:8000"

[tool.ruff]
lint.select = ["E", "F", "I", "Q"]
lint.extend-ignore = ["F841"]
lint.extend-ignore = ["F841", "E731"]
line-length = 100

[tool.ruff.lint.isort]
Expand Down
11 changes: 10 additions & 1 deletion pyqtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,16 @@
Scale,
Sequence,
)
from .embed import ConcretizedCallable, Embedding
from .embed import (
ConcretizedCallable,
Embedding,
cos,
log,
sin,
sqrt,
tan,
tanh,
)
from .hamiltonians import HamiltonianEvolution, Observable
from .noise import (
AmplitudeDamping,
Expand Down
34 changes: 20 additions & 14 deletions pyqtorch/composite/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import Module, ModuleList, ParameterDict

from pyqtorch.apply import apply_operator
from pyqtorch.embed import Embedding
from pyqtorch.embed import ConcretizedCallable, Embedding
from pyqtorch.matrices import add_batch_dim
from pyqtorch.primitives import CNOT, RX, RY, Parametric, Primitive
from pyqtorch.utils import (
Expand All @@ -35,7 +35,9 @@ class Scale(Sequence):
"""

def __init__(
self, operations: Union[Primitive, Sequence, Add], param_name: str | Tensor
self,
operations: Union[Primitive, Sequence, Add],
param_name: str | Tensor,
):
"""
Initializes a Scale object.
Expand Down Expand Up @@ -69,12 +71,14 @@ def forward(
if embedding is not None:
values = embedding(values)

scale = (
values[self.param_name]
if isinstance(self.param_name, str)
else self.param_name
)
return scale * self.operations[0].forward(state, values, embedding)
if isinstance(self.param_name, str):
scale = values[self.param_name]
elif isinstance(self.param_name, Tensor):
scale = self.param_name
elif isinstance(self.param_name, ConcretizedCallable):
scale = self.param_name(values)
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved

return scale * self.operations[0].forward(state, values)

def tensor(
self,
Expand All @@ -97,12 +101,14 @@ def tensor(
if embedding is not None:
values = embedding(values)

scale = (
values[self.param_name]
if isinstance(self.param_name, str)
else self.param_name
)
return scale * self.operations[0].tensor(values, embedding, full_support)
if isinstance(self.param_name, str):
scale = values[self.param_name]
elif isinstance(self.param_name, (Tensor, int, float)):
scale = self.param_name
elif isinstance(self.param_name, ConcretizedCallable):
scale = self.param_name(values)
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved

return scale * self.operations[0].tensor(values, full_support=full_support)

def flatten(self) -> list[Scale]:
"""This method should only be called in the AdjointExpectation,
Expand Down
110 changes: 106 additions & 4 deletions pyqtorch/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Tuple

from numpy.typing import ArrayLike, DTypeLike
from torch import Tensor

logger = getLogger(__name__)

Expand All @@ -21,7 +22,7 @@
"sub": ("jax.numpy", "subtract"),
"div": ("jax.numpy", "divide"),
}
DEFAULT_TORCH_MAPPING: dict = dict()
DEFAULT_TORCH_MAPPING = {"hs": ("pyqtorch.utils", "heaviside")}
DEFAULT_NUMPY_MAPPING = {
"mul": ("numpy", "multiply"),
"sub": ("numpy", "subtract"),
Expand Down Expand Up @@ -74,8 +75,8 @@ class ConcretizedCallable:

def __init__(
self,
call_name: str,
abstract_args: list[str | float | int],
call_name: str = "",
abstract_args: list[str | float | int | complex | ConcretizedCallable] = ["x"],
instruction_mapping: dict[str, Tuple[str, str]] = dict(),
engine_name: str = "torch",
device: str = "cpu",
Expand Down Expand Up @@ -113,17 +114,94 @@ def __init__(
def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
arraylike_args = []
for symbol_or_numeric in self.abstract_args:
if isinstance(symbol_or_numeric, (float, int)):
if isinstance(symbol_or_numeric, ConcretizedCallable):
arraylike_args.append(symbol_or_numeric(inputs))
if isinstance(symbol_or_numeric, (float, int, Tensor)):
arraylike_args.append(
self.arraylike_fn(symbol_or_numeric, device=self.device)
)
elif isinstance(symbol_or_numeric, str):
arraylike_args.append(inputs[symbol_or_numeric])
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved
return self.engine_call(*arraylike_args) # type: ignore[misc]

@classmethod
def _get_independent_args(cls, cc: ConcretizedCallable) -> set:
out: set = set()
if len(cc.abstract_args) == 1 and isinstance(cc.abstract_args[0], str):
return set([cc.abstract_args[0]])
else:
for arg in cc.abstract_args:
if isinstance(arg, ConcretizedCallable):
res = cls._get_independent_args(arg)
out = out.union(res)
else:
if isinstance(arg, str):
out.add(arg)
return out

@property
def independent_args(self) -> list:
return list(self._get_independent_args(self))

def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
return self.evaluate(inputs)

def __mul__(
self, other: str | int | float | complex | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("mul", [self, other])

def __rmul__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("mul", [other, self])

def __add__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("add", [self, other])

def __radd__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("add", [other, self])

def __sub__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("sub", [self, other])

def __rsub__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("sub", [other, self])

def __pow__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("pow", [self, other])

def __rpow__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("pow", [other, self])

def __truediv__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("div", [self, other])

def __rtruediv__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("div", [other, self])

def __repr__(self) -> str:
return f"{self.call_name}({self.abstract_args})"

def __neg__(self) -> ConcretizedCallable:
return -1 * self

@property
def device(self) -> str:
return self._device
Expand All @@ -150,6 +228,30 @@ def init_param(
return engine.random.uniform(0, 1)


def sin(x: str | ConcretizedCallable):
return ConcretizedCallable("sin", [x])


def cos(x: str | ConcretizedCallable):
return ConcretizedCallable("cos", [x])


def log(x: str | ConcretizedCallable):
return ConcretizedCallable("log", [x])


def tan(x: str | ConcretizedCallable):
return ConcretizedCallable("tan", [x])


def tanh(x: str | ConcretizedCallable):
return ConcretizedCallable("tanh", [x])


def sqrt(x: str | ConcretizedCallable):
return ConcretizedCallable("sqrt", [x])


class Embedding:
"""A class relating variational and feature parameters used in ConcretizedCallable instances to
parameter names used in gates.
Expand Down
82 changes: 62 additions & 20 deletions pyqtorch/hamiltonians/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from pyqtorch.apply import apply_operator
from pyqtorch.circuit import Sequence
from pyqtorch.embed import Embedding
from pyqtorch.composite import Scale
from pyqtorch.embed import ConcretizedCallable, Embedding
from pyqtorch.primitives import Primitive
from pyqtorch.quantum_operation import QuantumOperation
from pyqtorch.time_dependent.sesolve import sesolve
Expand Down Expand Up @@ -142,9 +143,10 @@ class HamiltonianEvolution(Sequence):
def __init__(
self,
generator: TGenerator,
time: Tensor | str,
time: Tensor | str | ConcretizedCallable,
qubit_support: Tuple[int, ...] | None = None,
cache_length: int = 1,
duration: float | Tensor = 1.0,
steps: int = 100,
solver=SolverType.DP5_SE,
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved
):
Expand All @@ -160,6 +162,19 @@ def __init__(

self.solver_type = solver
self.steps = steps
self.duration = duration
self.is_time_dependent = None

if (
isinstance(time, str)
or isinstance(time, Tensor)
or isinstance(time, ConcretizedCallable)
vytautas-a marked this conversation as resolved.
Show resolved Hide resolved
):
self.time = time
else:
raise ValueError("time should be passed as str or tensor.")

self.has_time_param = self._has_time_param(generator)

if isinstance(generator, Tensor):
if qubit_support is None:
Expand All @@ -185,6 +200,7 @@ def __init__(
"Taking support from generator and ignoring qubit_support input."
)
qubit_support = generator.qubit_support

if is_parametric(generator):
generator = [generator]
self.generator_type = GeneratorType.PARAMETRIC_OPERATION
Expand All @@ -205,11 +221,6 @@ def __init__(
super().__init__(generator)
self._qubit_support = qubit_support # type: ignore

if isinstance(time, str) or isinstance(time, Tensor):
self.time = time
else:
raise ValueError("time should be passed as str or tensor.")

logger.debug("Hamiltonian Evolution initialized")
if logger.isEnabledFor(logging.DEBUG):
# When Debugging let's add logging and NVTX markers
Expand Down Expand Up @@ -246,6 +257,24 @@ def flatten(self) -> ModuleList:
def param_name(self) -> Tensor | str:
return self.time

def _has_time_param(self, generator: TGenerator) -> bool:
from pyqtorch.primitives import Parametric

res = False
if isinstance(self.time, Tensor):
return res
else:
if isinstance(generator, (Sequence, QuantumOperation)):
for m in generator.modules():
if isinstance(m, (Scale, Parametric)):
if self.time in getattr(m.param_name, "independent_args", []):
# param_name is a ConcretizedCallable object
res = True
elif m.param_name == self.time:
# param_name is a string
res = True
return res

def _symbolic_generator(
self,
values: dict,
Expand Down Expand Up @@ -341,20 +370,27 @@ def _forward_time(
) -> State:
n_qubits = len(state.shape) - 1
batch_size = state.shape[-1]
t_grid = torch.linspace(0, float(self.time), self.steps)
t_grid = torch.linspace(0, float(self.duration), self.steps)

values.update({embedding.tparam_name: torch.tensor(0.0)}) # type: ignore [dict-item]
embedded_params = embedding(values)
if embedding is not None:
values.update({embedding.tparam_name: torch.tensor(0.0)}) # type: ignore [dict-item]
embedded_params = embedding(values)
else:
embedded_params = values

def Ht(t: torch.Tensor) -> torch.Tensor:
"""Accepts a value 't' for time and returns
a (2**n_qubits, 2**n_qubits) Hamiltonian evaluated at time 't'.
"""
# We use the origial embedded params and return a new dict
# We use the original embedded params and return a new dict
# where we reembedded all parameters depending on time with value 't'
reembedded_time_values = embedding.reembed_tparam(
embedded_params, torch.as_tensor(t)
)
if embedding is not None:
reembedded_time_values = embedding.reembed_tparam(
embedded_params, torch.as_tensor(t)
)
else:
values[self.time] = torch.as_tensor(t)
reembedded_time_values = values
return (
self.generator[0].tensor(reembedded_time_values, embedding).squeeze(2)
)
Expand Down Expand Up @@ -388,8 +424,10 @@ def forward(
Returns:
The transformed state.
"""
if embedding is not None and getattr(embedding, "tparam_name", None):
return self._forward_time(state, values, embedding)
if self.has_time_param or (
embedding is not None and getattr(embedding, "tparam_name", None)
):
return self._forward_time(state, values, embedding) # type: ignore [arg-type]

else:
return self._forward(state, values, embedding)
Expand Down Expand Up @@ -420,10 +458,14 @@ def tensor(
evolved_op = self._cache_hamiltonian_evo[values_cache_key]
else:
hamiltonian: torch.Tensor = self.create_hamiltonian(values, embedding) # type: ignore [call-arg]
time_evolution: torch.Tensor = (
values[self.time] if isinstance(self.time, str) else self.time
) # If `self.time` is a string / hence, a Parameter,
# we expect the user to pass it in the `values` dict

if isinstance(self.time, str):
time_evolution = values[self.time]
elif isinstance(self.time, ConcretizedCallable):
time_evolution = self.time(values)
else:
time_evolution = self.time

evolved_op = evolve(hamiltonian, time_evolution)
nb_cached = len(self._cache_hamiltonian_evo)

Expand Down
Loading