Skip to content

Commit

Permalink
Test all parameterizations of solve_triangular
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Aug 28, 2023
1 parent 15e23b9 commit 7649d59
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
10 changes: 3 additions & 7 deletions pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from numba.core import cgutils, types
from numba.extending import get_cython_function_address, intrinsic, overload
from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind
from scipy import linalg

from pytensor.link.numba.dispatch import basic as numba_basic
Expand Down Expand Up @@ -114,7 +114,7 @@ def _get_underlying_float(dtype):


def _get_addr_and_float_pointer(dtype, name):
d = _blas_kinds[dtype]
d = get_blas_kind(dtype)
func_name = f"{d}{name}"
float_pointer = _get_float_pointer_for_dtype(d)
addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
Expand Down Expand Up @@ -154,10 +154,6 @@ class _LAPACK:
def __init__(self):
ensure_lapack()

@classmethod
def test_blas_kinds(cls, dtype):
return _blas_kinds[dtype]

@classmethod
def numba_xtrtrs(cls, dtype):
"""
Expand Down Expand Up @@ -233,7 +229,7 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False):
else:
transval = ord("C")

B_NDIM = 1 if B_is_1d else B.shape[1]
B_NDIM = 1 if B_is_1d else int(B.shape[1])

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
TRANS = val_to_int_ptr(transval)
Expand Down
55 changes: 46 additions & 9 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,60 @@
rng = np.random.default_rng(42849)


def transpose_func(x, trans):
if trans == 0:
return x
if trans == 1:
return x.conj().T
if trans == 2:
return x.T


@pytest.mark.parametrize(
"b_func, b_size",
[(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))],
ids=["b_col_vec", "b_matrix", "b_vec"],
)
def test_solve_triangular(b_func, b_size):
A = pt.matrix("A")
b = b_func("b")
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"])
@pytest.mark.parametrize(
"unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"]
)
# @pytest.mark.parametrize('complex', [True, False], ids=['complex', 'real'])
def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex=False):
# TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, why?
complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128"
dtype = complex_dtype if complex else config.floatX

X = pt.linalg.solve_triangular(A, b, lower=True)
A = pt.matrix("A", dtype=dtype)
b = b_func("b", dtype=dtype)

X = pt.linalg.solve_triangular(
A, b, lower=lower, trans=trans, unit_diagonal=unit_diag
)
f = pytensor.function([A, b], X, mode="NUMBA")

A_val = np.random.normal(size=(5, 5)).astype(config.floatX)
A_sym = A_val @ A_val.T
A_tri = np.linalg.cholesky(A_sym)
A_val = np.random.normal(size=(5, 5))
b = np.random.normal(size=b_size)

if complex:
A_val = A_val + np.random.normal(size=(5, 5)) * 1j
b = b + np.random.normal(size=b_size) * 1j
A_sym = A_val @ A_val.conj().T

A_tri = np.linalg.cholesky(A_sym).astype(dtype)
if unit_diag:
adj_mat = np.ones((5, 5))
adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri)
A_tri = A_tri * adj_mat

A_tri = A_tri.astype(dtype)
b = b.astype(dtype)

b = np.random.normal(size=b_size).astype(config.floatX)
if not lower:
A_tri = A_tri.T

X_np = f(A_tri, b)
np.testing.assert_allclose(A_tri @ X_np, b, atol=ATOL, rtol=RTOL)
np.testing.assert_allclose(
transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL
)

0 comments on commit 7649d59

Please sign in to comment.