-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Numba implementation of Blockwise #1015
Open
ricardoV94
wants to merge
2
commits into
pymc-devs:main
Choose a base branch
from
ricardoV94:numba_blockwise
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from typing import cast | ||
|
||
from numba.core.extending import overload | ||
from numba.np.unsafe.ndarray import to_fixed_tuple | ||
|
||
from pytensor.link.numba.dispatch.basic import numba_funcify | ||
from pytensor.link.numba.dispatch.vectorize_codegen import ( | ||
_jit_options, | ||
_vectorized, | ||
encode_literals, | ||
store_core_outputs, | ||
) | ||
from pytensor.tensor import TensorVariable, get_vector_length | ||
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape | ||
|
||
|
||
@numba_funcify.register | ||
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): | ||
[blockwise_node] = op.fgraph.apply_nodes | ||
blockwise_op: Blockwise = blockwise_node.op | ||
core_op = blockwise_op.core_op | ||
nin = len(blockwise_node.inputs) | ||
nout = len(blockwise_node.outputs) | ||
if nout > 3: | ||
raise NotImplementedError( | ||
"Current implementation of BlockwiseWithCoreShape does not support more than 3 outputs." | ||
) | ||
|
||
core_shapes_len = [get_vector_length(sh) for sh in node.inputs[nin:]] | ||
core_shape_0 = core_shapes_len[0] if nout > 0 else None | ||
core_shape_1 = core_shapes_len[1] if nout > 1 else None | ||
core_shape_2 = core_shapes_len[2] if nout > 2 else None | ||
|
||
core_node = blockwise_op._create_dummy_core_node( | ||
cast(tuple[TensorVariable], blockwise_node.inputs) | ||
) | ||
core_op_fn = numba_funcify( | ||
core_op, | ||
node=core_node, | ||
parent_node=node, | ||
fastmath=_jit_options["fastmath"], | ||
**kwargs, | ||
) | ||
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) | ||
|
||
batch_ndim = blockwise_op.batch_ndim(node) | ||
|
||
# numba doesn't support nested literals right now... | ||
input_bc_patterns = encode_literals( | ||
tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs) | ||
) | ||
output_bc_patterns = encode_literals( | ||
tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs) | ||
) | ||
output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs)) | ||
inplace_pattern = encode_literals(()) | ||
|
||
def blockwise_wrapper(*inputs_and_core_shapes): | ||
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:] | ||
# Appease numba Gods :( | ||
# Secular solution welcomed | ||
if nout == 1: | ||
tuple_core_shapes = (to_fixed_tuple(core_shapes[0], core_shape_0),) | ||
elif nout == 2: | ||
tuple_core_shapes = ( | ||
to_fixed_tuple(core_shapes[0], core_shape_0), | ||
to_fixed_tuple(core_shapes[1], core_shape_1), | ||
) | ||
else: | ||
tuple_core_shapes = ( | ||
to_fixed_tuple(core_shapes[0], core_shape_0), | ||
to_fixed_tuple(core_shapes[1], core_shape_1), | ||
to_fixed_tuple(core_shapes[2], core_shape_2), | ||
) | ||
return _vectorized( | ||
core_op_fn, | ||
input_bc_patterns, | ||
output_bc_patterns, | ||
output_dtypes, | ||
inplace_pattern, | ||
(), # constant_inputs | ||
inputs, | ||
tuple_core_shapes, | ||
None, # size | ||
) | ||
|
||
def blockwise(*inputs_and_core_shapes): | ||
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented") | ||
|
||
@overload(blockwise, jit_options=_jit_options) | ||
def ov_blockwise(*inputs_and_core_shapes): | ||
return blockwise_wrapper | ||
|
||
return blockwise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from pytensor.compile import optdb | ||
from pytensor.graph import node_rewriter | ||
from pytensor.graph.basic import applys_between | ||
from pytensor.graph.rewriting.basic import out2in | ||
from pytensor.tensor.basic import as_tensor, constant | ||
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape | ||
from pytensor.tensor.rewriting.shape import ShapeFeature | ||
|
||
|
||
@node_rewriter([Blockwise]) | ||
def introduce_explicit_core_shape_blockwise(fgraph, node): | ||
"""Introduce the core shape of a Blockwise. | ||
|
||
We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph | ||
that has an extra "non-functional" input that represents the core shape of the Blockwise variable. | ||
This core_shape is used by the numba backend to pre-allocate the output array. | ||
|
||
If available, the core shape is extracted from the shape feature of the graph, | ||
which has a higher change of having been simplified, optimized, constant-folded. | ||
If missing, we fall back to the op._supp_shape_from_params method. | ||
|
||
This rewrite is required for the numba backend implementation of Blockwise. | ||
|
||
Example | ||
------- | ||
|
||
.. code-block:: python | ||
|
||
import pytensor | ||
import pytensor.tensor as pt | ||
|
||
x = pt.tensor("x", shape=(5, None, None)) | ||
outs = pt.linalg.svd(x, compute_uv=True) | ||
pytensor.dprint(outs) | ||
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A] | ||
# └─ x [id B] | ||
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A] | ||
# └─ ··· | ||
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A] | ||
# └─ ··· | ||
|
||
# After the rewrite, note the new 3 core shape inputs | ||
fn = pytensor.function([x], outs, mode="NUMBA") | ||
fn.dprint(print_type=False) | ||
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6 | ||
# ├─ x [id B] | ||
# ├─ MakeVector{dtype='int64'} [id C] 5 | ||
# │ ├─ Shape_i{1} [id D] 2 | ||
# │ │ └─ x [id B] | ||
# │ └─ Shape_i{1} [id D] 2 | ||
# │ └─ ··· | ||
# ├─ MakeVector{dtype='int64'} [id E] 4 | ||
# │ └─ Minimum [id F] 3 | ||
# │ ├─ Shape_i{1} [id D] 2 | ||
# │ │ └─ ··· | ||
# │ └─ Shape_i{2} [id G] 0 | ||
# │ └─ x [id B] | ||
# └─ MakeVector{dtype='int64'} [id H] 1 | ||
# ├─ Shape_i{2} [id G] 0 | ||
# │ └─ ··· | ||
# └─ Shape_i{2} [id G] 0 | ||
# └─ ··· | ||
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6 | ||
# └─ ··· | ||
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6 | ||
# └─ ··· | ||
""" | ||
if len(node.outputs) > 3: | ||
# Current implementation of BlockwiseWithCoreShape does not support more than 3 outputs. | ||
return None | ||
|
||
op: Blockwise = node.op # type: ignore[annotation-unchecked] | ||
batch_ndim = op.batch_ndim(node) | ||
|
||
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] | ||
if shape_feature: | ||
core_shapes = [ | ||
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] | ||
for out in node.outputs | ||
] | ||
else: | ||
input_shapes = [tuple(inp.shape) for inp in node.inputs] | ||
core_shapes = [ | ||
out_shape[batch_ndim:] | ||
for out_shape in op.infer_shape(None, node, input_shapes) | ||
] | ||
|
||
core_shapes = [ | ||
as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64") | ||
for core_shape in core_shapes | ||
] | ||
|
||
if any( | ||
isinstance(node.op, Blockwise) | ||
for node in applys_between(node.inputs, core_shapes) | ||
): | ||
# If Blockwise shows up in the shape graph we can't introduce the core shape | ||
return None | ||
|
||
return BlockwiseWithCoreShape( | ||
[*node.inputs, *core_shapes], | ||
node.outputs, | ||
destroy_map=op.destroy_map, | ||
)(*node.inputs, *core_shapes, return_list=True) | ||
|
||
|
||
optdb.register( | ||
introduce_explicit_core_shape_blockwise.__name__, | ||
out2in(introduce_explicit_core_shape_blockwise), | ||
"numba", | ||
position=100, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pytensor import function | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.tensor import tensor | ||
from pytensor.tensor.basic import ARange | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.nlinalg import SVD, Det | ||
from pytensor.tensor.slinalg import Cholesky, cholesky | ||
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode | ||
|
||
|
||
# Fails if object mode warning is issued when not expected | ||
pytestmark = pytest.mark.filterwarnings("error") | ||
|
||
|
||
@pytest.mark.parametrize("shape_opt", [True, False], ids=str) | ||
@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str) | ||
def test_blockwise(core_op, shape_opt): | ||
x = tensor(shape=(5, None, None)) | ||
outs = Blockwise(core_op=core_op)(x, return_list=True) | ||
|
||
mode = ( | ||
numba_mode.including("ShapeOpt") | ||
if shape_opt | ||
else numba_mode.excluding("ShapeOpt") | ||
) | ||
x_test = np.eye(3) * np.arange(1, 6)[:, None, None] | ||
compare_numba_and_py( | ||
([x], outs), | ||
[x_test], | ||
numba_mode=mode, | ||
eval_obj_mode=False, | ||
) | ||
|
||
|
||
def test_non_square_blockwise(): | ||
"""Test that Op that cannot always be blockwised at runtime fails gracefully.""" | ||
x = tensor(shape=(3,), dtype="int64") | ||
out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1) | ||
|
||
with pytest.warns(UserWarning, match="Numba will use object mode"): | ||
fn = function([x], out, mode="NUMBA") | ||
|
||
np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5))) | ||
|
||
with pytest.raises(ValueError): | ||
fn([3, 4, 5]) | ||
|
||
|
||
def test_too_many_outputs_blockwise(): | ||
"""Current implementation of Blockwise does not support more than 3 outputs.""" | ||
x = tensor("x", shape=()) | ||
core_op = OpFromGraph([x], [x + i for i in range(4)]) | ||
|
||
xs = tensor("x", shape=(3,)) | ||
outs = Blockwise(core_op=core_op, signature="()->(),(),(),()")(xs) | ||
|
||
with pytest.warns(UserWarning, match="Numba will use object mode"): | ||
compare_numba_and_py(([xs], outs), [np.arange(3)]) | ||
|
||
|
||
def test_blockwise_benchmark(benchmark): | ||
x = tensor(shape=(5, 3, 3)) | ||
out = cholesky(x) | ||
assert isinstance(out.owner.op, Blockwise) | ||
|
||
fn = function([x], out, mode="NUMBA") | ||
x_test = np.eye(3) * np.arange(1, 6)[:, None, None] | ||
fn(x_test) # JIT compile | ||
benchmark(fn, x_test) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anybody has an idea on how to do this dynamically would be great. Do we have to do string generation 😭?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you opposed to cheesing it?
(I don't have full context)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numba doesn't support that in this context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ewww. Maybe you could try a bunch of
eval
statements?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need to go down the string generation as we do for some other Ops (like Scan). But I didn't want to :)