diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index f4743a922653..0798bbe13f9b 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -456,6 +456,10 @@ def block_until_ready(self): else: self._sharded_buffer.block_until_ready() # type: ignore return self + + @property + def sharding(self): + return jax.sharding.MeshPspecSharding(self._global_mesh, self.mesh_axes) @classmethod def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,