Skip to content

Commit

Permalink
algebraic rules
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmoutinho committed Aug 30, 2024
1 parent dd9df14 commit 986a9b1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 26 deletions.
5 changes: 0 additions & 5 deletions pyqtorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,10 @@
from .embed import (
ConcretizedCallable,
Embedding,
add,
cos,
div,
log,
mul,
sin,
sqrt,
square,
sub,
tan,
tanh,
)
Expand Down
75 changes: 54 additions & 21 deletions pyqtorch/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ConcretizedCallable:
def __init__(
self,
call_name: str,
abstract_args: list[str | float | int | ConcretizedCallable],
abstract_args: list[str | float | int | complex | ConcretizedCallable],
instruction_mapping: dict[str, Tuple[str, str]] = dict(),
engine_name: str = "torch",
device: str = "cpu",
Expand Down Expand Up @@ -126,6 +126,59 @@ def evaluate(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
def __call__(self, inputs: dict[str, ArrayLike] = dict()) -> ArrayLike:
return self.evaluate(inputs)

def __mul__(
self, other: str | int | float | complex | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("mul", [self, other])

def __rmul__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("mul", [other, self])

def __add__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("add", [self, other])

def __radd__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("add", [other, self])

def __sub__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("sub", [self, other])

def __rsub__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("sub", [other, self])

def __pow__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("pow", [self, other])

def __rpow__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("pow", [other, self])

def __truediv__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("div", [self, other])

def __rtruediv__(
self, other: str | int | float | ConcretizedCallable
) -> ConcretizedCallable:
return ConcretizedCallable("div", [other, self])

def __neg__(self) -> ConcretizedCallable:
return -1 * self

@property
def device(self) -> str:
return self._device
Expand Down Expand Up @@ -176,26 +229,6 @@ def sqrt(x: str | ConcretizedCallable):
return ConcretizedCallable("sqrt", [x])


def square(x: str | ConcretizedCallable):
return ConcretizedCallable("square", [x])


def mul(x: str | ConcretizedCallable, y: str | ConcretizedCallable):
return ConcretizedCallable("mul", [x, y])


def add(x: str | ConcretizedCallable, y: str | ConcretizedCallable):
return ConcretizedCallable("add", [x, y])


def div(x: str | ConcretizedCallable, y: str | ConcretizedCallable):
return ConcretizedCallable("div", [x, y])


def sub(x: str | ConcretizedCallable, y: str | ConcretizedCallable):
return ConcretizedCallable("sub", [x, y])


class Embedding:
"""A class relating variational and feature parameters used in ConcretizedCallable instances to
parameter names used in gates.
Expand Down

0 comments on commit 986a9b1

Please sign in to comment.