Skip to content

Commit

Permalink
Implement Blockwise Op to vectorize existing Ops
Browse files Browse the repository at this point in the history
Inspired by: aesara-devs/aesara#1215

Co-authored-by: Brandon T. Willard <[email protected]>
Co-authored-by: Purna Chandra Mansingh <[email protected]>
Co-authored-by: Sayam Kumar <[email protected]>
Co-authored-by: Kaustubh <[email protected]>
  • Loading branch information
5 people committed Sep 5, 2023
1 parent f49b2cc commit 0ff0f29
Show file tree
Hide file tree
Showing 10 changed files with 966 additions and 46 deletions.
413 changes: 413 additions & 0 deletions pytensor/tensor/blockwise.py

Large diffs are not rendered by default.

73 changes: 39 additions & 34 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
discrete_dtypes,
float_dtypes,
lvector,
)
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq

Expand Down Expand Up @@ -232,7 +234,7 @@ def __str__(self):
return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"

def perform(self, node, inp, out, params):
def perform(self, node, inp, out, params=None):
(res,) = inp
(storage,) = out

Expand Down Expand Up @@ -429,28 +431,12 @@ def get_output_info(self, dim_shuffle, *inputs):
# of all inputs in parallel... the all() gives us each output
# broadcastable bit in turn.

def get_most_specialized_shape(shapes):
shapes = set(shapes)
# All shapes are the same
if len(shapes) == 1:
return tuple(shapes)[0]

# Only valid indeterminate case
if shapes == {None, 1}:
return None

shapes.discard(1)
shapes.discard(None)
if len(shapes) > 1:
raise ValueError
return tuple(shapes)[0]

# it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them)
try:
out_shapes = [
[
get_most_specialized_shape(shape)
broadcast_static_dim_lengths(shape)
for shape in zip(*[inp.type.shape for inp in inputs])
]
] * shadow.nout
Expand Down Expand Up @@ -665,22 +651,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
impl = "c"

if getattr(self, "nfunc_spec", None) and impl != "c":
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__(".".join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
try:
module = getattr(module, sub)
except AttributeError:
module = None
break
self.nfunc = module
self.nfunc = import_func_from_string(self.nfunc_spec[0])

if (
(len(node.inputs) + len(node.outputs)) <= 32
Expand Down Expand Up @@ -1768,3 +1739,37 @@ def _get_vector_length_Elemwise(op, var):
return get_vector_length(var.owner.inputs[0])

raise ValueError(f"Length of {var} cannot be determined")


_vectorize_node.register(Elemwise, vectorize_not_needed)


@_vectorize_node.register(DimShuffle)
def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
new_order = list(range(batched_ndims)) + [
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
]
return DimShuffle(input_broadcastable, new_order).make_node(x)


@_vectorize_node.register(CAReduce)
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply:
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
if not batched_ndims:
return node.op.make_node(x)
axes = op.axis
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
if axes is None:
axes = list(range(node.inputs[0].type.ndim))
else:
axes = list(axes)
new_axes = [axis + batched_ndims for axis in axes]
new_op = op.clone(axis=new_axes)
return new_op.make_node(x)
29 changes: 27 additions & 2 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
as_tensor_variable,
concatenate,
constant,
get_underlying_scalar_constant_value,
get_vector_length,
infer_static_shape,
)
from pytensor.tensor.blockwise import _vectorize_node
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
params_broadcast_shapes,
)
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
Expand Down Expand Up @@ -383,3 +389,22 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):


default_rng = DefaultGeneratorMakerOp()


@_vectorize_node.register(RandomVariable)
def vectorize_random_variable(
op: RandomVariable, node: Apply, rng, size, dtype, *dist_params
) -> Apply:
# If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values
old_size = node.inputs[1]
len_old_size = get_vector_length(old_size)
if len_old_size and equal_computations([old_size], [size]):
bcasted_param = broadcast_params(dist_params, op.ndims_params)[0]
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
if new_param_ndim >= 0:
new_size_dims = bcasted_param.shape[:new_param_ndim]
size = concatenate([new_size_dims, size])

return op.make_node(rng, size, dtype, *dist_params)
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops

Expand Down
41 changes: 41 additions & 0 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pytensor.compile.mode import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise, vectorize_node


