diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index aa063e46dc..9f2852596b 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -1,6 +1,7 @@ from textwrap import indent import numpy as np +from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify @@ -150,14 +151,60 @@ def split(tensor, axis, indices): @numba_funcify.register(ExtractDiag) -def numba_funcify_ExtractDiag(op, **kwargs): +def numba_funcify_ExtractDiag(op, node, **kwargs): offset = op.offset - # axis1 = op.axis1 - # axis2 = op.axis2 - - @numba_basic.numba_njit(inline="always") - def extract_diag(x): - return np.diag(x, k=offset) + axis1, axis2 = normalize_axis_tuple((op.axis1, op.axis2), node.inputs[0].type.ndim) + + if axis1 > axis2: + axis1, axis2 = axis2, axis1 + offset = -offset + + view = op.view_map + + if axis1 == 0 and axis2 == 1 and node.inputs[0].type.ndim == 2: + + @numba_basic.numba_njit(inline="always") + def extract_diag(x): + out = np.diag(x, k=offset) + + if not view: + out = out.copy() + + return out + + else: + first_axis = min(axis1, axis2) + last_axis = max(axis1, axis2) + first_axisp1 = first_axis + 1 + last_axisp1 = last_axis + 1 + leading_dims = (slice(None),) * first_axis + middle_dims = (slice(None),) * (last_axis - first_axis - 1) + + @numba_basic.numba_njit(inline="always") + def extract_diag(x): + if offset >= 0: + diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - abs(offset))) + else: + diag_len = min(x.shape[axis2], max(0, x.shape[axis1] - abs(offset))) + base_shape = ( + x.shape[:first_axis] + + x.shape[first_axisp1:last_axis] + + x.shape[last_axisp1:] + ) + out_shape = base_shape + (diag_len,) + out = np.empty(out_shape) + + # Empty case + if diag_len == 0: + return out + + for i in range(diag_len): + if offset >= 0: + new_entry = x[leading_dims + (i,) + middle_dims + (i + offset,)] + else: + new_entry = x[leading_dims + (i - offset,) + middle_dims + (i,)] + out[..., i] = new_entry + return out return extract_diag diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 047bc18a98..5791a6f230 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -8,6 +8,7 @@ from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph +from pytensor.link.numba.dispatch import numba_funcify from pytensor.scalar import Add from pytensor.tensor.shape import Unbroadcast from tests.link.numba.test_basic import ( @@ -370,6 +371,10 @@ def test_Split_view(): set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 0, ), + ( + set_test_value(at.tensor3(), np.arange(2 * 3 * 3).reshape(2, 3, 3)), + 1, + ), ], ) def test_ExtractDiag(val, offset): @@ -386,6 +391,24 @@ def test_ExtractDiag(val, offset): ) +@pytest.mark.parametrize("k", range(-5, 4)) +@pytest.mark.parametrize( + "axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)) +) +@pytest.mark.parametrize("reverse_axis", (False, True)) +@pytest.mark.slow +def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis): + if reverse_axis: + axis1, axis2 = axis2, axis1 + + x = at.tensor4("x") + x_shape = (2, 3, 4, 5) + x_test = np.arange(np.prod(x_shape)).reshape(x_shape) + out = at.diagonal(x, k, axis1, axis2) + numba_fn = numba_funcify(out.owner.op, out.owner) + np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2)) + + @pytest.mark.parametrize( "n, m, k, dtype", [