diff --git a/pyproject.toml b/pyproject.toml index a0f94737..1e7c2f37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" name = "pyqtorch" description = "An efficient, large-scale emulator designed for quantum machine learning, seamlessly integrated with a PyTorch backend. Please refer to https://pyqtorch.readthedocs.io/en/latest/ for setup and usage info, along with the full documentation." readme = "README.md" -version = "1.4.9" +version = "1.5.0" requires-python = ">=3.8,<3.13" license = { text = "Apache 2.0" } keywords = ["quantum"] @@ -117,7 +117,7 @@ serve = "mkdocs serve --dev-addr localhost:8000" [tool.ruff] lint.select = ["E", "F", "I", "Q"] -lint.extend-ignore = ["F841"] +lint.extend-ignore = ["F841", "E731"] line-length = 100 [tool.ruff.lint.isort] diff --git a/pyqtorch/__init__.py b/pyqtorch/__init__.py index 083b18d7..8945aff1 100644 --- a/pyqtorch/__init__.py +++ b/pyqtorch/__init__.py @@ -55,7 +55,16 @@ Scale, Sequence, ) -from .embed import ConcretizedCallable, Embedding +from .embed import ( + ConcretizedCallable, + Embedding, + cos, + log, + sin, + sqrt, + tan, + tanh, +) from .hamiltonians import HamiltonianEvolution, Observable from .noise import ( AmplitudeDamping, diff --git a/pyqtorch/composite/compose.py b/pyqtorch/composite/compose.py index 7d4ce84a..b14f6cf1 100644 --- a/pyqtorch/composite/compose.py +++ b/pyqtorch/composite/compose.py @@ -10,7 +10,7 @@ from torch.nn import Module, ModuleList, ParameterDict from pyqtorch.apply import apply_operator -from pyqtorch.embed import Embedding +from pyqtorch.embed import ConcretizedCallable, Embedding from pyqtorch.matrices import add_batch_dim from pyqtorch.primitives import CNOT, RX, RY, Parametric, Primitive from pyqtorch.utils import ( @@ -35,7 +35,9 @@ class Scale(Sequence): """ def __init__( - self, operations: Union[Primitive, Sequence, Add], param_name: str | Tensor + self, + operations: Union[Primitive, Sequence, Add], + param_name: str | float | int | Tensor | ConcretizedCallable, ): """ Initializes a Scale object. @@ -46,6 +48,11 @@ def __init__( """ if not isinstance(operations, (Primitive, Sequence, Add)): raise ValueError("Scale only supports a single operation, Sequence or Add.") + if not isinstance(param_name, (str, int, float, Tensor, ConcretizedCallable)): + raise TypeError( + "Only str, int, float, Tensor or ConcretizedCallable types \ + are supported for param_name" + ) super().__init__([operations]) self.param_name = param_name @@ -69,12 +76,14 @@ def forward( if embedding is not None: values = embedding(values) - scale = ( - values[self.param_name] - if isinstance(self.param_name, str) - else self.param_name - ) - return scale * self.operations[0].forward(state, values, embedding) + if isinstance(self.param_name, str): + scale = values[self.param_name] + elif isinstance(self.param_name, Tensor): + scale = self.param_name + elif isinstance(self.param_name, ConcretizedCallable): + scale = self.param_name(values) + + return scale * self.operations[0].forward(state, values) def tensor( self, @@ -97,12 +106,14 @@ def tensor( if embedding is not None: values = embedding(values) - scale = ( - values[self.param_name] - if isinstance(self.param_name, str) - else self.param_name - ) - return scale * self.operations[0].tensor(values, embedding, full_support) + if isinstance(self.param_name, str): + scale = values[self.param_name] + elif isinstance(self.param_name, (Tensor, int, float)): + scale = self.param_name + elif isinstance(self.param_name, ConcretizedCallable): + scale = self.param_name(values) + + return scale * self.operations[0].tensor(values, full_support=full_support) def flatten(self) -> list[Scale]: """This method should only be called in the AdjointExpectation, @@ -121,7 +132,7 @@ def to(self, *args: Any, **kwargs: Any) -> Scale: Converted Scale. """ super().to(*args, **kwargs) - if not isinstance(self.param_name, str): + if not isinstance(self.param_name, (str, float, int)): self.param_name = self.param_name.to(*args, **kwargs) return self diff --git a/pyqtorch/differentiation/adjoint.py b/pyqtorch/differentiation/adjoint.py index fcad2fbc..343307e0 100644 --- a/pyqtorch/differentiation/adjoint.py +++ b/pyqtorch/differentiation/adjoint.py @@ -130,8 +130,8 @@ def backward(ctx: Any, grad_out: Tensor) -> Tuple[None, ...]: grad_out * 2 * inner_prod(ctx.projected_state, mu).real ) - if values[op.param_name].requires_grad: - grads_dict[op.param_name] = grad_out * 2 * -values[op.param_name] + if values[op.param_name].requires_grad: # type: ignore [index] + grads_dict[op.param_name] = grad_out * 2 * -values[op.param_name] # type: ignore [index] ctx.projected_state = apply_operator( ctx.projected_state, op.dagger(values, ctx.embedding), diff --git a/pyqtorch/embed.py b/pyqtorch/embed.py index 80f7174a..95e2685f 100644 --- a/pyqtorch/embed.py +++ b/pyqtorch/embed.py @@ -5,6 +5,7 @@ from typing import Any, Tuple from numpy.typing import ArrayLike, DTypeLike +from torch import Tensor logger = getLogger(__name__) @@ -21,7 +22,7 @@ "sub": ("jax.numpy", "subtract"), "div": ("jax.numpy", "divide"), } -DEFAULT_TORCH_MAPPING: dict = dict() +DEFAULT_TORCH_MAPPING = {"hs": ("pyqtorch.utils", "heaviside")} DEFAULT_NUMPY_MAPPING = { "mul": ("numpy", "multiply"), "sub": ("numpy", "subtract"), @@ -74,8 +75,8 @@ class ConcretizedCallable: def __init__( self, - call_name: str, - abstract_args: list[str | float | int], + call_name: str = "", + abstract_args: list[str | float | int | complex | ConcretizedCallable] = ["x"], instruction_mapping: dict[str, Tuple[str, str]] = dict(), engine_name: str = "torch", device: str = "cpu", @@ -92,6 +93,16 @@ def __init__( self._dtype = dtype self.engine_call = None engine = None + if not all( + [ + isinstance(arg, (str, float, int, complex, Tensor, ConcretizedCallable)) + for arg in abstract_args + ] + ): + raise TypeError( + "Only str, float, int, complex, Tensor or ConcretizedCallable type elements \ + are supported for abstract_args" + ) try: engine_name, fn_name = ARRAYLIKE_FN_MAP[engine_name] engine = import_module(engine_name) @@ -113,7 +124,9 @@ def __init__( def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: arraylike_args = [] for symbol_or_numeric in self.abstract_args: - if isinstance(symbol_or_numeric, (float, int)): + if isinstance(symbol_or_numeric, ConcretizedCallable): + arraylike_args.append(symbol_or_numeric(inputs)) + if isinstance(symbol_or_numeric, (float, int, Tensor)): arraylike_args.append( self.arraylike_fn(symbol_or_numeric, device=self.device) ) @@ -121,9 +134,84 @@ def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: arraylike_args.append(inputs[symbol_or_numeric]) return self.engine_call(*arraylike_args) # type: ignore[misc] + @classmethod + def _get_independent_args(cls, cc: ConcretizedCallable) -> set: + out: set = set() + if len(cc.abstract_args) == 1 and isinstance(cc.abstract_args[0], str): + return set([cc.abstract_args[0]]) + else: + for arg in cc.abstract_args: + if isinstance(arg, ConcretizedCallable): + res = cls._get_independent_args(arg) + out = out.union(res) + else: + if isinstance(arg, str): + out.add(arg) + return out + + @property + def independent_args(self) -> list: + return list(self._get_independent_args(self)) + def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: return self.evaluate(inputs) + def __mul__( + self, other: str | int | float | complex | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("mul", [self, other]) + + def __rmul__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("mul", [other, self]) + + def __add__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("add", [self, other]) + + def __radd__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("add", [other, self]) + + def __sub__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("sub", [self, other]) + + def __rsub__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("sub", [other, self]) + + def __pow__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("pow", [self, other]) + + def __rpow__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("pow", [other, self]) + + def __truediv__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("div", [self, other]) + + def __rtruediv__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("div", [other, self]) + + def __repr__(self) -> str: + return f"{self.call_name}({self.abstract_args})" + + def __neg__(self) -> ConcretizedCallable: + return -1 * self + @property def device(self) -> str: return self._device @@ -150,6 +238,30 @@ def init_param( return engine.random.uniform(0, 1) +def sin(x: str | ConcretizedCallable): + return ConcretizedCallable("sin", [x]) + + +def cos(x: str | ConcretizedCallable): + return ConcretizedCallable("cos", [x]) + + +def log(x: str | ConcretizedCallable): + return ConcretizedCallable("log", [x]) + + +def tan(x: str | ConcretizedCallable): + return ConcretizedCallable("tan", [x]) + + +def tanh(x: str | ConcretizedCallable): + return ConcretizedCallable("tanh", [x]) + + +def sqrt(x: str | ConcretizedCallable): + return ConcretizedCallable("sqrt", [x]) + + class Embedding: """A class relating variational and feature parameters used in ConcretizedCallable instances to parameter names used in gates. diff --git a/pyqtorch/hamiltonians/evolution.py b/pyqtorch/hamiltonians/evolution.py index 9105411e..f91fc43b 100644 --- a/pyqtorch/hamiltonians/evolution.py +++ b/pyqtorch/hamiltonians/evolution.py @@ -12,7 +12,8 @@ from pyqtorch.apply import apply_operator from pyqtorch.circuit import Sequence -from pyqtorch.embed import Embedding +from pyqtorch.composite import Scale +from pyqtorch.embed import ConcretizedCallable, Embedding from pyqtorch.primitives import Primitive from pyqtorch.quantum_operation import QuantumOperation from pyqtorch.time_dependent.sesolve import sesolve @@ -142,9 +143,10 @@ class HamiltonianEvolution(Sequence): def __init__( self, generator: TGenerator, - time: Tensor | str, + time: Tensor | str | ConcretizedCallable, qubit_support: Tuple[int, ...] | None = None, cache_length: int = 1, + duration: float | Tensor = 1.0, steps: int = 100, solver=SolverType.DP5_SE, ): @@ -160,6 +162,17 @@ def __init__( self.solver_type = solver self.steps = steps + self.duration = duration + self.is_time_dependent = None + + if isinstance(time, (str, Tensor, ConcretizedCallable)): + self.time = time + else: + raise ValueError( + "time should be passed as str, Tensor or ConcretizedCallable." + ) + + self.has_time_param = self._has_time_param(generator) if isinstance(generator, Tensor): if qubit_support is None: @@ -185,6 +198,7 @@ def __init__( "Taking support from generator and ignoring qubit_support input." ) qubit_support = generator.qubit_support + if is_parametric(generator): generator = [generator] self.generator_type = GeneratorType.PARAMETRIC_OPERATION @@ -205,11 +219,6 @@ def __init__( super().__init__(generator) self._qubit_support = qubit_support # type: ignore - if isinstance(time, str) or isinstance(time, Tensor): - self.time = time - else: - raise ValueError("time should be passed as str or tensor.") - logger.debug("Hamiltonian Evolution initialized") if logger.isEnabledFor(logging.DEBUG): # When Debugging let's add logging and NVTX markers @@ -246,6 +255,24 @@ def flatten(self) -> ModuleList: def param_name(self) -> Tensor | str: return self.time + def _has_time_param(self, generator: TGenerator) -> bool: + from pyqtorch.primitives import Parametric + + res = False + if isinstance(self.time, Tensor): + return res + else: + if isinstance(generator, (Sequence, QuantumOperation)): + for m in generator.modules(): + if isinstance(m, (Scale, Parametric)): + if self.time in getattr(m.param_name, "independent_args", []): + # param_name is a ConcretizedCallable object + res = True + elif m.param_name == self.time: + # param_name is a string + res = True + return res + def _symbolic_generator( self, values: dict, @@ -341,20 +368,27 @@ def _forward_time( ) -> State: n_qubits = len(state.shape) - 1 batch_size = state.shape[-1] - t_grid = torch.linspace(0, float(self.time), self.steps) + t_grid = torch.linspace(0, float(self.duration), self.steps) - values.update({embedding.tparam_name: torch.tensor(0.0)}) # type: ignore [dict-item] - embedded_params = embedding(values) + if embedding is not None: + values.update({embedding.tparam_name: torch.tensor(0.0)}) # type: ignore [dict-item] + embedded_params = embedding(values) + else: + embedded_params = values def Ht(t: torch.Tensor) -> torch.Tensor: """Accepts a value 't' for time and returns a (2**n_qubits, 2**n_qubits) Hamiltonian evaluated at time 't'. """ - # We use the origial embedded params and return a new dict + # We use the original embedded params and return a new dict # where we reembedded all parameters depending on time with value 't' - reembedded_time_values = embedding.reembed_tparam( - embedded_params, torch.as_tensor(t) - ) + if embedding is not None: + reembedded_time_values = embedding.reembed_tparam( + embedded_params, torch.as_tensor(t) + ) + else: + values[self.time] = torch.as_tensor(t) + reembedded_time_values = values return ( self.generator[0].tensor(reembedded_time_values, embedding).squeeze(2) ) @@ -388,8 +422,10 @@ def forward( Returns: The transformed state. """ - if embedding is not None and getattr(embedding, "tparam_name", None): - return self._forward_time(state, values, embedding) + if self.has_time_param or ( + embedding is not None and getattr(embedding, "tparam_name", None) + ): + return self._forward_time(state, values, embedding) # type: ignore [arg-type] else: return self._forward(state, values, embedding) @@ -420,10 +456,14 @@ def tensor( evolved_op = self._cache_hamiltonian_evo[values_cache_key] else: hamiltonian: torch.Tensor = self.create_hamiltonian(values, embedding) # type: ignore [call-arg] - time_evolution: torch.Tensor = ( - values[self.time] if isinstance(self.time, str) else self.time - ) # If `self.time` is a string / hence, a Parameter, - # we expect the user to pass it in the `values` dict + + if isinstance(self.time, str): + time_evolution = values[self.time] + elif isinstance(self.time, ConcretizedCallable): + time_evolution = self.time(values) + else: + time_evolution = self.time + evolved_op = evolve(hamiltonian, time_evolution) nb_cached = len(self._cache_hamiltonian_evo) diff --git a/pyqtorch/primitives/parametric.py b/pyqtorch/primitives/parametric.py index d72569f7..2e5c8ef2 100644 --- a/pyqtorch/primitives/parametric.py +++ b/pyqtorch/primitives/parametric.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from pyqtorch.embed import Embedding +from pyqtorch.embed import ConcretizedCallable, Embedding from pyqtorch.matrices import ( OPERATIONS_DICT, _jacobian, @@ -34,7 +34,7 @@ def __init__( self, generator: str | Tensor, qubit_support: int | tuple[int, ...] | Support, - param_name: str | int | float | torch.Tensor = "", + param_name: str | int | float | torch.Tensor | ConcretizedCallable = "", noise: NoiseProtocol | None = None, ): """Initializes Parametric. @@ -49,6 +49,11 @@ def __init__( generator_operation = ( OPERATIONS_DICT[generator] if isinstance(generator, str) else generator ) + if not isinstance(param_name, (str, int, float, Tensor, ConcretizedCallable)): + raise TypeError( + "Only str, int, float, Tensor or ConcretizedCallable types \ + are supported for param_name" + ) self.param_name = param_name def parse_values( @@ -105,6 +110,20 @@ def parse_constant( torch.tensor(self.param_name, device=self.device, dtype=self.dtype) ) + def parse_concretized_callable( + values: dict[str, Tensor] | Tensor = dict(), + embedding: Embedding | None = None, + ) -> Tensor: + """Evaluate ConcretizedCallable object with given values. + + Arguments: + values: A dict containing param_name:torch.Tensor pairs + Returns: + A Torch Tensor with which to evaluate the Parametric Gate. + """ + # self.param_name will be a ConcretizedCallable + return Parametric._expand_values(self.param_name(values)) # type: ignore [operator] + if param_name == "": self.parse_values = parse_tensor self.param_name = self.param_name @@ -112,6 +131,8 @@ def parse_constant( self.parse_values = parse_values elif isinstance(param_name, (float, int, torch.Tensor)): self.parse_values = parse_constant + elif isinstance(param_name, ConcretizedCallable): + self.parse_values = parse_concretized_callable # Parametric is defined by generator operation and a function # The function will use parsed parameter values to compute the unitary diff --git a/pyqtorch/utils.py b/pyqtorch/utils.py index 4d3e26be..7174b9ed 100644 --- a/pyqtorch/utils.py +++ b/pyqtorch/utils.py @@ -218,9 +218,7 @@ def is_diag(H: Tensor, atol: Tensor = ATOL) -> bool: Returns: True if diagonal, else False. """ - m = H.shape[0] - p, q = H.stride() - offdiag_view = torch.as_strided(H[:, 1:], (m - 1, m), (p + q, q)) + offdiag_view = H - torch.diag(torch.diag(H)) return torch.count_nonzero(torch.abs(offdiag_view).gt(atol)) == 0 @@ -746,14 +744,45 @@ class SolverType(StrEnum): def is_parametric(operation: pyq.Sequence) -> bool: + """Check if operation is parametric. + + Args: + operation (pyq.Sequence): checked operation + + Returns: + bool: True if operation is parametric, False otherwise + """ + from pyqtorch.primitives import Parametric - params = [] + res = False for m in operation.modules(): if isinstance(m, (pyq.Scale, Parametric)): - params.append(m.param_name) - - res = False - if any(isinstance(p, str) for p in params): - res = True + if isinstance(m.param_name, (str, pyq.ConcretizedCallable)): + res = True + break return res + + +def heaviside(x: Tensor, _: Any = None, slope: float = 1000.0) -> Tensor: + """Torch autograd-compatible Heaviside function implementation. + + Args: + x (Tensor): function argument + _ (Any): unused argument left for signature compatibility reasons + slope (float, optional): slope of Heaviside function (theoretically should be $infty$). + Defaults to 1000.0. + + Returns: + Tensor: function value + """ + + if x.ndim > 1: + raise ValueError("Argument tensor must be 0-d or 1-d.") + + shape = (1, 2) if x.ndim == 0 else (len(x), 2) + a = torch.zeros(shape) + a[:, 0] = x + return torch.clamp( + slope * torch.max(a, dim=1)[0], torch.tensor(0.0), torch.tensor(1.0) + ) diff --git a/tests/helpers.py b/tests/helpers.py index 41d4b731..d3e1a961 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,6 +4,7 @@ import torch +import pyqtorch.embed as pyq_em from pyqtorch.apply import apply_operator, apply_operator_permute from pyqtorch.composite import Add, Scale, Sequence from pyqtorch.primitives import ( @@ -68,6 +69,26 @@ def get_op_support( return supp +def get_random_embed() -> tuple: + fn_list = [ + (pyq_em.sin, torch.sin), + (pyq_em.cos, torch.cos), + (pyq_em.log, torch.log), + (pyq_em.tanh, torch.tanh), + (pyq_em.tan, torch.tan), + (pyq_em.sqrt, torch.sqrt), + ] + + fn1, fn2 = random.choice(fn_list), random.choice(fn_list) + + expr = (1.0 + 2 ** fn1[0]("x")) * fn2[0]("x") + call = lambda x: (12.0 + 2 ** fn1[1](x)) * fn2[1]("x") + + embedding = pyq_em.Embedding(fparam_names=["x"], var_to_call={"expr": expr}) + + return embedding, call + + def random_pauli_hamiltonian( n_qubits: int, k_1q: int = 5, diff --git a/tests/test_analog.py b/tests/test_analog.py index 61a6bf9b..920d7fad 100644 --- a/tests/test_analog.py +++ b/tests/test_analog.py @@ -8,6 +8,7 @@ from helpers import calc_mat_vec_wavefunction, random_pauli_hamiltonian import pyqtorch as pyq +from pyqtorch import RX, Add, ConcretizedCallable, HamiltonianEvolution, Scale, X from pyqtorch.composite import Sequence from pyqtorch.hamiltonians import GeneratorType from pyqtorch.matrices import ( @@ -319,7 +320,8 @@ def test_timedependent( ) hamiltonian_evolution = pyq.HamiltonianEvolution( generator=hamevo_generator, - time=torch.as_tensor(duration), + time=tparam, + duration=duration, steps=n_steps, solver=ode_solver, ) @@ -388,3 +390,34 @@ def apply_hamevo_and_compare_expected(psi, values): apply_hamevo_and_compare_expected(psi, values) assert len(hamevo._cache_hamiltonian_evo) == 2 assert values_cache_key in previous_cache_keys + + +@pytest.mark.parametrize( + "generator, time_param, result", + [ + (RX(0, "x"), "x", True), + (RX(1, 0.5), "y", False), + (RX(0, "x"), "y", False), + (RX(0, "x"), torch.tensor(0.5), False), + (RX(0, torch.tensor(0.5)), torch.tensor(0.5), False), + (Scale(X(1), "y"), "y", True), + (Scale(X(1), 0.2), "x", False), + ( + Add( + [Scale(X(1), ConcretizedCallable("mul", ["y", "x"])), Scale(X(1), "z")] + ), + "x", + True, + ), + ( + Add( + [Scale(X(1), ConcretizedCallable("add", ["y", "x"])), Scale(X(1), "z")] + ), + "t", + False, + ), + ], +) +def test_hamevo_is_time_dependent_generator(generator, time_param, result) -> None: + hamevo = HamiltonianEvolution(generator, time_param) + assert hamevo.has_time_param == result diff --git a/tests/test_digital.py b/tests/test_digital.py index 0605172b..f44d1424 100644 --- a/tests/test_digital.py +++ b/tests/test_digital.py @@ -9,6 +9,7 @@ from torch import Tensor import pyqtorch as pyq +from pyqtorch import ConcretizedCallable from pyqtorch.apply import apply_operator from pyqtorch.matrices import ( DEFAULT_MATRIX_DTYPE, @@ -340,3 +341,19 @@ def test_parametric_constantparam(gate: Parametric) -> None: gate(target, "theta")(state, {"theta": param_val}), gate(target, param_val)(state), ) + + +@pytest.mark.parametrize("gate", [pyq.RX, pyq.RY, pyq.RZ]) +def test_parametric_callableparam(gate: Parametric) -> None: + n_qubits = 4 + max_batch_size = 10 + target = torch.randint(0, n_qubits, (1,)).item() + size = torch.randint(1, max_batch_size, (1,)).item() + param_val_x = torch.rand(size) + param_val_y = torch.rand(size) + state = pyq.random_state(n_qubits) + param = ConcretizedCallable("add", ["x", "y"]) + assert torch.allclose( + gate(target, param)(state, {"x": param_val_x, "y": param_val_y}), + gate(target, param_val_x + param_val_y)(state), + ) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 6d37e1cb..9dcd5a6d 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -10,7 +10,7 @@ from torch.nn import Module import pyqtorch as pyq -from pyqtorch.embed import ConcretizedCallable, Embedding +from pyqtorch.embed import ConcretizedCallable, Embedding, cos, log, sin, sqrt from pyqtorch.primitives import Primitive from pyqtorch.utils import ATOL_embedding @@ -226,3 +226,8 @@ def run( ) wf = custom(state=pyq.zero_state(2), values={"t": torch.rand(1)}, embedding=embed) assert not torch.any(torch.isnan(wf)) + + +def test_get_independent_args() -> None: + expr: ConcretizedCallable = sqrt(sin("x")) + cos("r") * (1.0 / log("z") * "y") + assert set(expr.independent_args) == {"x", "y", "z", "r"} diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..0b1c0271 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import pytest +import torch + +from pyqtorch import RX, RY, ConcretizedCallable, Scale, Sequence, X +from pyqtorch.utils import heaviside, is_parametric + + +@pytest.mark.parametrize( + "operation, result", + [ + (RX(0, "x"), True), + (RY(1, 0.5), False), + (Scale(X(1), "y"), True), + (Scale(X(1), 0.2), False), + (Scale(X(1), ConcretizedCallable("mul", ["y", "x"])), True), + ], +) +def test_is_parametric(operation: Sequence, result: bool) -> None: + assert is_parametric(operation) == result + + +def test_heaviside() -> None: + x = torch.linspace(-1, 1, 50) + assert torch.allclose(heaviside(x), torch.heaviside(x, torch.tensor(0.0)))