@node_rewriter([Blockwise])
def local_useless_blockwise(fgraph, node):
"""
If there is a dispatch implementation that does not require Blockwise, use that instead.
This means a user created a Blockwise manually when there was no need.
Note: This rewrite is not registered by default anywhere
"""
op = node.op
inputs = node.inputs
dummy_core_node = op._create_dummy_core_node(node.inputs)
vect_node = vectorize_node(dummy_core_node, *inputs)
if not isinstance(vect_node.op, Blockwise):
return copy_stack_trace(node.outputs, vect_node.outputs)


@node_rewriter([Blockwise])
def local_useless_unbatched_blockwise(fgraph, node):
"""Remove Blockwise that don't have any batched dims."""
op = node.op
inputs = node.inputs

if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0:
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)


# We register this rewrite late, so that other rewrites need only target Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
position=49,
)
53 changes: 53 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Sequence, Union

import numpy as np

import pytensor
Expand Down Expand Up @@ -107,3 +109,54 @@ def as_list(x):
return list(x)
except TypeError:
return [x]


def import_func_from_string(func_string: str): # -> Optional[Callable]:
func = getattr(np, func_string, None)
if func is not None:
return func

# Not inside NumPy or Scipy. So probably another package like scipy.
module = None
items = func_string.split(".")
for idx in range(1, len(items)):
try:
module = __import__(".".join(items[:idx]))
except ImportError:
break

if module:
for sub in items[1:]:
try:
module = getattr(module, sub)
except AttributeError:
module = None
break
return module


def broadcast_static_dim_lengths(
dim_lengths: Sequence[Union[int, None]]
) -> Union[int, None]:
"""Apply static broadcast given static dim length of inputs (obtained from var.type.shape).
Raises
------
ValueError
When static dim lengths are incompatible
"""

dim_lengths_set = set(dim_lengths)
# All dim_lengths are the same
if len(dim_lengths_set) == 1:
return tuple(dim_lengths_set)[0]

# Only valid indeterminate case
if dim_lengths_set == {None, 1}:
return None

dim_lengths_set.discard(1)
dim_lengths_set.discard(None)
if len(dim_lengths_set) > 1:
raise ValueError
return tuple(dim_lengths_set)[0]
36 changes: 36 additions & 0 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad
from pytensor.raise_op import Assert
from pytensor.tensor.blockwise import vectorize_node
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import all_dtypes, iscalar, tensor
Expand Down Expand Up @@ -202,3 +204,37 @@ def test_RandomVariable_incompatible_size():
ValueError, match="Size length is incompatible with batched dimensions"
):
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))


def test_vectorize_node():
vec = tensor(shape=(None,))
vec.tag.test_value = [0, 0, 0]
mat = tensor(shape=(None, None))
mat.tag.test_value = [[0, 0, 0], [1, 1, 1]]

# Test without size
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat

# Test with size, new size provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[1] = (2, 3)
new_inputs[3] = mat
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
assert vect_node.inputs[3] is mat

# Test with size, new size not provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
assert tuple(vect_node.inputs[1].eval({mat: mat.tag.test_value})) == (2, 3)
38 changes: 38 additions & 0 deletions tests/tensor/rewriting/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pytensor import function
from pytensor.graph import FunctionGraph
from pytensor.scalar import log as scalar_log
from pytensor.tensor import matrix, tensor3
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise


def test_useless_blockwise_of_elemwise():
x = matrix("x")
out = Blockwise(Elemwise(scalar_log), signature="()->()")(x)
assert isinstance(out.owner.op, Blockwise)
assert isinstance(out.owner.op.core_op, Elemwise)

fg = FunctionGraph([x], [out], clone=False)
[new_out] = local_useless_blockwise.transform(fg, out.owner)
assert isinstance(new_out.owner.op, Elemwise)


def test_useless_unbatched_blockwise():
x = matrix("x")
blockwise_op = Blockwise(MatrixPinv(hermitian=False), signature="(m,n)->(n,m)")
out = blockwise_op(x)

assert isinstance(out.owner.op, Blockwise)
assert isinstance(out.owner.op.core_op, MatrixPinv)

fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv)

# Test that it's not removed when there are batched dims
x = tensor3("x")
out = blockwise_op(x)
fn = function([x], out, mode="FAST_COMPILE")
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
Loading

0 comments on commit 0ff0f29

Please sign in to comment.