Skip to content

Commit

Permalink
Disable type casting for pow() int arguments (inducer#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Sep 18, 2024
1 parent 3731ce5 commit 5e7d368
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 10 deletions.
16 changes: 11 additions & 5 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def _binary_op(
np.dtype[Any]] = _np_result_dtype,
reverse: bool = False,
cast_to_result_dtype: bool = True,
is_pow: bool = False,
) -> Array:

# {{{ sanity checks
Expand All @@ -601,14 +602,16 @@ def _binary_op(
get_result_type,
tags=tags,
non_equality_tags=non_equality_tags,
cast_to_result_dtype=cast_to_result_dtype)
cast_to_result_dtype=cast_to_result_dtype,
is_pow=is_pow)
else:
result = utils.broadcast_binary_op(
self, other, op,
get_result_type,
tags=tags,
non_equality_tags=non_equality_tags,
cast_to_result_dtype=cast_to_result_dtype)
cast_to_result_dtype=cast_to_result_dtype,
is_pow=is_pow)

assert isinstance(result, Array)
return result
Expand Down Expand Up @@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array:
__rtruediv__ = partialmethod(_binary_op, operator.truediv,
get_result_type=_truediv_result_type, reverse=True)

__pow__ = partialmethod(_binary_op, operator.pow)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True)
__pow__ = partialmethod(_binary_op, operator.pow, is_pow=True)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True)

__neg__ = partialmethod(_unary_op, operator.neg)

Expand Down Expand Up @@ -2403,7 +2406,8 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool:
lambda x, y: np.dtype(np.bool_),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(stacklevel=2),
cast_to_result_dtype=False
cast_to_result_dtype=False,
is_pow=False,
) # type: ignore[return-value]


Expand Down Expand Up @@ -2467,6 +2471,7 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
cast_to_result_dtype=False,
is_pow=False,
) # type: ignore[return-value]


Expand All @@ -2484,6 +2489,7 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
cast_to_result_dtype=False,
is_pow=False,
) # type: ignore[return-value]


Expand Down
15 changes: 13 additions & 2 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
tags: frozenset[Tag],
non_equality_tags: frozenset[Tag],
cast_to_result_dtype: bool,
is_pow: bool,
) -> ArrayOrScalar:
from pytato.array import _get_default_axes

Expand Down Expand Up @@ -225,9 +226,19 @@ def cast_to_result_type(
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_

expr = TypeCast(result_dtype, expr)
# See https://github.com/inducer/pytato/issues/542
# on why pow() + integers is not typecast to float or complex.
if not (is_pow
and np.issubdtype(array.dtype, np.integer)
and not np.issubdtype(result_dtype, np.integer)):
expr = TypeCast(result_dtype, expr)
elif isinstance(expr, SCALAR_CLASSES):
expr = result_dtype.type(expr)
# See https://github.com/inducer/pytato/issues/542
# on why pow() + integers is not typecast to float or complex.
if not (is_pow
and np.issubdtype(type(expr), np.integer)
and not np.issubdtype(result_dtype, np.integer)):
expr = result_dtype.type(expr)

return expr

Expand Down
62 changes: 59 additions & 3 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse):
"logical_and"))
@pytest.mark.parametrize("reverse", (False, True))
def test_array_array_binary_arith(ctx_factory, which, reverse):
if which == "sub":
pytest.skip("https://github.com/inducer/loopy/issues/131")

cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal",
Expand Down Expand Up @@ -2008,6 +2005,65 @@ def call_bar(tracer, x, y):
np.testing.assert_allclose(result_out[k], expect_out[k])


def test_pow_arg_casting(ctx_factory):
# Check that pow() arguments are not typecast from int
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

types = (int, np.int32, np.int64, float, np.float32, np.float64)

for base_scalar in (True, False):
for exponent_scalar in (True, False):
if base_scalar and exponent_scalar:
# Not supported in pytato
continue

for base_tp in types:
if base_scalar:
base_np = base_tp(2)
base = base_np
else:
base_np = np.array([1, 2, 3], base_tp)
base = pt.make_data_wrapper(base_np)

for exponent_tp in types:
if exponent_scalar:
exponent_np = exponent_tp(2)
exponent = exponent_np
else:
exponent_np = np.array([1, 2, 3], exponent_tp)
exponent = pt.make_data_wrapper(exponent_np)

out = base ** exponent

_, (pt_out,) = pt.generate_loopy(out)(cq)

np_out = np.power(base_np, exponent_np)

assert pt_out.dtype == np_out.dtype
np.testing.assert_allclose(np_out, pt_out)

if np.issubdtype(exponent_tp, np.integer):
assert exponent_tp in (int, np.int32, np.int64)

if exponent_scalar:
# We do cast between different int types
assert (type(out.expr.exponent) in
(int, np.int32, np.int64)
or out.expr.exponent.dtype == np_out.dtype)
else:
assert out.bindings["_in1"].dtype in \
(int, np.int32, np.int64)
else:
assert exponent_tp in (float, np.float32, np.float64)
if exponent_scalar:
assert type(out.expr.exponent) == np_out.dtype \
or out.expr.exponent.dtype == np_out.dtype
else:
assert out.bindings["_in1"].dtype in \
(float, np.float32, np.float64)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 5e7d368

Please sign in to comment.