Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] Accept Cubed arrays instead of dask #249

Closed
wants to merge 8 commits into from
53 changes: 31 additions & 22 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
generic_aggregate,
)
from .cache import memoize
from .xrutils import is_duck_array, is_duck_dask_array, isnull
from .duck_array_ops import reshape
from .xrutils import is_chunked_array, is_duck_array, is_duck_dask_array, isnull

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -764,7 +765,7 @@ def chunk_reduce(
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
# always reshape to 1D along group dimensions
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
array = array.reshape(newshape)
array = reshape(array, newshape)
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1
Expand Down Expand Up @@ -1294,6 +1295,9 @@ def dask_groupby_agg(
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
import dask.array
from dask.array.core import slices_from_chunks
from xarray.core.parallelcompat import get_chunked_array_type

chunkmanager = get_chunked_array_type(array)
Comment on lines +1298 to +1300
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously this approach would introduce a dependency on xarray, which presumably is not desirable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine just having dask_kwargs and cubed_kwargs instead of all this complexity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably should have just done that in xarray itself 😅


# I think _tree_reduce expects this
assert isinstance(axis, Sequence)
Expand All @@ -1314,14 +1318,18 @@ def dask_groupby_agg(
# Unifying chunks is necessary for argreductions.
# We need to rechunk before zipping up with the index
# let's always do it anyway
if not is_duck_dask_array(by):
if not is_chunked_array(by):
# chunk numpy arrays like the input array
# This removes an extra rechunk-merge layer that would be
# added otherwise
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))

by = dask.array.from_array(by, chunks=chunks)
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])
by = chunkmanager.from_array(
by,
chunks=chunks,
spec=array.spec, # cubed needs all arguments to blockwise to have same Spec
)
_, (array, by) = chunkmanager.unify_chunks(array, inds, by, inds[-by.ndim :])

# preprocess the array:
# - for argreductions, this zips the index together with the array block
Expand Down Expand Up @@ -1363,7 +1371,7 @@ def dask_groupby_agg(
blockwise_method = tlz.compose(_expand_dims, blockwise_method)

# apply reduction on chunk
intermediate = dask.array.blockwise(
intermediate = chunkmanager.blockwise(
partial(
blockwise_method,
axis=axis,
Expand All @@ -1379,11 +1387,11 @@ def dask_groupby_agg(
inds,
by,
inds[-by.ndim :],
concatenate=False,
# concatenate=False,
dtype=array.dtype, # this is purely for show
meta=array._meta,
# meta=array._meta,
align_arrays=False,
name=f"{name}-chunk-{token}",
# name=f"{name}-chunk-{token}",
Comment on lines +1392 to +1394
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_meta and name are dask-specific. Are they used for anything important here or just for labelling tasks in the graph?

Copy link
Collaborator

@dcherian dcherian Jun 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you don't provide meta, dask will try to figure it out and then break?

)

group_chunks: tuple[tuple[int | float, ...]]
Expand All @@ -1392,18 +1400,20 @@ def dask_groupby_agg(
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = partial(_simple_combine, reindex=reindex)
combine_name = "simple-combine"
else:
combine = partial(_grouped_combine, engine=engine, sort=sort)
combine_name = "grouped-combine"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test for these names that will need to be fixed.


def identity(x, axis, keepdims):
return x

tree_reduce = partial(
dask.array.reductions._tree_reduce,
name=f"{name}-reduce-{method}-{combine_name}",
chunkmanager.reduction,
func=identity,
# name=f"{name}-reduce-{method}-{combine_name}",
dtype=array.dtype,
axis=axis,
keepdims=True,
concatenate=False,
# concatenate=False,
)
aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value)

Expand All @@ -1415,8 +1425,8 @@ def dask_groupby_agg(
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
combine=partial(combine, agg=agg),
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
combine_func=partial(combine, agg=agg),
aggregate_func=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, dtype=by.dtype)
Expand Down Expand Up @@ -1479,7 +1489,7 @@ def dask_groupby_agg(
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)

# Can't use map_blocks because it forces concatenate=True along drop_axes,
result = dask.array.blockwise(
result = chunkmanager.blockwise(
_extract_result,
out_inds,
reduced,
Expand All @@ -1488,7 +1498,7 @@ def dask_groupby_agg(
dtype=agg.dtype["final"],
key=agg.name,
name=f"{name}-{token}",
concatenate=False,
# concatenate=False,
)

return (result, groups)
Expand Down Expand Up @@ -1889,7 +1899,7 @@ def groupby_reduce(
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
nax = len(axis_)

has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
has_dask = is_chunked_array(array) or is_duck_dask_array(by_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_dask = is_chunked_array(array) or is_duck_dask_array(by_)
is_chunked = is_chunked_array(array) or is_chunked_array(by_)


if _is_first_last_reduction(func):
if has_dask and nax != 1:
Expand Down Expand Up @@ -2008,9 +2018,8 @@ def groupby_reduce(
# nan group labels are factorized to -1, and preserved
# now we get rid of them by reindexing
# This also handles bins with no data
result = reindex_(
result, from_=groups[0], to=expected_groups, fill_value=fill_value
).reshape(result.shape[:-1] + grp_shape)
reindexed = reindex_(result, from_=groups[0], to=expected_groups, fill_value=fill_value)
result = reshape(reindexed, reindexed.shape[:-1] + grp_shape)
groups = final_groups

if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
Expand Down
18 changes: 18 additions & 0 deletions flox/duck_array_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np


def get_array_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
else:
return np


def reshape(array, shape):
xp = get_array_namespace(array)
return xp.reshape(array, shape)


def asarray(obj, like):
xp = get_array_namespace(like)
return xp.asarray(obj)
11 changes: 9 additions & 2 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@ def is_duck_array(value: Any) -> bool:
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and hasattr(value, "__array_function__")
and hasattr(value, "__array_ufunc__")
and (
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__")
)
)


def is_chunked_array(x) -> bool:
"""True if dask or cubed"""
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))


def is_dask_collection(x):
try:
import dask
Expand Down