Skip to content

Commit

Permalink
Add Numba implementation of Blockwise
Browse files Browse the repository at this point in the history
Restricted to 3 outputs, due to limitations in jitting of Numba functions
  • Loading branch information
ricardoV94 committed Oct 6, 2024
1 parent 06d9a49 commit 2014cd9
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 6 deletions.
9 changes: 5 additions & 4 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify

# Load dispatch specializations
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.tensor_basic
import pytensor.link.numba.dispatch.blockwise
import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.extra_ops
import pytensor.link.numba.dispatch.nlinalg
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic

# isort: on
90 changes: 90 additions & 0 deletions pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from numba.core.extending import overload
from numba.np.unsafe.ndarray import to_fixed_tuple

from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
store_core_outputs,
)
from pytensor.tensor import get_vector_length
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape


@numba_funcify.register
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
[blockwise_node] = op.fgraph.apply_nodes
blockwise_op: Blockwise = blockwise_node.op
core_op = blockwise_op.core_op
nin = len(blockwise_node.inputs)
nout = len(blockwise_node.outputs)
if nout > 3:
raise NotImplementedError(
"Current implementation of BlockwiseWithCoreShape does not support more than 3 outputs."
)

core_shapes_len = [get_vector_length(sh) for sh in node.inputs[nin:]]
core_shape_0 = core_shapes_len[0] if nout > 0 else None
core_shape_1 = core_shapes_len[1] if nout > 1 else None
core_shape_2 = core_shapes_len[2] if nout > 2 else None

core_node = blockwise_op._create_dummy_core_node(blockwise_node.inputs)
core_op_fn = numba_funcify(
core_op,
node=core_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)

batch_ndim = blockwise_op.batch_ndim(node)

# numba doesn't support nested literals right now...
input_bc_patterns = encode_literals(
tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs)
)
output_bc_patterns = encode_literals(
tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs)
)
output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs))
inplace_pattern = encode_literals(())

def blockwise_wrapper(*inputs_and_core_shapes):
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:]
# Appease numba Gods :(
# Secular solution welcomed
if nout == 1:
tuple_core_shapes = (to_fixed_tuple(core_shapes[0], core_shape_0),)
elif nout == 2:
tuple_core_shapes = (
to_fixed_tuple(core_shapes[0], core_shape_0),
to_fixed_tuple(core_shapes[1], core_shape_1),
)
else:
tuple_core_shapes = (
to_fixed_tuple(core_shapes[0], core_shape_0),
to_fixed_tuple(core_shapes[1], core_shape_1),
to_fixed_tuple(core_shapes[2], core_shape_2),
)
return _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(), # constant_inputs
inputs,
tuple_core_shapes,
None, # size
)

def blockwise(*inputs_and_core_shapes):
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented")

@overload(blockwise, jit_options=_jit_options)
def ov_blockwise(*inputs_and_core_shapes):
return blockwise_wrapper

return blockwise
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params):
return rng, draws

def random(core_shape, rng, size, *dist_params):
pass
raise NotImplementedError("Non-jitted random variable not implemented")

@overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
Expand Down
8 changes: 8 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:

class OpWithCoreShape(OpFromGraph):
"""Generalizes an `Op` to include core shape as an additional input."""


class BlockwiseWithCoreShape(OpWithCoreShape):
"""Generalizes a Blockwise `Op` to include a core shape parameter."""

def __str__(self):
[blockwise_node] = self.fgraph.apply_nodes
return f"[{blockwise_node.op!s}]"
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
Expand Down
112 changes: 112 additions & 0 deletions pytensor/tensor/rewriting/numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.basic import applys_between
from pytensor.graph.rewriting.basic import out2in
from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature


