Skip to content

Commit

Permalink
Fix bug in tag_solve_triangular rewrite (#383)
Browse files Browse the repository at this point in the history
* Fix bug in tag_solve_triangular rewrite

* Rename tag_solve_triangular to generic_solve_to_solve_triangular
  • Loading branch information
jessegrabowski authored Jul 15, 2023
1 parent 7a82a3f commit 9be43d0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 29 deletions.
39 changes: 19 additions & 20 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,31 +50,30 @@ def inv_as_solve(fgraph, node):
@register_stabilize
@register_canonicalize
@node_rewriter([Solve])
def tag_solve_triangular(fgraph, node):
def generic_solve_to_solve_triangular(fgraph, node):
"""
If a general solve() is applied to the output of a cholesky op, then
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if isinstance(node.op, Solve):
if node.op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [Solve(assume_a="sym", lower=True)(A, b)]
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)]
else:
return [Solve(assume_a="sym", lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [Solve(assume_a="sym", lower=False)(A, b)]
else:
return [Solve(assume_a="sym", lower=True)(A, b)]
return [SolveTriangular(lower=True)(A, b)]


@register_canonicalize
Expand Down
40 changes: 31 additions & 9 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy.linalg
import pytest
import scipy.linalg
from numpy.testing import assert_allclose

import pytensor
from pytensor import function
Expand All @@ -12,7 +13,7 @@
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
Expand Down Expand Up @@ -81,25 +82,46 @@ def test_transinv_to_invtrans():
assert node.inputs[0].name == "X"


def test_tag_solve_triangular():
def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A")
x = vector("x")
x = matrix("x")

L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
f = pytensor.function([A, x], b1)

X = np.random.normal(size=(10, 10)).astype(config.floatX)
X = X @ X.T
X_chol = np.linalg.cholesky(X)
eye = np.eye(10, dtype=config.floatX)

if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.assume_a != "gen" and node.op.lower
toposort = f.maker.fgraph.toposort()
op_list = [node.op for node in toposort]

assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)

assert_allclose(
f(X, eye) @ X_chol, eye, atol=1e-8 if config.floatX.endswith("64") else 1e-4
)

f = pytensor.function([A, x], b2)

if config.mode != "FAST_COMPILE":
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.assume_a != "gen" and not node.op.lower
toposort = f.maker.fgraph.toposort()
op_list = [node.op for node in toposort]
assert not any(isinstance(op, Solve) for op in op_list)
assert any(isinstance(op, SolveTriangular) for op in op_list)
assert_allclose(
f(X, eye).T @ X_chol,
eye,
atol=1e-8 if config.floatX.endswith("64") else 1e-4,
)


def test_matrix_inverse_solve():
Expand Down

0 comments on commit 9be43d0

Please sign in to comment.