diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 99f9d39061..962b57fb32 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -153,6 +153,8 @@ ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] NDSlice: TypeAlias = tuple[slice, ...] +SlicedAxis: TypeAlias = tuple[slice, ...] +SlicedAxes: TypeAlias = tuple[SlicedAxis, ...] def rechunk_transfer( @@ -187,14 +189,6 @@ class _Partial(NamedTuple): old: slice #: Slice of the new chunks along this axis that belong to the partial new: slice - #: Index of the first value of the left-most old chunk along this axis - #: to include in this partial. Everything left to this index belongs to - #: the previous partial. - left_start: int - #: Index of the first value of the right-most old chunk along this axis - #: to exclude from this partial. - #: This corresponds to `left_start` of the subsequent partial. - right_stop: int class _NDPartial(NamedTuple): @@ -204,17 +198,6 @@ class _NDPartial(NamedTuple): old: NDSlice #: n-dimensional slice of the new chunks along each axis that belong to the partial new: NDSlice - #: Indices of the first value of the left-most old chunk along each axis - #: to include in this partial. Everything left to this index belongs to - #: the previous partial. - left_starts: NDIndex - #: Indices of the first value of the right-most old chunk along each axis - #: to exclude from this partial. - #: This corresponds to `left_start` of the subsequent partial. - right_stops: NDIndex - #: Index of the partial among all partials. - #: This corresponds to the position of the partial in the n-dimensional grid of - #: partials representing the full rechunk. ix: NDIndex @@ -222,7 +205,14 @@ def rechunk_name(token: str) -> str: return f"rechunk-p2p-{token}" -def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: +def rechunk_p2p( + x: da.Array, + chunks: ChunkedAxes, + *, + threshold: int | None = None, + block_size_limit: int | None = None, + balance: bool = False, +) -> da.Array: import dask.array as da if x.size == 0: @@ -230,6 +220,19 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: return da.empty(x.shape, chunks=chunks, dtype=x.dtype) from dask.array.core import new_da_object + prechunked = _calculate_prechunking(x.chunks, chunks) + if prechunked != x.chunks: + x = cast( + "da.Array", + x.rechunk( + chunks=prechunked, + threshold=threshold, + block_size_limit=block_size_limit, + balance=balance, + method="tasks", + ), + ) + token = tokenize(x, chunks) name = rechunk_name(token) disk: bool = dask.config.get("distributed.p2p.disk") @@ -396,24 +399,22 @@ def _construct_graph(self) -> _T_LowLevelGraph: dsk: _T_LowLevelGraph = {} _old_to_new = old_to_new(self.chunks_input, self.chunks) - chunked_shape = tuple(len(axis) for axis in self.chunks) - for ndpartial in _split_partials(_old_to_new, chunked_shape): + for ndpartial in _split_partials(_old_to_new): partial_keepmap = self.keepmap[ndpartial.new] output_count = np.sum(partial_keepmap) if output_count == 0: continue elif output_count == 1: # Single output chunk - ndindex = np.argwhere(partial_keepmap)[0] - ndpartial = _truncate_partial(ndindex, ndpartial, _old_to_new) - dsk.update( partial_concatenate( input_name=self.name_input, input_chunks=self.chunks_input, ndpartial=ndpartial, token=self.token, + keepmap=self.keepmap, + old_to_new=_old_to_new, ) ) else: @@ -431,71 +432,120 @@ def _construct_graph(self) -> _T_LowLevelGraph: return dsk +def _calculate_prechunking( + old_chunks: ChunkedAxes, new_chunks: ChunkedAxes +) -> ChunkedAxes: + from dask.array.rechunk import old_to_new + + _old_to_new = old_to_new(old_chunks, new_chunks) + + partials = _slice_new_chunks_into_partials(_old_to_new) + + split_axes = [] + + for axis_index, slices in enumerate(partials): + old_to_new_axis = _old_to_new[axis_index] + old_axis = old_chunks[axis_index] + split_axis = [] + for slice_ in slices: + first_new_chunk = slice_.start + first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0] + last_new_chunk = slice_.stop - 1 + last_old_chunk, last_old_slice = old_to_new_axis[last_new_chunk][-1] + + first_chunk_size = old_axis[first_old_chunk] + last_chunk_size = old_axis[last_old_chunk] + + if first_old_chunk == last_old_chunk: + chunk_size = first_chunk_size + if ( + last_old_slice.stop is not None + and last_old_slice.stop != last_chunk_size + ): + chunk_size = last_old_slice.stop + if first_old_slice.start != 0: + chunk_size -= first_old_slice.start + split_axis.append(chunk_size) + continue + + split_axis.append(first_chunk_size - first_old_slice.start) + + split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk]) + + if last_old_slice.stop is not None: + chunk_size = last_old_slice.stop + else: + chunk_size = last_chunk_size + + split_axis.append(chunk_size) + + split_axes.append(split_axis) + return tuple(tuple(axis) for axis in split_axes) + + def _split_partials( old_to_new: list[Any], - chunked_shape: tuple[int, ...], ) -> Generator[_NDPartial, None, None]: """Split the rechunking into partials that can be performed separately""" - partials_per_axis = _split_partials_per_axis(old_to_new, chunked_shape) + partials_per_axis = _split_partials_per_axis(old_to_new) indices_per_axis = (range(len(partials)) for partials in partials_per_axis) for nindex, partial_per_axis in zip( product(*indices_per_axis), product(*partials_per_axis) ): - old, new, left_starts, right_stops = zip(*partial_per_axis) - yield _NDPartial(old, new, left_starts, right_stops, nindex) + old, new = zip(*partial_per_axis) + yield _NDPartial(old, new, nindex) -def _split_partials_per_axis( - old_to_new: list[Any], chunked_shape: tuple[int, ...] -) -> tuple[tuple[_Partial, ...], ...]: +def _split_partials_per_axis(old_to_new: list[Any]) -> tuple[tuple[_Partial, ...], ...]: """Split the rechunking into partials that can be performed separately on each axis""" - sliced_axes = _partial_slices(old_to_new, chunked_shape) + sliced_axes = _slice_new_chunks_into_partials(old_to_new) partial_axes = [] for axis_index, slices in enumerate(sliced_axes): partials = [] for slice_ in slices: last_old_chunk: int - first_old_chunk, first_old_slice = old_to_new[axis_index][slice_.start][0] - last_old_chunk, last_old_slice = old_to_new[axis_index][slice_.stop - 1][-1] + first_old_chunk, _ = old_to_new[axis_index][slice_.start][0] + last_old_chunk, _ = old_to_new[axis_index][slice_.stop - 1][-1] partials.append( _Partial( old=slice(first_old_chunk, last_old_chunk + 1), new=slice_, - left_start=first_old_slice.start, - right_stop=last_old_slice.stop, ) ) partial_axes.append(tuple(partials)) return tuple(partial_axes) -def _partial_slices( - old_to_new: list[list[list[tuple[int, slice]]]], chunked_shape: NDIndex -) -> tuple[tuple[slice, ...], ...]: - """Compute the slices of the new chunks that can be computed separately""" +def _slice_new_chunks_into_partials( + old_to_new: list[list[list[tuple[int, slice]]]] +) -> SlicedAxes: + """Slice the new chunks into partials that can be computed separately""" sliced_axes = [] + chunk_shape = tuple(len(axis) for axis in old_to_new) + for axis_index, old_to_new_axis in enumerate(old_to_new): # Two consecutive output chunks A and B belong to the same partial rechunk - # if B is fully included in the right-most input chunk of A, i.e., - # separating A and B would not allow us to cull more input tasks. + # if A and B share the same input chunks, i.e., separating A and B would not + # allow us to cull more input tasks. # Index of the last input chunk of this partial rechunk - last_old_chunk: int | None = None + first_old_chunk: int | None = None partial_splits = [0] recipe: list[tuple[int, slice]] for new_chunk_index, recipe in enumerate(old_to_new_axis): if len(recipe) == 0: continue - current_last_old_chunk, old_slice = recipe[-1] - if last_old_chunk is None: - last_old_chunk = current_last_old_chunk - elif last_old_chunk != current_last_old_chunk: + current_first_old_chunk, _ = recipe[0] + current_last_old_chunk, _ = recipe[-1] + if first_old_chunk is None: + first_old_chunk = current_first_old_chunk + elif first_old_chunk != current_last_old_chunk: partial_splits.append(new_chunk_index) - last_old_chunk = current_last_old_chunk - partial_splits.append(chunked_shape[axis_index]) + first_old_chunk = current_first_old_chunk + partial_splits.append(chunk_shape[axis_index]) sliced_axes.append( tuple(slice(a, b) for a, b in toolz.sliding_window(2, partial_splits)) ) @@ -517,6 +567,8 @@ def partial_concatenate( input_chunks: ChunkedAxes, ndpartial: _NDPartial, token: str, + keepmap: np.ndarray, + old_to_new: list[Any], ) -> dict[Key, Any]: import numpy as np @@ -527,80 +579,45 @@ def partial_concatenate( slice_group = f"rechunk-slice-{token}" - old_offset = tuple(slice_.start for slice_ in ndpartial.old) + partial_keepmap = keepmap[ndpartial.new] + assert np.sum(partial_keepmap) == 1 + partial_new_index = np.argwhere(partial_keepmap)[0] - shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old) - rec_cat_arg = np.empty(shape, dtype="O") + global_new_index = tuple( + int(ix) + slc.start for ix, slc in zip(partial_new_index, ndpartial.new) + ) - partial_old = _compute_partial_old_chunks(ndpartial, input_chunks) + inputs = tuple( + old_to_new_axis[ix] for ix, old_to_new_axis in zip(global_new_index, old_to_new) + ) + shape = tuple(len(axis) for axis in inputs) + rec_cat_arg = np.empty(shape, dtype="O") - for old_partial_index in _partial_ndindex(ndpartial.old): - old_global_index = _global_index(old_partial_index, old_offset) - # TODO: Precompute slicing to avoid duplicate work - ndslice = ndslice_for( - old_partial_index, partial_old, ndpartial.left_starts, ndpartial.right_stops + for old_partial_index in np.ndindex(shape): + old_global_index, old_slice = zip( + *(input_axis[index] for index, input_axis in zip(old_partial_index, inputs)) ) - original_shape = tuple( - axis[index] for index, axis in zip(old_global_index, input_chunks) + old_axis[index] for index, old_axis in zip(old_global_index, input_chunks) ) - if _slicing_is_necessary(ndslice, original_shape): # type: ignore + if _slicing_is_necessary(old_slice, original_shape): key = (slice_group,) + ndpartial.ix + old_global_index rec_cat_arg[old_partial_index] = key dsk[key] = ( getitem, (input_name,) + old_global_index, - ndslice, + old_slice, ) else: rec_cat_arg[old_partial_index] = (input_name,) + old_global_index - global_index = tuple(int(slice_.start) for slice_ in ndpartial.new) - dsk[(rechunk_name(token),) + global_index] = ( + + dsk[(rechunk_name(token),) + global_new_index] = ( concatenate3, rec_cat_arg.tolist(), ) return dsk -def _truncate_partial( - ndindex: NDIndex, - ndpartial: _NDPartial, - old_to_new: list[Any], -) -> _NDPartial: - partial_per_axis = [] - for axis_index, index in enumerate(ndindex): - slc = slice( - ndpartial.new[axis_index].start + index, - ndpartial.new[axis_index].start + index + 1, - ) - first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] - last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] - partial_per_axis.append( - _Partial( - old=slice(first_old_chunk, last_old_chunk + 1), - new=slc, - left_start=first_old_slice.start, - right_stop=last_old_slice.stop, - ) - ) - - old, new, left_starts, right_stops = zip(*partial_per_axis) - return _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) - - -def _compute_partial_old_chunks( - partial: _NDPartial, chunks: ChunkedAxes -) -> ChunkedAxes: - _partial_old = [] - for axis_index in range(len(partial.old)): - c = list(chunks[axis_index][partial.old[axis_index]]) - c[0] = c[0] - partial.left_starts[axis_index] - if (stop := partial.right_stops[axis_index]) is not None: - c[-1] = stop - _partial_old.append(tuple(c)) - return tuple(_partial_old) - - def _slicing_is_necessary(slice: NDSlice, shape: tuple[int | None, ...]) -> bool: """Return True if applying the slice alters the shape, False otherwise.""" return not all( @@ -618,15 +635,12 @@ def partial_rechunk( disk: bool, keepmap: np.ndarray, ) -> dict[Key, Any]: - from dask.array.chunk import getitem - dsk: dict[Key, Any] = {} old_partial_offset = tuple(slice_.start for slice_ in ndpartial.old) partial_token = tokenize(token, ndpartial.ix) # Use `token` to generate a canonical group for the entire rechunk - slice_group = f"rechunk-slice-{token}" transfer_group = f"rechunk-transfer-{token}" unpack_group = rechunk_name(token) # We can use `partial_token` here because the barrier task share their @@ -636,32 +650,19 @@ def partial_rechunk( ndim = len(input_chunks) - partial_old = _compute_partial_old_chunks(ndpartial, input_chunks) + partial_old = tuple( + chunk_axis[partial_axis] + for partial_axis, chunk_axis in zip(ndpartial.old, input_chunks) + ) partial_new: ChunkedAxes = tuple( chunks[axis_index][ndpartial.new[axis_index]] for axis_index in range(ndim) ) transfer_keys = [] for partial_index in _partial_ndindex(ndpartial.old): - # FIXME: Do not shuffle data for output chunks that we culled - ndslice = ndslice_for( - partial_index, partial_old, ndpartial.left_starts, ndpartial.right_stops - ) - global_index = _global_index(partial_index, old_partial_offset) - original_shape = tuple( - axis[index] for index, axis in zip(global_index, input_chunks) - ) - if _slicing_is_necessary(ndslice, original_shape): # type: ignore - input_key = (slice_group,) + ndpartial.ix + global_index - dsk[input_key] = ( - getitem, - (input_name,) + global_index, - ndslice, - ) - else: - input_key = (input_name,) + global_index + input_key = (input_name,) + global_index key = (transfer_group,) + ndpartial.ix + global_index transfer_keys.append(key) @@ -690,26 +691,6 @@ def partial_rechunk( return dsk -def ndslice_for( - partial_index: NDIndex, - chunks: ChunkedAxes, - left_starts: NDIndex, - right_stops: NDIndex, -) -> NDSlice: - slices = [] - shape = tuple(len(axis) for axis in chunks) - for axis_index, chunked_axis in enumerate(chunks): - chunk_index = partial_index[axis_index] - start = left_starts[axis_index] if chunk_index == 0 else 0 - stop = ( - right_stops[axis_index] - if chunk_index == shape[axis_index] - 1 - else chunked_axis[chunk_index] + start - ) - slices.append(slice(start, stop)) - return tuple(slices) - - class Split(NamedTuple): """Slice of a chunk that is concatenated with other splits to create a new chunk diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 69438b4473..7fa608bef3 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -31,6 +31,7 @@ ArrayRechunkRun, ArrayRechunkSpec, Split, + _calculate_prechunking, split_axes, ) from distributed.shuffle.tests.utils import AbstractShuffleTestPool @@ -188,17 +189,19 @@ async def test_rechunk_configuration(c, s, *ws, config_value, keyword): -------- dask.array.tests.test_rechunk.test_rechunk_1d """ - a = np.random.default_rng().uniform(0, 1, 30) - x = da.from_array(a, chunks=((10,) * 3,)) - new = ((6,) * 5,) + a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10)) + x = da.from_array(a, chunks=(10, 1)) + new = ((1,) * 10, (10,)) config = {"array.rechunk.method": config_value} if config_value is not None else {} with dask.config.set(config): x2 = rechunk(x, chunks=new, method=keyword) expected_algorithm = keyword if keyword is not None else config_value if expected_algorithm == "p2p": - assert all(key[0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) else: - assert not any(key[0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + assert not any( + key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__() + ) assert x2.chunks == new assert np.all(await c.compute(x2) == a) @@ -1315,30 +1318,17 @@ async def test_partial_rechunk_homogeneous_distribution(c, s, *workers): async def test_partial_rechunk_taskgroups(c, s): """Regression test for https://github.com/dask/distributed/issues/8656""" arr = da.random.random( - (10, 10, 10), + (10, 10), chunks=( - ( - 2, - 2, - 2, - 2, - 2, - ), - ) - * 3, + (1,) * 10, + (2,) * 5, + ), ) arr = arr.rechunk( ( - ( - 1, - 2, - 2, - 2, - 2, - 1, - ), - ) - * 3, + (2,) * 5, + (1,) * 10, + ), method="p2p", ) @@ -1350,4 +1340,39 @@ async def test_partial_rechunk_taskgroups(c, s): ), timeout=5, ) - assert len(s.task_groups) < 7 + assert len(s.task_groups) < 6 + + +@pytest.mark.parametrize( + ["old", "new", "expected"], + [ + [((2, 2),), ((2, 2),), ((2, 2),)], + [((2, 2),), ((4,),), ((2, 2),)], + [((2, 2),), ((1, 1, 1, 1),), ((2, 2),)], + [((2, 2, 2),), ((1, 2, 2, 1),), ((1, 1, 1, 1, 1, 1),)], + [((1, np.nan),), ((1, np.nan),), ((1, np.nan),)], + ], +) +def test_calculate_prechunking_1d(old, new, expected): + actual = _calculate_prechunking(old, new) + assert actual == expected + + +@pytest.mark.parametrize( + ["old", "new", "expected"], + [ + [((2, 2), (3, 3)), ((2, 2), (3, 3)), ((2, 2), (3, 3))], + [((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))], + [((2, 2), (3, 3)), ((1, 1, 1, 1), (3, 3)), ((2, 2), (3, 3))], + [ + ((2, 2, 2), (3, 3, 3)), + ((1, 2, 2, 1), (2, 3, 4)), + ((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)), + ], + [((1, np.nan), (3, 3)), ((1, np.nan), (2, 2, 2)), ((1, np.nan), (2, 1, 1, 2))], + [((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))], + ], +) +def test_calculate_prechunking_2d(old, new, expected): + actual = _calculate_prechunking(old, new) + assert actual == expected