Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forbid runtime broadcasting by Alloc #390

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytensor/link/jax/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def allocempty(*shape):


@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op, **kwargs):
def jax_funcify_Alloc(op, node, **kwargs):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return res

return alloc
Expand Down
10 changes: 9 additions & 1 deletion pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,24 @@ def numba_funcify_Alloc(op, node, **kwargs):
" " * 4,
)

check_runtime_broadcast = []
for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]):
if val_static_dim is None:
check_runtime_broadcast.append(
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
)
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)

alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""

alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})

return numba_basic.numba_njit(alloc_fn)
Expand Down
85 changes: 57 additions & 28 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,12 @@ class Alloc(COp):

__props__ = ()

_runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

def make_node(self, value, *shape):
value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape)
Expand Down Expand Up @@ -1468,10 +1474,21 @@ def make_node(self, value, *shape):
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()])

@staticmethod
def _check_runtime_broadcast(node, value, shape):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
value_static_shape = node.inputs[0].type.shape
for v_static_dim, value_dim, out_dim in zip(
value_static_shape[::-1], value.shape[::-1], shape[::-1]
):
if v_static_dim is None and value_dim == 1 and out_dim != 1:
raise ValueError(Alloc._runtime_broadcast_error_msg)

def perform(self, node, inputs, out_):
(out,) = out_
v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]])
self._check_runtime_broadcast(node, v, sh)

if out[0] is None or out[0].shape != sh:
if v.size == 1 and v.item() == 0:
out[0] = np.zeros(sh, dtype=v.dtype)
Expand All @@ -1484,51 +1501,63 @@ def perform(self, node, inputs, out_):

def c_code(self, node, name, inp, out, sub):
vv = inp[0]
ndim = len(inp[1:])
(zz,) = out
fail = sub["fail"]

v_static_shape = node.inputs[0].type.shape
o_static_shape = node.outputs[0].type.shape
v_ndim = len(v_static_shape)
o_ndim = len(o_static_shape)
assert o_ndim == len(inp[1:])

# Declare variables
code = f"""
npy_intp shape[{ndim}];
npy_intp shape[{o_ndim}];
int need_new_out;
"""

# Initialize shape
for i, shp_i in enumerate(inp[1:]):
code += """
shape[%(i)s] = ((dtype_%(shp_i)s*) PyArray_DATA(%(shp_i)s))[0];
""" % dict(
i=i, shp_i=shp_i
)
code += f"""
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0];
"""

# Add checks for runtime broadcasting
for i, v_static_dim in enumerate(v_static_shape[::-1]):
if v_static_dim is None:
code += f"""
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be sure that the arrays are long enough for the indices? Unless this is guaranteed for some reason even for invalid inputs I think an explicit check would be good.

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs say this is guaranteed.

https://pytensor.readthedocs.io/en/latest/extending/creating_a_c_op.html#simple-cop-example

Also, in the C code, it is very important to properly validate the inputs
and outputs storage. PyTensor guarantees that the inputs exist and have the
right number of dimensions but it does not guarantee their exact shape. For
instance, if an :class:Op computes the sum of two vectors, it needs to validate that
its two inputs have the same shape. In our case, we do not need to validate
the exact shapes of the inputs because we don't have a need that they match
in any way.

For the outputs, things are a little bit more subtle. PyTensor does not
guarantee that they have been allocated but it does guarantee that, if they
have been allocated, they have the right number of dimension. Again, PyTensor
offers no guarantee on the exact shapes. This means that, in our example, we
need to validate that the output storage has been allocated and has the same
shape as our vector input. If it is not the case, we allocate a new output
storage with the right shape and number of dimensions.

This is not true if fn.trust_input=True, and/or if an Op returns an output with the wrong shape, so I don't know what they mean by "guarantee". However this probably means most Ops are written with this assumption?

This seems to align with the pre-existing check for output having right shape (they index in a loop without checking if ndims are enough).

{{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
{fail}
}}
"""

code += """
int need_new_out = (NULL == %(zz)s);
for (int i = 0; i < %(ndim)s; i++)
need_new_out = (need_new_out
|| (PyArray_DIMS(%(zz)s)[i] != shape[i]));
code += f"""
need_new_out = (NULL == {zz});
for (int i = 0; i < {o_ndim}; i++)
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i]));

if (need_new_out)
{
Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s,
shape, PyArray_TYPE((PyArrayObject*) py_%(vv)s));
if (!%(zz)s)
{
{{
Py_XDECREF({zz});
{zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv}));
if (!{zz})
{{
PyErr_SetString(PyExc_MemoryError, "alloc failed");
%(fail)s
}
}
{fail}
}}
}}

// This function takes care of broadcasting
if (PyArray_CopyInto(%(zz)s, %(vv)s) == -1)
%(fail)s
""" % dict(
vv=vv, ndim=ndim, zz=zz, fail=fail
)
if (PyArray_CopyInto({zz}, {vv}) == -1)
{fail}
"""

