Skip to content

Commit

Permalink
Extend cholesky of triangular dot rewrite to matmul Ops
Browse files Browse the repository at this point in the history
Also restrict to 2D Dot cases
  • Loading branch information
ricardoV94 committed Sep 27, 2023
1 parent 1a94585 commit c06fdb0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
16 changes: 14 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
Expand Down Expand Up @@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node):
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
Also works with matmul.
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op.core_op, Cholesky):
return

A = node.inputs[0]
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
if not (
A.owner is not None
and (
(
isinstance(A.owner.op, (Dot, Dot22))
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
)
):
return

l, r = A.owner.inputs
Expand Down
41 changes: 28 additions & 13 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np
import numpy.linalg
import pytest
Expand All @@ -9,13 +11,14 @@
from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op

Expand Down Expand Up @@ -137,33 +140,38 @@ def test_matrix_inverse_solve():
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
@pytest.mark.parametrize("op", (dot, matmul))
def test_cholesky_ldotlt(tag, cholesky_form, product, op):
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag

A = matrix("L")
ndim = 2 if op == dot else 3
A = tensor("L", shape=(None,) * ndim)
if tag:
setattr(A.tag, tag + "_triangular", True)

if product == "lower":
M = A.dot(A.T)
M = op(A, swapaxes(A, -1, -2))
elif product == "upper":
M = A.T.dot(A)
M = op(swapaxes(A, -1, -2), A)
else:
M = A

C = cholesky(M, lower=(cholesky_form == "lower"))
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))

no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
isinstance(node.op, Cholesky)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky))
for node in f.maker.fgraph.apply_nodes
)

assert no_cholesky_in_graph == transform_removes_chol

if transform_transposes:
expected_order = (1, 0) if ndim == 2 else (0, 2, 1)
assert any(
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
isinstance(node.op, DimShuffle) and node.op.new_order == expected_order
for node in f.maker.fgraph.apply_nodes
)

Expand All @@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
]
)

cholesky_vect_fn = np.vectorize(
partial(scipy.linalg.cholesky, lower=(cholesky_form == "lower")),
signature="(a, a)->(a, a)",
)

for Av in Avs:
if tag == "upper":
Av = Av.T
Expand All @@ -194,11 +207,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else:
Mv = Av

assert np.all(
np.isclose(
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
f(Av),
)
if ndim == 3:
Av = np.broadcast_to(Av, (5, *Av.shape))
Mv = np.broadcast_to(Mv, (5, *Mv.shape))

np.testing.assert_allclose(
cholesky_vect_fn(Mv),
f(Av),
)


Expand Down

0 comments on commit c06fdb0

Please sign in to comment.