diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 0804cf04af9b..ff1d828f27fc 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -62,6 +62,8 @@ def add_one(x_ref, o_ref): @jax.default_matmul_precision("bfloat16") def test_mosaic_matmul(self): + if jtu.is_device_tpu(6, "e"): + self.skipTest("TODO(apaszke): Test fails on TPU v6e") dtype = jnp.float32 def func(): # Build the inputs here, to reduce the size of the golden inputs.