return code

def c_code_cache_version(self):
return (2,)
return (4,)

def infer_shape(self, fgraph, node, input_shapes):
return [node.inputs[1:]]
Expand Down Expand Up @@ -1568,7 +1597,7 @@ def grad(self, inputs, grads):
for idx, axis in enumerate(axis_kept):
new_order[axis] = idx
gx = gx.dimshuffle(new_order)
# Dimshuffle to add back the broadcasted dims
# Dimshuffle to add back the broadcasted dims
# The *elements* of the output are not connected to
# the inputs that specify the shape. If you grow the
# shape by epsilon, the existing elements do not
Expand Down
4 changes: 2 additions & 2 deletions tests/link/jax/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from tests.tensor.test_elemwise import TestElemwise


def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))
def test_elemwise_runtime_broadcast():
TestElemwise.check_runtime_broadcast(get_mode("JAX"))


def test_jax_Dimshuffle():
Expand Down
7 changes: 7 additions & 0 deletions tests/link/jax/test_tensor_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

from pytensor.compile import get_mode


jax = pytest.importorskip("jax")
import jax.errors
Expand All @@ -12,6 +14,7 @@
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import TestAlloc


def test_jax_Alloc():
Expand Down Expand Up @@ -50,6 +53,10 @@ def compare_shape_dtype(x, y):
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])


def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("JAX"))


def test_jax_MakeVector():
x = at.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
Expand Down
4 changes: 2 additions & 2 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):


@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
def test_elemwise_runtime_broadcast():
TestElemwise.check_runtime_broadcast(get_mode("NUMBA"))


def test_elemwise_speed(benchmark):
Expand Down
6 changes: 6 additions & 0 deletions tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytensor.tensor as at
import pytensor.tensor.basic as atb
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
Expand All @@ -15,6 +16,7 @@
compare_shape_dtype,
set_test_value,
)
from tests.tensor.test_basic import TestAlloc


rng = np.random.default_rng(42849)
Expand Down Expand Up @@ -45,6 +47,10 @@ def test_Alloc(v, shape):
assert numba_res.shape == shape


def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))


def test_AllocEmpty():
x = at.empty((2, 3), dtype="float32")
x_fg = FunctionGraph([], [x])
Expand Down
57 changes: 56 additions & 1 deletion tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.misc.safe_asarray import _asarray
Expand Down Expand Up @@ -720,6 +720,39 @@ class TestAlloc:
shared = staticmethod(pytensor.shared)
allocs = [Alloc()] * 3

@staticmethod
def check_allocs_in_fgraph(fgraph, n):
assert (
len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)])
== n
)

@staticmethod
def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
floatX = config.floatX
x_v = vector("x", shape=(None,))

out = alloc(x_v, 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)

np.testing.assert_array_equal(
f(x=np.zeros((3,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x=np.zeros((1,), dtype=floatX))

out = alloc(specify_shape(x_v, (1,)), 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)

np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)

def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed())

Expand Down Expand Up @@ -851,6 +884,28 @@ def test_static_shape(self):
with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 6)

def test_alloc_of_view_linker(self):
"""Check we can allocate a new array properly in the C linker when input is a view."""
floatX = config.floatX

x_v = vector("x", shape=(None,))
dim_len = scalar("dim_len", dtype=int)
out = alloc(specify_shape(x_v, (1,)), 5, dim_len)

f = pytensor.function([x_v, dim_len], out, mode=Mode("c"))
assert equal_computations(
f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)]
)

np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX), dim_len=3),
np.zeros((5, 3), dtype=floatX),
)

@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
def test_runtime_broadcast(self, mode):
self.check_runtime_broadcast(mode)


def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
Expand Down
10 changes: 5 additions & 5 deletions tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def test_input_dimensions_overflow(self):
g(*[np.zeros(2**11, config.floatX) for i in range(6)])

@staticmethod
def check_runtime_shapes_error(mode):
def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
x_v = matrix("x")
m_v = vector("m")
Expand All @@ -777,15 +777,15 @@ def check_runtime_shapes_error(mode):
with pytest.raises((ValueError, TypeError)):
f(x, m)

def test_runtime_shapes_error_python(self):
self.check_runtime_shapes_error(Mode(linker="py"))
def test_runtime_broadcast_python(self):
self.check_runtime_broadcast(Mode(linker="py"))

@pytest.mark.skipif(
not pytensor.config.cxx,
reason="G++ not available, so we need to skip this test.",
)
def test_runtime_shapes_error_c(self):
self.check_runtime_shapes_error(Mode(linker="c"))
def test_runtime_broadcast_c(self):
self.check_runtime_broadcast(Mode(linker="c"))

def test_str(self):
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
Expand Down
Loading