From bd054f81976f95776a88f4f78aa647b2c9d76331 Mon Sep 17 00:00:00 2001 From: Jack Gallagher Date: Tue, 11 Oct 2022 21:45:15 -0700 Subject: [PATCH] add sharding property to GDA should improve forward compatibility with Array --- jax/experimental/global_device_array.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index f4743a922653..0798bbe13f9b 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -456,6 +456,10 @@ def block_until_ready(self): else: self._sharded_buffer.block_until_ready() # type: ignore return self + + @property + def sharding(self): + return jax.sharding.MeshPspecSharding(self._global_mesh, self.mesh_axes) @classmethod def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,