Skip to content

Commit

Permalink
[Pallas] Simplify OpsExtraTest by skipping 64-bit tests on 32-bit e…
Browse files Browse the repository at this point in the history
…nvironments

This PR is similar to #23814.

Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 679307748
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 8, 2024
1 parent 76d5938 commit c782f81
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for common JAX operations within pallas_call."""

from collections.abc import Sequence
import contextlib
import functools
import itertools
import sys
Expand All @@ -30,7 +29,6 @@
import jax.numpy as jnp
from jax import lax
from jax import random
from jax._src import config
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
Expand Down Expand Up @@ -786,11 +784,14 @@ def test_elementwise(self, fn, dtype):
def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
if (
not jax.config.x64_enabled
and jnp.dtype(dtype).itemsize == 8
):
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/jax-ml/jax/issues/23191
Expand Down Expand Up @@ -822,12 +823,15 @@ def test_pow(self, x_dtype, y_dtype):
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = lax.pow(x_ref[...], y_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(x_dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([1, 2, 3, 4]).astype(x_dtype)
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))
if (
not jax.config.x64_enabled
and jnp.dtype(x_dtype).itemsize == 8
):
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([1, 2, 3, 4]).astype(x_dtype)
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))

@parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3)
def test_integer_pow(self, y):
Expand Down Expand Up @@ -855,12 +859,15 @@ def test_nextafter(self, dtype):
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([1, 2, 3, 4]).astype(dtype)
y = jnp.array([1, 2, 3, 4]).astype(dtype)
np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y))
if (
not jax.config.x64_enabled
and jnp.dtype(dtype).itemsize == 8
):
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([1, 2, 3, 4]).astype(dtype)
y = jnp.array([1, 2, 3, 4]).astype(dtype)
np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y))

COMPARISON_OPS = [
jnp.equal,
Expand Down

0 comments on commit c782f81

Please sign in to comment.