Skip to content

Commit

Permalink
Concatenate small input chunks before P2P rechunking (dask#8832)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Hoefler <[email protected]>
  • Loading branch information
hendrikmakait and phofl authored Aug 23, 2024
1 parent c073797 commit ea7d35c
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 22 deletions.
173 changes: 158 additions & 15 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@

from __future__ import annotations

import math
import mmap
import os
from collections import defaultdict
Expand All @@ -111,7 +112,7 @@
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from itertools import product
from itertools import chain, product
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast

Expand All @@ -124,6 +125,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.typing import Key
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
from distributed.metrics import context_meter
Expand Down Expand Up @@ -220,7 +222,7 @@ def rechunk_p2p(
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)
from dask.array.core import new_da_object

prechunked = _calculate_prechunking(x.chunks, chunks)
prechunked = _calculate_prechunking(x.chunks, chunks, x.dtype, block_size_limit)
if prechunked != x.chunks:
x = cast(
"da.Array",
Expand Down Expand Up @@ -433,8 +435,140 @@ def _construct_graph(self) -> _T_LowLevelGraph:


def _calculate_prechunking(
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
old_chunks: ChunkedAxes,
new_chunks: ChunkedAxes,
dtype: np.dtype,
block_size_limit: int | None,
) -> ChunkedAxes:
"""Calculate how to perform the pre-rechunking step
During the pre-rechunking step, we
1. Split input chunks along partial boundaries to make partials completely independent of one another
2. Merge small chunks within partials to reduce the number of transfer tasks and corresponding overhead
"""
split_axes = _split_chunks_along_partial_boundaries(old_chunks, new_chunks)

# We can only determine how to concatenate chunks if we can calculate block sizes.
has_nans = (any(math.isnan(y) for y in x) for x in old_chunks)

if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans):
return tuple(tuple(chain(*axis)) for axis in split_axes)

if dtype is None or dtype.hasobject or dtype.itemsize == 0:
return tuple(tuple(chain(*axis)) for axis in split_axes)

# We made sure that there are no NaNs in split_axes above
return _concatenate_small_chunks(
split_axes, old_chunks, new_chunks, dtype, block_size_limit # type: ignore[arg-type]
)


def _concatenate_small_chunks(
split_axes: list[list[list[int]]],
old_chunks: ChunkedAxes,
new_chunks: ChunkedAxes,
dtype: np.dtype,
block_size_limit: int | None,
) -> ChunkedAxes:
"""Concatenate small chunks within partials.
By concatenating chunks within partials, we reduce the number of P2P transfer tasks and their
corresponding overhead.
The algorithm used in this function is very similar to :func:`dask.array.rechunk.find_merge_rechunk`,
the main difference is that we have to make sure only to merge chunks within partials.
"""
import numpy as np

block_size_limit = block_size_limit or dask.config.get("array.chunk-size")

if isinstance(block_size_limit, str):
block_size_limit = parse_bytes(block_size_limit)

# Make it a number of elements
block_size_limit //= dtype.itemsize

# We verified earlier that we do not have any NaNs
largest_old_block = _largest_block_size(old_chunks) # type: ignore[arg-type]
largest_new_block = _largest_block_size(new_chunks) # type: ignore[arg-type]
block_size_limit = max([block_size_limit, largest_old_block, largest_new_block])

old_largest_width = [max(chain(*axis)) for axis in split_axes]
new_largest_width = [max(c) for c in new_chunks]

# This represents how much each dimension increases (>1) or reduces (<1)
# the graph size during rechunking
graph_size_effect = {
dim: len(new_axis) / sum(map(len, split_axis))
for dim, (split_axis, new_axis) in enumerate(zip(split_axes, new_chunks))
}

ndim = len(old_chunks)

# This represents how much each dimension increases (>1) or reduces (<1) the
# largest block size during rechunking
block_size_effect = {
dim: new_largest_width[dim] / (old_largest_width[dim] or 1)
for dim in range(ndim)
}

# Our goal is to reduce the number of nodes in the rechunk graph
# by concatenating some adjacent chunks, so consider dimensions where we can
# reduce the # of chunks
candidates = [dim for dim in range(ndim) if graph_size_effect[dim] <= 1.0]

# Concatenating along each dimension reduces the graph size by a certain factor
# and increases memory largest block size by a certain factor.
# We want to optimize the graph size while staying below the given
# block_size_limit. This is in effect a knapsack problem, except with
# multiplicative values and weights. Just use a greedy algorithm
# by trying dimensions in decreasing value / weight order.
def key(k: int) -> float:
gse = graph_size_effect[k]
bse = block_size_effect[k]
if bse == 1:
bse = 1 + 1e-9
return (np.log(gse) / np.log(bse)) if bse > 0 else 0

sorted_candidates = sorted(candidates, key=key)

concatenated_axes: list[list[int]] = [[] for i in range(ndim)]

# Sim all the axes that are no candidates
for i in range(ndim):
if i in candidates:
continue
concatenated_axes[i] = list(chain(*split_axes[i]))

# We want to concatenate chunks
for axis_index in sorted_candidates:
concatenated_axis = concatenated_axes[axis_index]
multiplier = math.prod(
old_largest_width[:axis_index] + old_largest_width[axis_index + 1 :]
)
axis_limit = block_size_limit // multiplier

for partial in split_axes[axis_index]:
current = partial[0]
for chunk in partial[1:]:
if (current + chunk) > axis_limit:
concatenated_axis.append(current)
current = chunk
else:
current += chunk
concatenated_axis.append(current)
old_largest_width[axis_index] = max(concatenated_axis)
return tuple(tuple(axis) for axis in concatenated_axes)


def _split_chunks_along_partial_boundaries(
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
) -> list[list[list[float]]]:
"""Split the old chunks along the boundaries of partials, i.e., groups of new chunks that share the same inputs.
By splitting along the boundaries before rechunkin their input tasks become disjunct and each partial conceptually
operates on an independent sub-array.
"""
from dask.array.rechunk import old_to_new

_old_to_new = old_to_new(old_chunks, new_chunks)
Expand All @@ -443,10 +577,13 @@ def _calculate_prechunking(

split_axes = []

# Along each axis, we want to figure out how we have to split input chunks in order to make
# partials disjunct. We then group the resulting input chunks per partial before returning.
for axis_index, slices in enumerate(partials):
old_to_new_axis = _old_to_new[axis_index]
old_axis = old_chunks[axis_index]
split_axis = []
partial_chunks = []
for slice_ in slices:
first_new_chunk = slice_.start
first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0]
Expand All @@ -465,22 +602,28 @@ def _calculate_prechunking(
chunk_size = last_old_slice.stop
if first_old_slice.start != 0:
chunk_size -= first_old_slice.start
split_axis.append(chunk_size)
continue

split_axis.append(first_chunk_size - first_old_slice.start)

split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk])

