Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend cholesky of triangular dot rewrite to matmul Ops #459

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
fetch-depth: 0

- name: Build wheels
uses: pypa/cibuildwheel@v2.16.0
uses: pypa/cibuildwheel@v2.14.1

- uses: actions/upload-artifact@v3
with:
Expand Down
16 changes: 14 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
Expand Down Expand Up @@ -168,13 +168,25 @@ def cholesky_ldotlt(fgraph, node):
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.

Also works with matmul.

This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op.core_op, Cholesky):
return

A = node.inputs[0]
if not (A.owner and isinstance(A.owner.op, (Dot, Dot22))):
if not (
A.owner is not None
and (
(
isinstance(A.owner.op, (Dot, Dot22))
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
)
):
return

l, r = A.owner.inputs
Expand Down
41 changes: 28 additions & 13 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np
import numpy.linalg
import pytest
Expand All @@ -9,13 +11,14 @@
from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op

Expand Down Expand Up @@ -137,33 +140,38 @@ def test_matrix_inverse_solve():
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
@pytest.mark.parametrize("op", (dot, matmul))
def test_cholesky_ldotlt(tag, cholesky_form, product, op):
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag

A = matrix("L")
ndim = 2 if op == dot else 3
A = tensor("L", shape=(None,) * ndim)
if tag:
setattr(A.tag, tag + "_triangular", True)

if product == "lower":
M = A.dot(A.T)
M = op(A, swapaxes(A, -1, -2))
elif product == "upper":
M = A.T.dot(A)
M = op(swapaxes(A, -1, -2), A)
else:
M = A

C = cholesky(M, lower=(cholesky_form == "lower"))
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))

no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
isinstance(node.op, Cholesky)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky))
for node in f.maker.fgraph.apply_nodes
)

assert no_cholesky_in_graph == transform_removes_chol

if transform_transposes:
expected_order = (1, 0) if ndim == 2 else (0, 2, 1)
assert any(
isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)
isinstance(node.op, DimShuffle) and node.op.new_order == expected_order
for node in f.maker.fgraph.apply_nodes
)

Expand All @@ -183,6 +191,11 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
]
)

cholesky_vect_fn = np.vectorize(
partial(scipy.linalg.cholesky, lower=(cholesky_form == "lower")),
signature="(a, a)->(a, a)",
)

for Av in Avs:
if tag == "upper":
Av = Av.T
Expand All @@ -194,11 +207,13 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else:
Mv = Av

assert np.all(
np.isclose(
scipy.linalg.cholesky(Mv, lower=(cholesky_form == "lower")),
f(Av),
)
if ndim == 3:
Av = np.broadcast_to(Av, (5, *Av.shape))
Mv = np.broadcast_to(Mv, (5, *Mv.shape))

np.testing.assert_allclose(
cholesky_vect_fn(Mv),
f(Av),
)


Expand Down
Loading