Skip to content

Commit

Permalink
Fix mock_gpu_test on OSS build.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570436380
  • Loading branch information
wang12tao authored and jax authors committed Oct 3, 2023
1 parent c3e73c6 commit ee8af09
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tests/mock_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
config.parse_flags_with_absl()


@jtu.run_on_devices('gpu')
class MockGPUTest(jtu.JaxTestCase):

def setUp(self):
Expand All @@ -44,10 +45,7 @@ def testMockWithSharding(self):
return
num_shards = 16
jax.config.update('mock_num_gpus', num_shards)
mesh_shape = (num_shards,)
axis_names = ('x',)
mesh_devices = np.array(jax.devices()).reshape(mesh_shape)
mesh = jax.sharding.Mesh(mesh_devices, axis_names)
mesh = jtu.create_global_mesh((num_shards,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, P('x',)),
Expand Down

0 comments on commit ee8af09

Please sign in to comment.