diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index dced8882c357..24a19a2b18d9 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -28,6 +28,7 @@ config.parse_flags_with_absl() +@jtu.run_on_devices('gpu') class MockGPUTest(jtu.JaxTestCase): def setUp(self): @@ -44,10 +45,7 @@ def testMockWithSharding(self): return num_shards = 16 jax.config.update('mock_num_gpus', num_shards) - mesh_shape = (num_shards,) - axis_names = ('x',) - mesh_devices = np.array(jax.devices()).reshape(mesh_shape) - mesh = jax.sharding.Mesh(mesh_devices, axis_names) + mesh = jtu.create_global_mesh((num_shards,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, P('x',)),