if last_old_slice.stop is not None:
chunk_size = last_old_slice.stop
partial_chunks.append(chunk_size)
else:
chunk_size = last_chunk_size
partial_chunks.append(first_chunk_size - first_old_slice.start)

split_axis.append(chunk_size)
partial_chunks.extend(old_axis[first_old_chunk + 1 : last_old_chunk])

if last_old_slice.stop is not None:
chunk_size = last_old_slice.stop
else:
chunk_size = last_chunk_size

partial_chunks.append(chunk_size)
split_axis.append(partial_chunks)
partial_chunks = []
if partial_chunks:
split_axis.append(partial_chunks)
split_axes.append(split_axis)
return tuple(tuple(axis) for axis in split_axes)
return split_axes


def _largest_block_size(chunks: tuple[tuple[int, ...], ...]) -> int:
return math.prod(map(max, chunks))


def _split_partials(
Expand Down
97 changes: 90 additions & 7 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ async def test_rechunk_avoid_needless_chunking(c, s, *ws):
x = da.ones(16, chunks=2)
y = x.rechunk(8, method="p2p")
dsk = y.__dask_graph__()
assert len(dsk) <= 8 + 2
# 8 inputs, 2 concatenations of small inputs, 2 outputs
assert len(dsk) <= 8 + 2 + 2


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1337,7 +1338,7 @@ async def test_partial_rechunk_taskgroups(c, s):
),
timeout=5,
)
assert len(s.task_groups) < 6
assert len(s.task_groups) < 7


