diff --git a/pytato/array.py b/pytato/array.py index 6f9221ae8..619cca77d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -582,6 +582,7 @@ def _binary_op( np.dtype[Any]] = _np_result_dtype, reverse: bool = False, cast_to_result_dtype: bool = True, + is_pow: bool = False, ) -> Array: # {{{ sanity checks @@ -601,14 +602,16 @@ def _binary_op( get_result_type, tags=tags, non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + cast_to_result_dtype=cast_to_result_dtype, + is_pow=is_pow) else: result = utils.broadcast_binary_op( self, other, op, get_result_type, tags=tags, non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + cast_to_result_dtype=cast_to_result_dtype, + is_pow=is_pow) assert isinstance(result, Array) return result @@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array: __rtruediv__ = partialmethod(_binary_op, operator.truediv, get_result_type=_truediv_result_type, reverse=True) - __pow__ = partialmethod(_binary_op, operator.pow) - __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True) + __pow__ = partialmethod(_binary_op, operator.pow, is_pow=True) + __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True) __neg__ = partialmethod(_unary_op, operator.neg) @@ -2403,7 +2406,8 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(stacklevel=2), - cast_to_result_dtype=False + cast_to_result_dtype=False, + is_pow=False, ) # type: ignore[return-value] @@ -2467,6 +2471,7 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), cast_to_result_dtype=False, + is_pow=False, ) # type: ignore[return-value] @@ -2484,6 +2489,7 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), cast_to_result_dtype=False, + is_pow=False, ) # type: ignore[return-value] diff --git a/pytato/utils.py b/pytato/utils.py index f4261685c..2e0a7d3d2 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, tags: frozenset[Tag], non_equality_tags: frozenset[Tag], cast_to_result_dtype: bool, + is_pow: bool, ) -> ArrayOrScalar: from pytato.array import _get_default_axes @@ -225,9 +226,19 @@ def cast_to_result_type( # Loopy's type casts don't like casting to bool assert result_dtype != np.bool_ - expr = TypeCast(result_dtype, expr) + # See https://github.com/inducer/pytato/issues/542 + # on why pow() + integers is not typecast to float or complex. + if not (is_pow + and np.issubdtype(array.dtype, np.integer) + and not np.issubdtype(result_dtype, np.integer)): + expr = TypeCast(result_dtype, expr) elif isinstance(expr, SCALAR_CLASSES): - expr = result_dtype.type(expr) + # See https://github.com/inducer/pytato/issues/542 + # on why pow() + integers is not typecast to float or complex. + if not (is_pow + and np.issubdtype(type(expr), np.integer) + and not np.issubdtype(result_dtype, np.integer)): + expr = result_dtype.type(expr) return expr diff --git a/test/test_codegen.py b/test/test_codegen.py index 5380f0676..9aeb60dd8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -326,9 +326,6 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): "logical_and")) @pytest.mark.parametrize("reverse", (False, True)) def test_array_array_binary_arith(ctx_factory, which, reverse): - if which == "sub": - pytest.skip("https://github.com/inducer/loopy/issues/131") - cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", @@ -2008,6 +2005,65 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) +def test_pow_arg_casting(ctx_factory): + # Check that pow() arguments are not typecast from int + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + types = (int, np.int32, np.int64, float, np.float32, np.float64) + + for base_scalar in (True, False): + for exponent_scalar in (True, False): + if base_scalar and exponent_scalar: + # Not supported in pytato + continue + + for base_tp in types: + if base_scalar: + base_np = base_tp(2) + base = base_np + else: + base_np = np.array([1, 2, 3], base_tp) + base = pt.make_data_wrapper(base_np) + + for exponent_tp in types: + if exponent_scalar: + exponent_np = exponent_tp(2) + exponent = exponent_np + else: + exponent_np = np.array([1, 2, 3], exponent_tp) + exponent = pt.make_data_wrapper(exponent_np) + + out = base ** exponent + + _, (pt_out,) = pt.generate_loopy(out)(cq) + + np_out = np.power(base_np, exponent_np) + + assert pt_out.dtype == np_out.dtype + np.testing.assert_allclose(np_out, pt_out) + + if np.issubdtype(exponent_tp, np.integer): + assert exponent_tp in (int, np.int32, np.int64) + + if exponent_scalar: + # We do cast between different int types + assert (type(out.expr.exponent) in + (int, np.int32, np.int64) + or out.expr.exponent.dtype == np_out.dtype) + else: + assert out.bindings["_in1"].dtype in \ + (int, np.int32, np.int64) + else: + assert exponent_tp in (float, np.float32, np.float64) + if exponent_scalar: + assert type(out.expr.exponent) == np_out.dtype \ + or out.expr.exponent.dtype == np_out.dtype + else: + assert out.bindings["_in1"].dtype in \ + (float, np.float32, np.float64) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])