Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Simplify OpsTest by skipping 64-bit tests on 32-bit environments #23960

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for common JAX operations within pallas_call."""

from collections.abc import Sequence
import contextlib
import functools
import itertools
import sys
Expand All @@ -30,7 +29,6 @@
import jax.numpy as jnp
from jax import lax
from jax import random
from jax._src import config
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import state
Expand Down Expand Up @@ -786,11 +784,11 @@ def test_elementwise(self, fn, dtype):
def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([0.42, 2.4]).astype(dtype)
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)

def test_abs_weak_type(self):
# see https://github.com/jax-ml/jax/issues/23191
Expand Down Expand Up @@ -822,12 +820,12 @@ def test_pow(self, x_dtype, y_dtype):
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = lax.pow(x_ref[...], y_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(x_dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([1, 2, 3, 4]).astype(x_dtype)
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))
if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([1, 2, 3, 4]).astype(x_dtype)
y = jnp.array([1, 2, 3, 4]).astype(y_dtype)
np.testing.assert_allclose(kernel(x, y), lax.pow(x, y))

@parameterized.parameters(0, 1, 2, 3, 4, 5, -1, -2, -3)
def test_integer_pow(self, y):
Expand Down Expand Up @@ -855,12 +853,12 @@ def test_nextafter(self, dtype):
def kernel(x_ref, y_ref, o_ref):
o_ref[:] = jnp.nextafter(x_ref[...], y_ref[...])

with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
x = jnp.array([1, 2, 3, 4]).astype(dtype)
y = jnp.array([1, 2, 3, 4]).astype(dtype)
np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y))
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

x = jnp.array([1, 2, 3, 4]).astype(dtype)
y = jnp.array([1, 2, 3, 4]).astype(dtype)
np.testing.assert_allclose(kernel(x, y), jnp.nextafter(x, y))

COMPARISON_OPS = [
jnp.equal,
Expand Down