diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index c2d33c00e3c6..2cdc5013bf15 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -15,7 +15,6 @@ """Tests for common JAX operations within pallas_call.""" from collections.abc import Sequence -import contextlib import functools import itertools import sys @@ -30,7 +29,6 @@ import jax.numpy as jnp from jax import lax from jax import random -from jax._src import config from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state @@ -786,11 +784,14 @@ def test_elementwise(self, fn, dtype): def kernel(x_ref, o_ref): o_ref[:] = fn(x_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - x = jnp.array([0.42, 2.4]).astype(dtype) - np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + if ( + not jax.config.x64_enabled + and jnp.dtype(dtype).itemsize == 8 + ): + self.skipTest("64-bit types require x64_enabled") + + x = jnp.array([0.42, 2.4]).astype(dtype) + np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): # see https://github.com/jax-ml/jax/issues/23191 @@ -822,12 +823,15 @@ def test_pow(self, x_dtype, y_dtype): def kernel(x_ref, y_ref, o_ref): o_ref[:] = lax.pow(x_ref[...], y_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(x_dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - x = jnp.array([1, 2, 3, 4]).astype(x_dtype) - y = jnp.array([1, 2, 3, 4]).astype(y_dtype) - np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) + if ( + not jax.config.x64_enabled + and jnp.dtype(x_dtype).itemsize == 8 + ): + self.skipTest("64-bit types require x64_enabled") + + x = jnp.array([1, 2, 3, 4]).astype(x_dtype) + y = jnp.array([1, 2, 3, 4]).astype(y_dtype) + np.testing.assert_allclose(kernel(x, y), lax.pow(x, y)) @parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3) def test_integer_pow(self, y): @@ -855,12 +859,15 @@ def test_nextafter(self, dtype): def kernel(x_ref, y_ref, o_ref): o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...]) - with contextlib.ExitStack() as stack: - if jnp.dtype(dtype).itemsize == 8: - stack.enter_context(config.enable_x64(True)) - x = jnp.array([1, 2, 3, 4]).astype(dtype) - y = jnp.array([1, 2, 3, 4]).astype(dtype) - np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) + if ( + not jax.config.x64_enabled + and jnp.dtype(dtype).itemsize == 8 + ): + self.skipTest("64-bit types require x64_enabled") + + x = jnp.array([1, 2, 3, 4]).astype(dtype) + y = jnp.array([1, 2, 3, 4]).astype(dtype) + np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y)) COMPARISON_OPS = [ jnp.equal,