diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index e4398e503b82..7c522daf51f5 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1073,14 +1073,32 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): return _bcast_to(_ensure_ir_value(x, x_aval), shape) -def _integer_pow(a, *, y): - if y == 2: - return a * a - if y == 3: - return a * a * a - if y == -2: - return 1.0 / (a * a) - return jax.lax.pow(a, y) +@register_lowering(lax.integer_pow_p) +def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): + if y == 0: + return _full(x.type, 1) + + is_reciprocal = y < 0 + if is_reciprocal: + y = -y + + acc = None + while y > 0: + y, mod = divmod(y, 2) + if mod: + acc = x if acc is None else _mul(acc, x) + if y > 0: + x = _mul(x, x) + assert acc is not None + + [x_aval] = ctx.avals_in + [out_aval] = ctx.avals_out + acc = _cast(acc, x_aval.dtype, out_aval.dtype) + if is_reciprocal: + signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) + return _truediv(_full(acc.type, 1), acc, signed=signed) + else: + return acc def lower_fun( @@ -1100,7 +1118,6 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.integer_pow_p: _integer_pow, lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), } diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 45f933a18b6b..b672879a0b1d 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1119,6 +1119,17 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.array([1, 2, 3, 4]).astype(y_dtype) np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) + @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) + def test_integer_pow(self, y): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[:] = lax.integer_pow(x_ref[...], y) + + x = jnp.array([1, 2, 3, 4]).astype(jnp.float32) / 10 + np.testing.assert_allclose(kernel(x), lax.integer_pow(x, y)) + @parameterized.parameters("float32", "float64") def test_nextafter(self, dtype): if jtu.test_device_matches(["tpu"]) and dtype == "float64":