@node_rewriter([Blockwise])
def introduce_explicit_core_shape_blockwise(fgraph, node):
"""Introduce the core shape of a Blockwise.
We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph
that has an extra "non-functional" input that represents the core shape of the Blockwise variable.
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher change of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of Blockwise.
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
x = pt.tensor("x", shape=(5, None, None))
outs = pt.linalg.svd(x, compute_uv=True)
pytensor.dprint(outs)
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A]
# └─ x [id B]
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A]
# └─ ···
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A]
# └─ ···
# After the rewrite, note the new 3 core shape inputs
fn = pytensor.function([x], outs, mode="NUMBA")
fn.dprint(print_type=False)
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6
# ├─ x [id B]
# ├─ MakeVector{dtype='int64'} [id C] 5
# │ ├─ Shape_i{1} [id D] 2
# │ │ └─ x [id B]
# │ └─ Shape_i{1} [id D] 2
# │ └─ ···
# ├─ MakeVector{dtype='int64'} [id E] 4
# │ └─ Minimum [id F] 3
# │ ├─ Shape_i{1} [id D] 2
# │ │ └─ ···
# │ └─ Shape_i{2} [id G] 0
# │ └─ x [id B]
# └─ MakeVector{dtype='int64'} [id H] 1
# ├─ Shape_i{2} [id G] 0
# │ └─ ···
# └─ Shape_i{2} [id G] 0
# └─ ···
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6
# └─ ···
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
# └─ ···
"""
if len(node.outputs) > 3:
# Current implementation of BlockwiseWithCoreShape does not support more than 3 outputs.
return None

op: Blockwise = node.op
batch_ndim = op.batch_ndim(node)

shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
if shape_feature:
core_shapes = [
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
for out in node.outputs
]
else:
input_shapes = [tuple(inp.shape) for inp in node.inputs]
core_shapes = [
out_shape[batch_ndim:]
for out_shape in op.infer_shape(None, node, input_shapes)
]

core_shapes = [
as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64")
for core_shape in core_shapes
]

if any(
isinstance(node.op, Blockwise)
for node in applys_between(node.inputs, core_shapes)
):
# If Blockwise shows up in the shape graph we can't introduce the core shape
return None

return BlockwiseWithCoreShape(
[*node.inputs, *core_shapes],
node.outputs,
destroy_map=op.destroy_map,
)(*node.inputs, *core_shapes, return_list=True)


optdb.register(
introduce_explicit_core_shape_blockwise.__name__,
out2in(introduce_explicit_core_shape_blockwise),
"numba",
position=100,
)
2 changes: 1 addition & 1 deletion tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def compare_numba_and_py(
Parameters
----------
fgraph
`FunctionGraph` or inputs to compare.
`FunctionGraph` or tuple(inputs, outputs) to compare.
inputs
Numeric inputs to be passed to the compiled graphs.
assert_fn
Expand Down
72 changes: 72 additions & 0 deletions tests/link/numba/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
import pytest

from pytensor import function
from pytensor.compile.builders import OpFromGraph
from pytensor.link.numba.test_basic import compare_numba_and_py, numba_mode
from pytensor.tensor import tensor
from pytensor.tensor.basic import ARange
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky


# Fails if object mode warning is issued when not expected
pytestmark = pytest.mark.filterwarnings("error")


@pytest.mark.parametrize("shape_opt", [True, False], ids=str)
@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str)
def test_blockwise(core_op, shape_opt):
x = tensor(shape=(5, None, None))
outs = Blockwise(core_op=core_op)(x, return_list=True)

mode = (
numba_mode.including("ShapeOpt")
if shape_opt
else numba_mode.excluding("ShapeOpt")
)
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
compare_numba_and_py(
([x], outs),
[x_test],
numba_mode=mode,
eval_obj_mode=False,
)


def test_non_square_blockwise():
"""Test that Op that cannot always be blockwised at runtime fails gracefully."""
x = tensor(shape=(3,), dtype="int64")
out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1)

with pytest.warns(UserWarning, match="Numba will use object mode"):
fn = function([x], out, mode="NUMBA")

np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5)))

with pytest.raises(ValueError):
fn([3, 4, 5])


def test_too_many_outputs_blockwise():
"""Current implementation of Blockwise does not support more than 3 outputs."""
x = tensor("x", shape=())
core_op = OpFromGraph([x], [x + i for i in range(4)])

xs = tensor("x", shape=(3,))
outs = Blockwise(core_op=core_op, signature="()->(),(),(),()")(xs)

with pytest.warns(UserWarning, match="Numba will use object mode"):
compare_numba_and_py(([xs], outs), [np.arange(3)])


def test_blockwise_benchmark(benchmark):
x = tensor(shape=(5, 3, 3))
out = cholesky(x)
assert isinstance(out.owner.op, Blockwise)

fn = function([x], out, mode="NUMBA")
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
fn(x_test) # JIT compile
benchmark(fn, x_test)

0 comments on commit 2014cd9

Please sign in to comment.