-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[draft] Accept Cubed arrays instead of dask #249
Changes from all commits
6eca6f1
58d2021
8fdc367
4777e77
5582e5e
fabaf35
858c98a
786af6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -35,7 +35,8 @@ | |||||
generic_aggregate, | ||||||
) | ||||||
from .cache import memoize | ||||||
from .xrutils import is_duck_array, is_duck_dask_array, isnull | ||||||
from .duck_array_ops import reshape | ||||||
from .xrutils import is_chunked_array, is_duck_array, is_duck_dask_array, isnull | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
try: | ||||||
|
@@ -764,7 +765,7 @@ def chunk_reduce( | |||||
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :]) | ||||||
# always reshape to 1D along group dimensions | ||||||
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),) | ||||||
array = array.reshape(newshape) | ||||||
array = reshape(array, newshape) | ||||||
group_idx = group_idx.reshape(-1) | ||||||
|
||||||
assert group_idx.ndim == 1 | ||||||
|
@@ -1294,6 +1295,9 @@ def dask_groupby_agg( | |||||
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: | ||||||
import dask.array | ||||||
from dask.array.core import slices_from_chunks | ||||||
from xarray.core.parallelcompat import get_chunked_array_type | ||||||
|
||||||
chunkmanager = get_chunked_array_type(array) | ||||||
|
||||||
# I think _tree_reduce expects this | ||||||
assert isinstance(axis, Sequence) | ||||||
|
@@ -1314,14 +1318,18 @@ def dask_groupby_agg( | |||||
# Unifying chunks is necessary for argreductions. | ||||||
# We need to rechunk before zipping up with the index | ||||||
# let's always do it anyway | ||||||
if not is_duck_dask_array(by): | ||||||
if not is_chunked_array(by): | ||||||
# chunk numpy arrays like the input array | ||||||
# This removes an extra rechunk-merge layer that would be | ||||||
# added otherwise | ||||||
chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) | ||||||
|
||||||
by = dask.array.from_array(by, chunks=chunks) | ||||||
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :]) | ||||||
by = chunkmanager.from_array( | ||||||
by, | ||||||
chunks=chunks, | ||||||
spec=array.spec, # cubed needs all arguments to blockwise to have same Spec | ||||||
) | ||||||
_, (array, by) = chunkmanager.unify_chunks(array, inds, by, inds[-by.ndim :]) | ||||||
|
||||||
# preprocess the array: | ||||||
# - for argreductions, this zips the index together with the array block | ||||||
|
@@ -1363,7 +1371,7 @@ def dask_groupby_agg( | |||||
blockwise_method = tlz.compose(_expand_dims, blockwise_method) | ||||||
|
||||||
# apply reduction on chunk | ||||||
intermediate = dask.array.blockwise( | ||||||
intermediate = chunkmanager.blockwise( | ||||||
partial( | ||||||
blockwise_method, | ||||||
axis=axis, | ||||||
|
@@ -1379,11 +1387,11 @@ def dask_groupby_agg( | |||||
inds, | ||||||
by, | ||||||
inds[-by.ndim :], | ||||||
concatenate=False, | ||||||
# concatenate=False, | ||||||
dtype=array.dtype, # this is purely for show | ||||||
meta=array._meta, | ||||||
# meta=array._meta, | ||||||
align_arrays=False, | ||||||
name=f"{name}-chunk-{token}", | ||||||
# name=f"{name}-chunk-{token}", | ||||||
Comment on lines
+1392
to
+1394
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you don't provide meta, dask will try to figure it out and then break? |
||||||
) | ||||||
|
||||||
group_chunks: tuple[tuple[int | float, ...]] | ||||||
|
@@ -1392,18 +1400,20 @@ def dask_groupby_agg( | |||||
combine: Callable[..., IntermediateDict] | ||||||
if do_simple_combine: | ||||||
combine = partial(_simple_combine, reindex=reindex) | ||||||
combine_name = "simple-combine" | ||||||
else: | ||||||
combine = partial(_grouped_combine, engine=engine, sort=sort) | ||||||
combine_name = "grouped-combine" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a test for these names that will need to be fixed. |
||||||
|
||||||
def identity(x, axis, keepdims): | ||||||
return x | ||||||
|
||||||
tree_reduce = partial( | ||||||
dask.array.reductions._tree_reduce, | ||||||
name=f"{name}-reduce-{method}-{combine_name}", | ||||||
chunkmanager.reduction, | ||||||
func=identity, | ||||||
# name=f"{name}-reduce-{method}-{combine_name}", | ||||||
dtype=array.dtype, | ||||||
axis=axis, | ||||||
keepdims=True, | ||||||
concatenate=False, | ||||||
# concatenate=False, | ||||||
) | ||||||
aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value) | ||||||
|
||||||
|
@@ -1415,8 +1425,8 @@ def dask_groupby_agg( | |||||
if method == "map-reduce": | ||||||
reduced = tree_reduce( | ||||||
intermediate, | ||||||
combine=partial(combine, agg=agg), | ||||||
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex), | ||||||
combine_func=partial(combine, agg=agg), | ||||||
aggregate_func=partial(aggregate, expected_groups=expected_groups, reindex=reindex), | ||||||
) | ||||||
if is_duck_dask_array(by_input) and expected_groups is None: | ||||||
groups = _extract_unknown_groups(reduced, dtype=by.dtype) | ||||||
|
@@ -1479,7 +1489,7 @@ def dask_groupby_agg( | |||||
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks) | ||||||
|
||||||
# Can't use map_blocks because it forces concatenate=True along drop_axes, | ||||||
result = dask.array.blockwise( | ||||||
result = chunkmanager.blockwise( | ||||||
_extract_result, | ||||||
out_inds, | ||||||
reduced, | ||||||
|
@@ -1488,7 +1498,7 @@ def dask_groupby_agg( | |||||
dtype=agg.dtype["final"], | ||||||
key=agg.name, | ||||||
name=f"{name}-{token}", | ||||||
concatenate=False, | ||||||
# concatenate=False, | ||||||
) | ||||||
|
||||||
return (result, groups) | ||||||
|
@@ -1889,7 +1899,7 @@ def groupby_reduce( | |||||
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore | ||||||
nax = len(axis_) | ||||||
|
||||||
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) | ||||||
has_dask = is_chunked_array(array) or is_duck_dask_array(by_) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if _is_first_last_reduction(func): | ||||||
if has_dask and nax != 1: | ||||||
|
@@ -2008,9 +2018,8 @@ def groupby_reduce( | |||||
# nan group labels are factorized to -1, and preserved | ||||||
# now we get rid of them by reindexing | ||||||
# This also handles bins with no data | ||||||
result = reindex_( | ||||||
result, from_=groups[0], to=expected_groups, fill_value=fill_value | ||||||
).reshape(result.shape[:-1] + grp_shape) | ||||||
reindexed = reindex_(result, from_=groups[0], to=expected_groups, fill_value=fill_value) | ||||||
result = reshape(reindexed, reindexed.shape[:-1] + grp_shape) | ||||||
groups = final_groups | ||||||
|
||||||
if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import numpy as np | ||
|
||
|
||
def get_array_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
else: | ||
return np | ||
|
||
|
||
def reshape(array, shape): | ||
xp = get_array_namespace(array) | ||
return xp.reshape(array, shape) | ||
|
||
|
||
def asarray(obj, like): | ||
xp = get_array_namespace(like) | ||
return xp.asarray(obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obviously this approach would introduce a dependency on xarray, which presumably is not desirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine just having
dask_kwargs
andcubed_kwargs
instead of all this complexity.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I probably should have just done that in xarray itself 😅