Skip to content

Commit

Permalink
Fully support ExtractDiag in numba
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 18, 2023
1 parent 2cd94c6 commit a1c4076
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 7 deletions.
61 changes: 54 additions & 7 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from textwrap import indent

import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify
Expand Down Expand Up @@ -150,14 +151,60 @@ def split(tensor, axis, indices):


@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
def numba_funcify_ExtractDiag(op, node, **kwargs):
offset = op.offset
# axis1 = op.axis1
# axis2 = op.axis2

@numba_basic.numba_njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
axis1, axis2 = normalize_axis_tuple((op.axis1, op.axis2), node.inputs[0].type.ndim)

if axis1 > axis2:
axis1, axis2 = axis2, axis1
offset = -offset

view = op.view_map

if axis1 == 0 and axis2 == 1 and node.inputs[0].type.ndim == 2:

@numba_basic.numba_njit(inline="always")
def extract_diag(x):
out = np.diag(x, k=offset)

if not view:
out = out.copy()

return out

else:
first_axis = min(axis1, axis2)
last_axis = max(axis1, axis2)
first_axisp1 = first_axis + 1
last_axisp1 = last_axis + 1
leading_dims = (slice(None),) * first_axis
middle_dims = (slice(None),) * (last_axis - first_axis - 1)

@numba_basic.numba_njit(inline="always")
def extract_diag(x):
if offset >= 0:
diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - abs(offset)))
else:
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] - abs(offset)))
base_shape = (
x.shape[:first_axis]
+ x.shape[first_axisp1:last_axis]
+ x.shape[last_axisp1:]
)
out_shape = base_shape + (diag_len,)
out = np.empty(out_shape)

# Empty case
if diag_len == 0:
return out

for i in range(diag_len):
if offset >= 0:
new_entry = x[leading_dims + (i,) + middle_dims + (i + offset,)]
else:
new_entry = x[leading_dims + (i - offset,) + middle_dims + (i,)]
out[..., i] = new_entry
return out

return extract_diag

Expand Down
23 changes: 23 additions & 0 deletions tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import (
Expand Down Expand Up @@ -370,6 +371,10 @@ def test_Split_view():
set_test_value(at.vector(), np.arange(10, dtype=config.floatX)),
0,
),
(
set_test_value(at.tensor3(), np.arange(2 * 3 * 3).reshape(2, 3, 3)),
1,
),
],
)
def test_ExtractDiag(val, offset):
Expand All @@ -386,6 +391,24 @@ def test_ExtractDiag(val, offset):
)


@pytest.mark.parametrize("k", range(-5, 4))
@pytest.mark.parametrize(
"axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3))
)
@pytest.mark.parametrize("reverse_axis", (False, True))
@pytest.mark.slow
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
if reverse_axis:
axis1, axis2 = axis2, axis1

x = at.tensor4("x")
x_shape = (2, 3, 4, 5)
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
out = at.diagonal(x, k, axis1, axis2)
numba_fn = numba_funcify(out.owner.op, out.owner)
np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2))


@pytest.mark.parametrize(
"n, m, k, dtype",
[
Expand Down

0 comments on commit a1c4076

Please sign in to comment.