Skip to content

Commit

Permalink
Optimize bitmask finding for chunk size 1 and single chunk cases (#360)
Browse files Browse the repository at this point in the history
* Optimize bitmask finding for chunk size 1.

* Fix benchmark.

* bugfix

* Add single chunk benchmark

* Optimize single chunk case.

* Add test
  • Loading branch information
dcherian authored Apr 27, 2024
1 parent 13cb229 commit 627bf2b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
38 changes: 24 additions & 14 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import cached_property

import dask
import numpy as np
import pandas as pd
Expand All @@ -11,6 +13,10 @@ class Cohorts:
def setup(self, *args, **kwargs):
raise NotImplementedError

@cached_property
def dask(self):
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0].dask

def containment(self):
asfloat = self.bitmask().astype(float)
chunks_per_label = asfloat.sum(axis=0)
Expand Down Expand Up @@ -43,26 +49,17 @@ def time_find_group_cohorts(self):
pass

def time_graph_construct(self):
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts")
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)

def track_num_tasks(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
return len(result.dask.to_dict())
return len(self.dask.to_dict())

def track_num_tasks_optimized(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
(opt,) = dask.optimize(result)
return len(opt.dask.to_dict())
(opt,) = dask.optimize(self.dask)
return len(opt.to_dict())

def track_num_layers(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
return len(result.dask.layers)
return len(self.dask.layers)

track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
Expand Down Expand Up @@ -193,6 +190,19 @@ def setup(self, *args, **kwargs):
self.expected = pd.RangeIndex(self.by.max() + 1)


class SingleChunk(Cohorts):
"""Single chunk along reduction axis: always blockwise."""

def setup(self, *args, **kwargs):
index = pd.date_range("1959-01-01", freq="D", end="1962-12-31")
self.time = pd.Series(index)
TIME = len(self.time)
self.axis = (2,)
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, -1))
self.by = codes_for_resampling(index, freq="5D")
self.expected = pd.RangeIndex(self.by.max() + 1)


class OISST(Cohorts):
def setup(self, *args, **kwargs):
self.array = dask.array.ones((1, 14532), chunks=(1, 10))
Expand Down
32 changes: 23 additions & 9 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,22 @@ def slices_from_chunks(chunks):


def _compute_label_chunk_bitmask(labels, chunks, nlabels):
def make_bitmask(rows, cols):
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows.shape)
return csc_array((data, (rows, cols)), dtype=bool, shape=(nchunks, nlabels))

assert isinstance(labels, np.ndarray)
shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)

labels = np.broadcast_to(labels, shape[-labels.ndim :])
# Shortcut for 1D with size-1 chunks
if shape == (nchunks,):
rows_array = np.arange(nchunks)
cols_array = labels
mask = labels >= 0
return make_bitmask(rows_array[mask], cols_array[mask])

labels = np.broadcast_to(labels, shape[-labels.ndim :])
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
Expand All @@ -272,10 +282,8 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
label_is_present[:] = False
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
cols_array = np.concatenate(cols)
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))

return bitmask
return make_bitmask(rows_array, cols_array)


# @memoize
Expand Down Expand Up @@ -312,13 +320,18 @@ def find_group_cohorts(
labels = np.asarray(labels)

shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)

# assumes that `labels` are factorized
if expected_groups is None:
nlabels = labels.max() + 1
else:
nlabels = expected_groups[-1] + 1

# 1. Single chunk, blockwise always
if nchunks == 1:
return "blockwise", {(0,): list(range(nlabels))}

labels = np.broadcast_to(labels, shape[-labels.ndim :])
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)

Expand Down Expand Up @@ -346,21 +359,21 @@ def invert(x) -> tuple[np.ndarray, ...]:

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

# 1. Every group is contained to one block, use blockwise here.
# 2. Every group is contained to one block, use blockwise here.
if bitmask.shape[CHUNK_AXIS] == 1 or (chunks_per_label == 1).all():
logger.info("find_group_cohorts: blockwise is preferred.")
return "blockwise", chunks_cohorts

# 2. Perfectly chunked so there is only a single cohort
# 3. Perfectly chunked so there is only a single cohort
if len(chunks_cohorts) == 1:
logger.info("Only found a single cohort. 'map-reduce' is preferred.")
return "map-reduce", chunks_cohorts if merge else {}

# 3. Our dataset has chunksize one along the axis,
# 4. Our dataset has chunksize one along the axis,
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
# 4. Every chunk only has a single group, but that group might extend across multiple chunks
# 5. Every chunk only has a single group, but that group might extend across multiple chunks
one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
# 5. Existing cohorts don't overlap, great for time grouping with perfect chunking
# 6. Existing cohorts don't overlap, great for time grouping with perfect chunking
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
if one_group_per_chunk or single_chunks or no_overlapping_cohorts:
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
Expand Down Expand Up @@ -393,6 +406,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
sparsity, MAX_SPARSITY_FOR_COHORTS
)
)
# 7. Groups seem fairly randomly distributed, use "map-reduce".
if sparsity > MAX_SPARSITY_FOR_COHORTS:
if not merge:
logger.info(
Expand Down
19 changes: 17 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,12 +946,12 @@ def test_verify_complex_cohorts(chunksize: int) -> None:
@pytest.mark.parametrize("chunksize", (12,) + tuple(range(1, 13)) + (-1,))
def test_method_guessing(chunksize):
# just a regression test
labels = np.tile(np.arange(1, 13), 30)
labels = np.tile(np.arange(0, 12), 30)
by = dask.array.from_array(labels, chunks=chunksize) - 1
preferred_method, chunks_cohorts = find_group_cohorts(labels, by.chunks[slice(-1, None)])
if chunksize == -1:
assert preferred_method == "blockwise"
assert chunks_cohorts == {(0,): list(range(1, 13))}
assert chunks_cohorts == {(0,): list(range(12))}
elif chunksize in (1, 2, 3, 4, 6):
assert preferred_method == "cohorts"
assert len(chunks_cohorts) == 12 // chunksize
Expand All @@ -960,6 +960,21 @@ def test_method_guessing(chunksize):
assert chunks_cohorts == {}


@requires_dask
@pytest.mark.parametrize("ndim", [1, 2, 3])
def test_single_chunk_method_is_blockwise(ndim):
for by_ndim in range(1, ndim + 1):
chunks = (5,) * (ndim - by_ndim) + (-1,) * by_ndim
assert len(chunks) == ndim
array = dask.array.ones(shape=(10,) * ndim, chunks=chunks)
by = np.zeros(shape=(10,) * by_ndim, dtype=int)
method, chunks_cohorts = find_group_cohorts(
by, chunks=[array.chunks[ax] for ax in range(-by.ndim, 0)]
)
assert method == "blockwise"
assert chunks_cohorts == {(0,): [0]}


@requires_dask
@pytest.mark.parametrize(
"chunk_at,expected",
Expand Down

0 comments on commit 627bf2b

Please sign in to comment.