Skip to content

Commit

Permalink
Skip activation offload tests on GPU and CPU until they are supported
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604698627
  • Loading branch information
yashk2810 authored and jax authors committed Feb 6, 2024
1 parent 87d1670 commit 299b983
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,11 @@ def f_bwd(res, tx):

class ActivationOffloadingTest(jtu.JaxTestCase):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
super().setUp()

def test_remat_jaxpr_offloadable(self):
mesh = jtu.create_global_mesh((2,), ("x",))
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
Expand Down

0 comments on commit 299b983

Please sign in to comment.