diff --git a/pyqtorch/__init__.py b/pyqtorch/__init__.py index fd37fb3f..8945aff1 100644 --- a/pyqtorch/__init__.py +++ b/pyqtorch/__init__.py @@ -58,15 +58,10 @@ from .embed import ( ConcretizedCallable, Embedding, - add, cos, - div, log, - mul, sin, sqrt, - square, - sub, tan, tanh, ) diff --git a/pyqtorch/embed.py b/pyqtorch/embed.py index 7473b04c..e47c2ac1 100644 --- a/pyqtorch/embed.py +++ b/pyqtorch/embed.py @@ -75,7 +75,7 @@ class ConcretizedCallable: def __init__( self, call_name: str, - abstract_args: list[str | float | int | ConcretizedCallable], + abstract_args: list[str | float | int | complex | ConcretizedCallable], instruction_mapping: dict[str, Tuple[str, str]] = dict(), engine_name: str = "torch", device: str = "cpu", @@ -126,6 +126,59 @@ def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike: return self.evaluate(inputs) + def __mul__( + self, other: str | int | float | complex | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("mul", [self, other]) + + def __rmul__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("mul", [other, self]) + + def __add__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("add", [self, other]) + + def __radd__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("add", [other, self]) + + def __sub__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("sub", [self, other]) + + def __rsub__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("sub", [other, self]) + + def __pow__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("pow", [self, other]) + + def __rpow__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("pow", [other, self]) + + def __truediv__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("div", [self, other]) + + def __rtruediv__( + self, other: str | int | float | ConcretizedCallable + ) -> ConcretizedCallable: + return ConcretizedCallable("div", [other, self]) + + def __neg__(self) -> ConcretizedCallable: + return -1 * self + @property def device(self) -> str: return self._device @@ -176,26 +229,6 @@ def sqrt(x: str | ConcretizedCallable): return ConcretizedCallable("sqrt", [x]) -def square(x: str | ConcretizedCallable): - return ConcretizedCallable("square", [x]) - - -def mul(x: str | ConcretizedCallable, y: str | ConcretizedCallable): - return ConcretizedCallable("mul", [x, y]) - - -def add(x: str | ConcretizedCallable, y: str | ConcretizedCallable): - return ConcretizedCallable("add", [x, y]) - - -def div(x: str | ConcretizedCallable, y: str | ConcretizedCallable): - return ConcretizedCallable("div", [x, y]) - - -def sub(x: str | ConcretizedCallable, y: str | ConcretizedCallable): - return ConcretizedCallable("sub", [x, y]) - - class Embedding: """A class relating variational and feature parameters used in ConcretizedCallable instances to parameter names used in gates.