Skip to content

Commit

Permalink
Blockwise some linalg Ops by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 5, 2023
1 parent 0ff0f29 commit 8df22d7
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 129 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg


def swapaxes(y, axis1, axis2):
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor."
y = as_tensor_variable(y)
ndim = y.ndim
Expand Down
20 changes: 15 additions & 5 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from pytensor.tensor import basic as at
from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector


class MatrixPinv(Op):
__props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"

def __init__(self, hermitian):
self.hermitian = hermitian
Expand Down Expand Up @@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op.
"""
return MatrixPinv(hermitian=hermitian)(x)
return Blockwise(MatrixPinv(hermitian=hermitian))(x)


class MatrixInverse(Op):
Expand All @@ -93,6 +95,8 @@ class MatrixInverse(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)

def __init__(self):
pass
Expand Down Expand Up @@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
return shapes


inv = matrix_inverse = MatrixInverse()
inv = matrix_inverse = Blockwise(MatrixInverse())


def matrix_dot(*args):
Expand Down Expand Up @@ -181,6 +185,8 @@ class Det(Op):
"""

__props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)

def make_node(self, x):
x = as_tensor_variable(x)
Expand Down Expand Up @@ -209,7 +215,7 @@ def __str__(self):
return "Det"


det = Det()
det = Blockwise(Det())


class SLogDet(Op):
Expand All @@ -218,6 +224,8 @@ class SLogDet(Op):
"""

__props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)

def make_node(self, x):
x = as_tensor_variable(x)
Expand All @@ -242,7 +250,7 @@ def __str__(self):
return "SLogDet"


slogdet = SLogDet()
slogdet = Blockwise(SLogDet())


class Eig(Op):
Expand All @@ -252,6 +260,8 @@ class Eig(Op):
"""

__props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)

def make_node(self, x):
x = as_tensor_variable(x)
Expand All @@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
return [(n,), (n, n)]


eig = Eig()
eig = Blockwise(Eig())


class Eigh(Eig):
Expand Down
Loading

0 comments on commit 8df22d7

Please sign in to comment.