Skip to content

Commit

Permalink
mesh_utils.create_hybrid_device_mesh: make sorting granules by key …
Browse files Browse the repository at this point in the history
…user configurable.

When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.

PiperOrigin-RevId: 590228169
  • Loading branch information
jax authors committed Dec 12, 2023
1 parent b077483 commit 94d58b7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
19 changes: 14 additions & 5 deletions jax/experimental/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,13 @@ def create_device_mesh(
device_mesh = np.asarray(devices).reshape(mesh_shape)
return device_mesh

def create_hybrid_device_mesh(mesh_shape: Sequence[int],
dcn_mesh_shape: Sequence[int],
devices: Optional[Sequence[Any]] = None, *,
process_is_granule: bool = False) -> np.ndarray:
def create_hybrid_device_mesh(
mesh_shape: Sequence[int],
dcn_mesh_shape: Sequence[int],
devices: Optional[Sequence[Any]] = None, *,
process_is_granule: bool = False,
should_sort_granules_by_key: bool = True,
) -> np.ndarray:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
Args:
Expand All @@ -339,6 +342,9 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
of the slower/outer network. Otherwise it will look for slice_index
attributes on devices and use slices as the units. Enabling this is meant
as a fallback for platforms (e.g., GPU) that don't set slice_index.
should_sort_granules_by_key: Whether device granules should be sorted by the
granule key, either slice or process index, depending on
process_is_granule.
Raises:
ValueError: if the number of slices to which the `devices` belong doesn't
Expand All @@ -356,7 +362,10 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
granule_dict = collections.defaultdict(list)
for dev in devices:
granule_dict[getattr(dev, attr)].append(dev)
granules = [granule_dict[key] for key in sorted(granule_dict.keys())]
granules = (
[granule_dict[key] for key in sorted(granule_dict.keys())]
if should_sort_granules_by_key
else granule_dict.values())
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f'Number of slices {len(granules)} must equal the product of '
Expand Down
44 changes: 43 additions & 1 deletion tests/mesh_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def mock_tpu_device(core_on_chip, xd, yd, zd, xp, yp, zp, slice_index):
_validate_mocked_process_indices(devices, one_device_per_chip)
return devices


# If this function raises, it's a bug in the test code!
def _validate_mocked_process_indices(devices, one_device_per_chip):
process_to_devices = collections.defaultdict(list)
Expand Down Expand Up @@ -202,7 +203,48 @@ def test_create_hybrid_device_mesh(self, mesh_shape, dcn_mesh_shape):
mesh_shape, dcn_mesh_shape, devices)
total_mesh_shape = tuple(
m1 * m2 for m1, m2 in zip(mesh_shape, dcn_mesh_shape))
assert mesh.shape == total_mesh_shape
self.assertEqual(mesh.shape, total_mesh_shape)

@parameterized.named_parameters(
('2X4x4x4a', (1, 16, 4), (2, 1, 1)),
('2X4x4x4b', (1, 4, 16), (1, 2, 1)),
)
def test_create_hybrid_device_mesh_device_sorting(
self,
mesh_shape: tuple[int, ...],
dcn_mesh_shape: tuple[int, ...],
):
devices = mock_tpu_devices(4, 4, 4, 'TPU v4', True, 2)
reversed_slices_devices = list(
np.flip(np.array(devices).reshape(2, -1), axis=0).flat)
mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape,
dcn_mesh_shape,
devices,
should_sort_granules_by_key=False,
)
sorted_slices_mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape,
dcn_mesh_shape,
reversed_slices_devices,
should_sort_granules_by_key=True,
)
np.testing.assert_array_equal(mesh, sorted_slices_mesh)
self.assertSetEqual(
{0, 1},
{d.slice_index for d in sorted_slices_mesh.flat},
)

reversed_slices_mesh = mesh_utils.create_hybrid_device_mesh(
mesh_shape,
dcn_mesh_shape,
reversed_slices_devices,
should_sort_granules_by_key=False,
)
self.assertSetEqual(
{1, 0},
{d.slice_index for d in reversed_slices_mesh.flat},
)

@parameterized.named_parameters(
# Physical ring order over tray
Expand Down

0 comments on commit 94d58b7

Please sign in to comment.