Skip to content

Commit

Permalink
Removed unnecessary skip in pallas_test.py::SoftmaxTest
Browse files Browse the repository at this point in the history
The Triton bug, whatever it was, seems to have been fixed.

PiperOrigin-RevId: 644293465
  • Loading branch information
superbobry authored and jax authors committed Jun 18, 2024
1 parent 3fd9326 commit 5bfd6af
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,16 +2148,11 @@ class RmsNormInterpreterTest(PallasTest):

class SoftmaxTest(PallasTest):

@parameterized.parameters(
(shape, dtype)
for shape in [(1024, 125), (4, 1024, 125)]
for dtype in (jnp.bfloat16, jnp.float16, jnp.float32)
@parameterized.product(
shape=[(1024, 125), (4, 1024, 125)],
dtype=[jnp.bfloat16, jnp.float16, jnp.float32]
)
def test_softmax(self, shape, dtype):
# TODO(bchetioui): add Triton bug reference when filed
if dtype == jnp.bfloat16:
raise absltest.SkipTest("Disabled due to Triton lowering bug")

x = jax.random.normal(random.key(0), shape, dtype=dtype)

atol, rtol = {
Expand All @@ -2166,9 +2161,11 @@ def test_softmax(self, shape, dtype):
jnp.float32: (1e-7, 1e-6),
}[dtype]

# We upcast to float32 because NumPy <2.0 does not handle custom dtypes
# properly. See https://github.com/google/jax/issues/11014.
np.testing.assert_allclose(
softmax.softmax(x, axis=-1),
jax.nn.softmax(x, axis=-1),
softmax.softmax(x, axis=-1).astype(jnp.float32),
jax.nn.softmax(x, axis=-1).astype(jnp.float32),
atol=atol,
rtol=rtol,
)
Expand Down

0 comments on commit 5bfd6af

Please sign in to comment.