diff --git a/jax/__init__.py b/jax/__init__.py index 60f39ff858b6..93bb7bbdc413 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -134,6 +134,7 @@ from jax._src.array import ( make_array_from_single_device_arrays as make_array_from_single_device_arrays, make_array_from_callback as make_array_from_callback, + make_array_from_process_local_data as make_array_from_process_local_data, ) from jax._src.tree_util import ( diff --git a/jax/_src/array.py b/jax/_src/array.py index 555b2f7ac0a4..7c7a98d528d3 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -15,13 +15,12 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Sequence import enum +import functools import math import operator as op -import numpy as np -import functools -from typing import Any, Callable, cast, TYPE_CHECKING -from collections.abc import Sequence +from typing import Any, Callable, TYPE_CHECKING, cast from jax._src import abstract_arrays from jax._src import api @@ -35,18 +34,19 @@ from jax._src import profiler from jax._src import tree_util from jax._src import xla_bridge -from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension as xe from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla +from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout +from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - SingleDeviceSharding, XLACompatibleSharding, PmapSharding, - device_replica_id_map, hashed_index) -from jax._src.layout import DeviceLocalLayout, Layout, AutoLayout + PmapSharding, SingleDeviceSharding, XLACompatibleSharding, + device_replica_id_map, hashed_index, num_addressable_indices) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method +import numpy as np Shape = tuple[int, ...] @@ -627,9 +627,12 @@ def _value(self) -> np.ndarray: setattr(ArrayImpl, "__hash__", None) setattr(ArrayImpl, "__array_priority__", 100) +# TODO(yashkatariya): Remove None from callback input type. + def make_array_from_callback( shape: Shape, sharding: Sharding | Layout, data_callback: Callable[[Index | None], ArrayLike]) -> ArrayImpl: + # pyformat: disable """Returns a ``jax.Array`` via data fetched from ``data_callback``. ``data_callback`` is used to fetch the data for each addressable shard of the @@ -667,6 +670,7 @@ def make_array_from_callback( >>> arr.addressable_data(0).shape (4, 2) """ + # pyformat: enable dll = sharding.device_local_layout if isinstance(sharding, Layout) else None if isinstance(dll, AutoLayout): raise TypeError( @@ -725,6 +729,114 @@ def make_array_from_callback( return ArrayImpl(aval, sharding, arrays, committed=True) +def make_array_from_process_local_data( + sharding: Sharding, + local_data: np.ndarray, + global_shape: tuple[int, ...], +) -> ArrayImpl: + # pyformat: disable + """Creates distributed tensor using the data available in process. + + This function is a common special case of `make_array_from_callback`. It + assumes that the data is available in the process and takes care of the + index wrangling. + + Note, if the two hosts are replicas, host_local_data should be identical as + well. + Each dimension of the shape of host_local_data should either match + global_shape or the # indices the devices on this process need to + address. For example if dimension $i$ is fully sharded then this size would be + `per_device_shape[i] * jax.local_device_count()`. + + If the shape matches global shape, each device slice will just lookup + the slice in the local_data. In the latter case the global slice of each + device will be mapped into local slice of `local_data` array. For example, + if given process only addresses slices (8, 12) and (24, 28), then + these slices will be mapped into (0, 4) and (4, 8) of the `local_data`. + + This function can be used to create tensors from dataset feeding pipelines. + + The most common case is when the sharding is fully sharded across the batch + dimension and each host just loads its corresponding sub-batch. This function + supports more general case as well, such as multi-host replication + but you would need to compute the size and the contents of process-local data + correctly to satisfy the replication constraints. + + Examples: + >>> from jax.sharding import PartitionSpec as P + >>> mesh_rows = 2 + >>> mesh_cols = jax.device_count() // 2 + ... + >>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y')) + + >>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),)) + >>> rows_per_device = 2 + >>> feature_length = 32 + >>> per_device_shape = (rows_per_device, feature_length) + >>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length) + >>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape) + >>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays + >>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:] + >>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape) + ... + >>> assert output_global_array.addressable_data(0).shape == per_device_shape + >>> assert output_global_array.shape == global_shape + + Args: + sharding: sharding of the global tensor. + host_local_data: data on the host to be placed on local devices. Each + dimension should either match global_shape, or match + num_addressable_indices(dim). + global_shape: the target shape of the global tensor. In some cases this + parameter can be inferred from sharding and host_local_data, however it is + useful to catch common sharding errors. + + Returns: + Tensor that will have sharding=sharding. + """ + # pyformat: enable + shard_shape = sharding.shard_shape(global_shape) + full_dim = [] + for i, (data_dim, global_dim) in enumerate( + zip(local_data.shape, global_shape) + ): + full_dim.append(data_dim == global_dim) + if data_dim != global_dim: + process_slice = num_addressable_indices(sharding, i, global_shape) + if process_slice != data_dim: + raise ValueError( + "Invalid host data, each dimension should match either global or " + f"process shape. In dimension {i=}, the process data has {data_dim}" + f"elements. Process addresses {process_slice} elements and " + f"{global_shape=}." + ) + addressable_shards = sharding.addressable_devices_indices_map(global_shape) + slices_for_each_dim: list[list[int]] = [[] for _ in global_shape] + for shard_index in addressable_shards.values(): + assert shard_index is not None + for i, slc in enumerate(shard_index): + slices_for_each_dim[i].append(slc.start or 0) + for i in range(len(global_shape)): + slices_for_each_dim[i] = sorted(set(slices_for_each_dim[i])) + + def local_slice(i, slc): + # Looks up the index of this slice in the list of slices for this dimension. + # This will determine the slice in host_local_data + start = slices_for_each_dim[i].index(slc.start or 0) * shard_shape[i] + end = start + shard_shape[i] + return slice(start, end) + + def cb(index: Index | None) -> ArrayLike: + assert index is not None + data_slice = [ + slc if full_dim[i] else local_slice(i, slc) + for i, slc in enumerate(index) + ] + return local_data[tuple(data_slice)] + + return make_array_from_callback(global_shape, sharding, cb) + + def make_array_from_single_device_arrays( shape: Shape, sharding: Sharding, arrays: Sequence[basearray.Array] ) -> ArrayImpl: diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 0cf8d30af968..36a276112b90 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -25,18 +25,17 @@ from typing import Any, NamedTuple, Union, cast from jax._src import mesh as mesh_lib -from jax._src.op_shardings import ( - is_op_sharding_replicated, are_op_shardings_equal, get_num_ways_dim_sharded, - op_sharding_to_indices) from jax._src import sharding from jax._src import sharding_specs from jax._src import tree_util from jax._src import util from jax._src import xla_bridge -from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc +from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, + is_op_sharding_replicated, + op_sharding_to_indices) # pyformat: disable from jax._src.partition_spec import PartitionSpec - +from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method import numpy as np @@ -1376,3 +1375,54 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, ParsedPartitionSpec('', partitions))] else: raise AssertionError("Unhandled OpSharding type. Please open a bug report!") + + +def _slice_as_tuple(s: slice): + assert s.step is None + return (s.start, s.stop) + + +def num_addressable_indices( + tensor_sharding: sharding.Sharding, + dim: int, + global_shape: Shape, +) -> int: + """Returns the number of indices for given dimension this host has access to. + + Each host can have multiple number of devices that are spanning + possibly discontiguous slices of data. This function computes the + total number of unique indices for dimension `dim` that any of its + addressable devices hold. + + In most cases the addressable indices form a sparse grid (and in some + cases a subcube), and thus each host will hold the same of number of + indices for each dimension. However, it is possible to design a mesh that + addressable shards form a complicated pattern. In that case, the returned + value is the number of indices that are addressable by at least one device. + + For example, suppose the sharding looks like this: (number indicates + the host index) + + 1221 + 1221 + 0000 + + Then on host 1 and 2, both dim 0 (rows), and dim=1 (cols) will have size 2, + while on host 0, dim 0 will have size 1, and dim 1 will have size 4. + + Args: + tensor_sharding: Sharding of the tensor. + dim: dimension along which to compute the number of addressable indices. + global_shape: global shape of the tensor. + + Returns: + The number of indices for dimension `dim` that this host holds. + """ + # TODO(sandler, yashkatariya): Consider making this function public. + addressables = tensor_sharding.addressable_devices_indices_map(global_shape) + addressables = cast(Mapping[sharding.Device, Index], addressables) + num_unique_slices = len({ + _slice_as_tuple(addressable[dim]) for addressable in addressables.values() + }) + shard_size = tensor_sharding.shard_shape(global_shape)[dim] + return shard_size * num_unique_slices diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index f19fed894c91..ef7279d01641 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -13,16 +13,16 @@ # limitations under the License. from __future__ import annotations -from collections.abc import Generator, Iterable, Sequence -from contextlib import contextmanager, ExitStack +from collections.abc import Generator, Iterable, Mapping, Sequence +from contextlib import ExitStack, contextmanager import datetime -import inspect -import io import functools from functools import partial +import inspect +import io import math -import re import os +import re import tempfile import textwrap from typing import Any, Callable @@ -32,33 +32,31 @@ from absl.testing import absltest from absl.testing import parameterized - -import numpy as np -import numpy.random as npr - import jax from jax import lax -from jax.experimental.compilation_cache import compilation_cache -from jax._src.interpreters import mlir -from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten from jax._src import api -from jax._src import pjit as pjit_lib from jax._src import config from jax._src import core from jax._src import dispatch -from jax._src import linear_util as lu from jax._src import dtypes as _dtypes +from jax._src import linear_util as lu from jax._src import monitoring +from jax._src import pjit as pjit_lib from jax._src import stages -from jax._src.lib import xla_client as xc +from jax._src import xla_bridge from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm +from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.lib import xla_client as xc from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact -from jax._src.util import unzip2 from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, tolerance, rand_like) -from jax._src import xla_bridge + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) +from jax._src.util import unzip2 +from jax.experimental.compilation_cache import compilation_cache +from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten +import numpy as np +import numpy.random as npr # This submodule includes private test utilities that are not exported to @@ -1911,3 +1909,101 @@ def worker(ctx, s, e, r, v): return worker(ctx, scale, exact, reference, value) else: assert 0 # unreachable + + +def get_process_index_and_count( + tensor_sharding: jax.sharding.Sharding, + dim: int, + global_shape: tuple[int, ...], +) -> tuple[int, int]: + """Returns current process index and total count for the given dimension. + + This function facilitates mapping of process-level data to individual + devices. Each process can use its index to obtain the data corresponding + to that index. If process level data is sharded on multiple dimensions + this function can be used to build the cross product of indices in + each sharded axis. Processes that need to load the same data will have + the same index. For shardings whose per-process data is not distributed + on a grid, the number of distinct shards will be such that it is possible to + build the target shape while maintaining a "cube" shape of local-process data. + + For example, in case of 4 hosts with sharding distributed like so: + + 1234 + 2143 + + For dim 0 (rows): all processes need to access all rows, so we return (0, 1) + For dim 1 (cols): + process 1 and 2 returns index 0 out of 2 (need cols 0 and 1), + process 3 and 4 returns index 1 out of 2 (need cols 2 and 3). + + On the other hand, for a sharding like: + + 1212 + 3434 + + Dim 0 (rows): process 1 and 2 returns (0, 2), process 3 and 4 returns (1, 2) + Dim 1 (cols): process 1 and 3 returns (0, 2), process 2 and 4 returns (1, 2) + + Note: This function requires sharding to be process uniform in dimension `dim`: + each process has the same number of addressable indices in that + dimension and all index sets across processes are either disjoint or the same. + + For sharding to be process uniform the addressable shards doesn't need to + form contiguous subtensor, or even a sparse grid and in case of + interleaved high-dimensional tensor it is possible for sharding to be + process uniform only in some dimensions but not others. + + For example: + 1111 and 12 and 1212 and 1212 + 2222 21 2121 1212 + + are all sharding uniform, in both dimensions. However + + 1122 + 2121 + 1121 + 1222 + + is uniform in dimension 0 (both hosts access all rows), but + is not uniform in dimension 1 (host 1 accesses columns: 0, 1, and 3), + while host 2 accesses (0, 1, 2, 3). + + Returns: + A tuple of (index, num_distinct_shards) for the given dimension. + It is guaranteed that `index` will cover 0 to `num_distinct_shards - 1`, + across all processes. + + Raises: + ValueError if the sharding is not process uniform in dimension `dim`. + """ + # TODO(sandler, yashkatariya): Consider making this function public. + + if tensor_sharding.is_fully_addressable or tensor_sharding.is_fully_replicated: + return (0, 1) + # NB: For most types of shardings, global_shape is a superfluous argument + # and could be replaced by [d, d, ...., d, d], where d is the number of + # devices. + device_map: Mapping[jax.sharding.Device, jax.sharding.Index] = ( + tensor_sharding.devices_indices_map(global_shape) + ) + + global_slice = {k: v[dim] for k, v in device_map.items()} + process_map: dict[int, set[tuple[int, int]]] = {} + all_slices = set() + + current_pid = next(iter(tensor_sharding.addressable_devices)).process_index + for d, v in global_slice.items(): + key = (v.start, v.stop) + process_map.setdefault(d.process_index, set()).add(key) + all_slices.add(key) + addressable = frozenset(process_map[current_pid]) + slices_per_process = len(addressable) + if any(len(x) != slices_per_process for x in process_map.values()): + raise ValueError(f'{tensor_sharding=} is non-uniform on {dim=}') + unique_processes = list({frozenset(x) for x in process_map.values()}) + + # After removing duplicate processes each slide should appear exactly once. + if sum(len(h) for h in unique_processes) != len(all_slices): + raise ValueError(f'{tensor_sharding=} is non-uniform on {dim=}') + return (unique_processes.index(addressable), len(unique_processes)) diff --git a/tests/array_test.py b/tests/array_test.py index 7be25411b8d1..23bcbfdedd41 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -812,6 +812,19 @@ def test_make_array_from_callback_global_array(self): self.assertArraysEqual(out2, arr2) self.assertEqual(out2.sharding, sharding2) + def test_make_array_from_process_data_single_host_data_sharding(self): + data = np.ones((1, 512)) + mesh = jtu.create_global_mesh((1, 1), ('x', 'unused')) + sharding_spec = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('x') + ) + global_shape = data.shape + result = jax.make_array_from_process_local_data( + sharding_spec, data, global_shape + ) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, data.shape) + self.assertEqual(result.sharding, sharding_spec) class ShardingTest(jtu.JaxTestCase):