Skip to content

Commit

Permalink
Fail graciously in local_pow_to_nested_squaring when static type shap…
Browse files Browse the repository at this point in the history
…e is updated
  • Loading branch information
ricardoV94 committed Sep 29, 2023
1 parent 3169197 commit 603d9ae
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 73 deletions.
114 changes: 58 additions & 56 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
"""

if node.op == at_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)

# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
# it could be that in that case, numpy make the comparison
# into the wrong type(do in int8 that overflow.)
if isinstance(y, np.ndarray):
assert y.size == 1
try:
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y)
for i in range(int(np.log2(y_to_do))):
pow2.append(sqr(pow2[i]))
pow2_scal.append(aes.sqr(pow2_scal[i]))
rval1 = None
rval1_scal = None
while y_to_do > 0:
log_to_do = int(np.log2(y_to_do))
if rval1:
rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do

if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(
aes.Composite([pow2_scal[0]], [rval1_scal])
).make_node(xsym)
if y < 0:
rval = [reciprocal(rval1)]
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)

# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
# it could be that in that case, numpy make the comparison
# into the wrong type(do in int8 that overflow.)
if isinstance(y, np.ndarray):
assert y.size == 1
try:
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y)
for i in range(int(np.log2(y_to_do))):
pow2.append(sqr(pow2[i]))
pow2_scal.append(aes.sqr(pow2_scal[i]))
rval1 = None
rval1_scal = None
while y_to_do > 0:
log_to_do = int(np.log2(y_to_do))
if rval1:
rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
return rval
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do

if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(aes.Composite([pow2_scal[0]], [rval1_scal])).make_node(
xsym
)
if y < 0:
rval = [reciprocal(rval1)]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
return None
return rval


@register_specialize
Expand Down
58 changes: 41 additions & 17 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, second, switch
from pytensor.tensor.basic import Alloc, as_tensor, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down Expand Up @@ -69,7 +70,7 @@
from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import (
prod,
rad2deg,
Expand Down Expand Up @@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring():
utt.assert_allclose(f(val_no0), val_no0 ** (-16))


def test_local_pow_to_nested_squaring_fails_gracefully():
# Reported in #456

x = vector("x", shape=(1,))
# Create an Apply that does not have precise output shape
node = Apply(
op=pt_pow,
inputs=[x, constant([2.0])],
outputs=[tensor(shape=(None,))],
)
y = node.default_output()

fn = function([x], y)

# Check rewrite is not applied (this could change in the future)
assert any(
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
for node in fn.maker.fgraph.apply_nodes
)

np.testing.assert_allclose(fn([2.0]), np.array([4.0]))


class TestFuncInverse:
def setup_method(self):
mode = get_default_mode()
Expand Down Expand Up @@ -2435,21 +2459,21 @@ def test_elemwise(self):
s1 = at.switch(c, a, b)
s2 = at.switch(c, x, y)
for op in (
add,
sub,
mul,
true_div,
int_div,
floor_div,
minimum,
maximum,
gt,
lt,
ge,
le,
eq,
neq,
at_pow,
add,
sub,
mul,
true_div,
int_div,
floor_div,
minimum,
maximum,
gt,
lt,
ge,
le,
eq,
neq,
pt_pow,
):
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert debugprint(g, file="str").count("Switch") == 1
Expand Down

0 comments on commit 603d9ae

Please sign in to comment.