diff --git a/tests/memories_test.py b/tests/memories_test.py index 9353f21ead0f..dde03f6fad67 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1099,6 +1099,11 @@ def f_bwd(res, tx): class ActivationOffloadingTest(jtu.JaxTestCase): + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Memories do not work on CPU and GPU backends yet.") + super().setUp() + def test_remat_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))