@pytest.mark.parametrize(
Expand All @@ -1351,25 +1352,107 @@ async def test_partial_rechunk_taskgroups(c, s):
],
)
def test_calculate_prechunking_1d(old, new, expected):
actual = _calculate_prechunking(old, new)
actual = _calculate_prechunking(old, new, np.dtype, None)
assert actual == expected


@pytest.mark.parametrize(
["old", "new", "expected"],
[
[((2, 2), (3, 3)), ((2, 2), (3, 3)), ((2, 2), (3, 3))],
[((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))],
[((2, 2), (3, 3)), ((4,), (3, 3)), ((4,), (3, 3))],
[((2, 2), (3, 3)), ((1, 1, 1, 1), (3, 3)), ((2, 2), (3, 3))],
[
((2, 2, 2), (3, 3, 3)),
((1, 2, 2, 1), (2, 3, 4)),
((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)),
((1, 2, 2, 1), (2, 3, 4)),
],
[((1, np.nan), (3, 3)), ((1, np.nan), (2, 2, 2)), ((1, np.nan), (2, 1, 1, 2))],
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))],
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (3,))],
],
)
def test_calculate_prechunking_2d(old, new, expected):
actual = _calculate_prechunking(old, new)
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == expected


@pytest.mark.parametrize(
["old", "new", "expected"],
[
(
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
((1, 1, 1, 1), (4,), (2, 2)),
((2, 2), (4,), (1, 1, 1, 1)),
),
(
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
((1, 1, 1, 1), (2, 2), (2, 2)),
((2, 2), (2, 2), (2, 2)),
),
(
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
((1, 1, 1, 1), (2, 2), (4,)),
((2, 2), (2, 2), (2, 2)),
),
(
((1, 1, 1, 1), (1, 1, 1, 1), (2, 2)),
((2, 2), (4,), (1, 1, 1, 1)),
((2, 2), (2, 2), (2, 2)),
),
],
)
def test_calculate_prechunking_3d(old, new, expected):
with dask.config.set({"array.chunk-size": "16 B"}):
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == expected


@pytest.mark.parametrize(
["chunk_size", "expected"],
[
("1 B", ((10,), (1,) * 10)),
("20 B", ((10,), (1,) * 10)),
("40 B", ((10,), (2, 2, 1, 2, 2, 1))),
("100 B", ((10,), (5, 5))),
],
)
def test_calculate_prechunking_concatenation(chunk_size, expected):
old = ((10,), (1,) * 10)
new = ((2,) * 5, (5, 5))
with dask.config.set({"array.chunk-size": chunk_size}):
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == expected


def test_calculate_prechunking_does_not_concatenate_object_type():
old = ((10,), (1,) * 10)
new = ((2,) * 5, (5, 5))

# Ensure that int dtypes get concatenated
new = ((2,) * 5, (5, 5))
with dask.config.set({"array.chunk-size": "100 B"}):
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == ((10,), (5, 5))

# Ensure object dtype chunks do not get concatenated
with dask.config.set({"array.chunk-size": "100 B"}):
actual = _calculate_prechunking(old, new, np.dtype(object), None)
assert actual == old


@pytest.mark.parametrize(
["old", "new", "expected"],
[
[((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))],
[
((2, 2, 2), (3, 3, 3)),
((1, 2, 2, 1), (2, 3, 4)),
((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)),
],
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))],
],
)
def test_calculate_prechunking_splitting(old, new, expected):
# _calculate_prechunking does not concatenate on object
actual = _calculate_prechunking(old, new, np.dtype(object), None)
assert actual == expected

0 comments on commit ea7d35c

Please sign in to comment.