Skip to content

Commit

Permalink
Merge pull request #12764 from midjourney:gda_sharding_property
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480628909
  • Loading branch information
jax authors committed Oct 12, 2022
2 parents 9bb2c99 + bd054f8 commit 58d516c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
7 changes: 6 additions & 1 deletion jax/experimental/global_device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple

import jax
from jax import core
from jax._src import api_util
from jax._src.lib import xla_bridge as xb
Expand Down Expand Up @@ -454,9 +455,13 @@ def block_until_ready(self):
for db in self._device_buffers:
db.block_until_ready()
else:
self._sharded_buffer.block_until_ready() # type: ignore
self._sharded_buffer.block_until_ready() # type: ignore
return self

@property
def sharding(self):
return jax.sharding.MeshPspecSharding(self._global_mesh, self.mesh_axes)

@classmethod
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, data_callback: Callable[[Index],
Expand Down
1 change: 1 addition & 0 deletions tests/global_device_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def cb(index):
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertIsInstance(gda.sharding, jax.sharding.MeshPspecSharding)
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
Expand Down

0 comments on commit 58d516c

Please sign in to comment.