From 7649d59cf3f42d5a64499bdf2e43a6c9b499a0ef Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 01:01:03 +0200 Subject: [PATCH] Test all parameterizations of solve_triangular --- pytensor/link/numba/dispatch/slinalg.py | 10 ++--- tests/link/numba/test_slinalg.py | 55 +++++++++++++++++++++---- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index afe8730058..c345834c75 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -4,7 +4,7 @@ import numpy as np from numba.core import cgutils, types from numba.extending import get_cython_function_address, intrinsic, overload -from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind from scipy import linalg from pytensor.link.numba.dispatch import basic as numba_basic @@ -114,7 +114,7 @@ def _get_underlying_float(dtype): def _get_addr_and_float_pointer(dtype, name): - d = _blas_kinds[dtype] + d = get_blas_kind(dtype) func_name = f"{d}{name}" float_pointer = _get_float_pointer_for_dtype(d) addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) @@ -154,10 +154,6 @@ class _LAPACK: def __init__(self): ensure_lapack() - @classmethod - def test_blas_kinds(cls, dtype): - return _blas_kinds[dtype] - @classmethod def numba_xtrtrs(cls, dtype): """ @@ -233,7 +229,7 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): else: transval = ord("C") - B_NDIM = 1 if B_is_1d else B.shape[1] + B_NDIM = 1 if B_is_1d else int(B.shape[1]) UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) TRANS = val_to_int_ptr(transval) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 7a22cd3029..4ab314d1f0 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -14,23 +14,60 @@ rng = np.random.default_rng(42849) +def transpose_func(x, trans): + if trans == 0: + return x + if trans == 1: + return x.conj().T + if trans == 2: + return x.T + + @pytest.mark.parametrize( "b_func, b_size", [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], ids=["b_col_vec", "b_matrix", "b_vec"], ) -def test_solve_triangular(b_func, b_size): - A = pt.matrix("A") - b = b_func("b") +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"]) +@pytest.mark.parametrize( + "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] +) +# @pytest.mark.parametrize('complex', [True, False], ids=['complex', 'real']) +def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex=False): + # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, why? + complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" + dtype = complex_dtype if complex else config.floatX - X = pt.linalg.solve_triangular(A, b, lower=True) + A = pt.matrix("A", dtype=dtype) + b = b_func("b", dtype=dtype) + + X = pt.linalg.solve_triangular( + A, b, lower=lower, trans=trans, unit_diagonal=unit_diag + ) f = pytensor.function([A, b], X, mode="NUMBA") - A_val = np.random.normal(size=(5, 5)).astype(config.floatX) - A_sym = A_val @ A_val.T - A_tri = np.linalg.cholesky(A_sym) + A_val = np.random.normal(size=(5, 5)) + b = np.random.normal(size=b_size) + + if complex: + A_val = A_val + np.random.normal(size=(5, 5)) * 1j + b = b + np.random.normal(size=b_size) * 1j + A_sym = A_val @ A_val.conj().T + + A_tri = np.linalg.cholesky(A_sym).astype(dtype) + if unit_diag: + adj_mat = np.ones((5, 5)) + adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) + A_tri = A_tri * adj_mat + + A_tri = A_tri.astype(dtype) + b = b.astype(dtype) - b = np.random.normal(size=b_size).astype(config.floatX) + if not lower: + A_tri = A_tri.T X_np = f(A_tri, b) - np.testing.assert_allclose(A_tri @ X_np, b, atol=ATOL, rtol=RTOL) + np.testing.assert_allclose( + transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL + )