Skip to content

Commit

Permalink
[Pallas] Simplify OpsTest by skipping 64-bit tests on 32-bit enviro…
Browse files Browse the repository at this point in the history
…nments

This PR is similar to #23814.

Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 8, 2024
1 parent 0854dc2 commit 6a958b9
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for common JAX operations within pallas_call."""

from collections.abc import Sequence
import contextlib
import functools
import itertools
import sys
Expand All @@ -30,7 +29,6 @@
import jax.numpy as jnp
from jax import lax
from jax import random
from jax._src import config
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
Expand Down Expand Up @@ -786,11 +784,11 @@ def test_elementwise(self, fn, dtype):
def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/jax-ml/jax/issues/23191
Expand Down Expand Up @@ -822,12 +820,12 @@ def test_pow(self, x_dtype, y_dtype):
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = lax.pow(x_ref[...], y_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(x_dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([1, 2, 3, 4]).astype(x_dtype)
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))
if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([1, 2, 3, 4]).astype(x_dtype)
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))

@parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3)
def test_integer_pow(self, y):
Expand Down Expand Up @@ -855,12 +853,12 @@ def test_nextafter(self, dtype):
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([1, 2, 3, 4]).astype(dtype)
y = jnp.array([1, 2, 3, 4]).astype(dtype)
np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y))
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([1, 2, 3, 4]).astype(dtype)
y = jnp.array([1, 2, 3, 4]).astype(dtype)
np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y))

COMPARISON_OPS = [
jnp.equal,
Expand Down

0 comments on commit 6a958b9

Please sign in to comment.