Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove old reduction implementation #589

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions cubed/array_api/linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cubed.core import blockwise, reduction, squeeze


def matmul(x1, x2, /, use_new_impl=True, split_every=None):
def matmul(x1, x2, /, split_every=None):
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in matmul")

Expand Down Expand Up @@ -51,9 +51,7 @@ def matmul(x1, x2, /, use_new_impl=True, split_every=None):
dtype=dtype,
)

out = _sum_wo_cat(
out, axis=-2, dtype=dtype, use_new_impl=use_new_impl, split_every=split_every
)
out = _sum_wo_cat(out, axis=-2, dtype=dtype, split_every=split_every)

if x1_is_1d:
out = squeeze(out, -2)
Expand All @@ -68,7 +66,7 @@ def _matmul(a, b):
return chunk[..., nxp.newaxis, :]


def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
def _sum_wo_cat(a, axis=None, dtype=None, split_every=None):
if a.shape[axis] == 1:
return squeeze(a, axis)

Expand All @@ -78,7 +76,6 @@ def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
_chunk_sum,
axis=axis,
dtype=dtype,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
Expand All @@ -99,7 +96,7 @@ def matrix_transpose(x, /):
return permute_dims(x, axes)


def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
def tensordot(x1, x2, /, *, axes=2, split_every=None):
from cubed.array_api.statistical_functions import sum

if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
Expand Down Expand Up @@ -147,7 +144,6 @@ def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
out,
axis=x1_axes,
dtype=dtype,
use_new_impl=use_new_impl,
split_every=split_every,
)

Expand All @@ -161,7 +157,7 @@ def _tensordot(a, b, axes):
return x


def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
def vecdot(x1, x2, /, *, axis=-1, split_every=None):
# based on the implementation in array-api-compat
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in vecdot")
Expand All @@ -176,7 +172,6 @@ def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
res = matmul(
x1_[..., None, :],
x2_[..., None],
use_new_impl=use_new_impl,
split_every=split_every,
)
return res[..., 0, 0]
6 changes: 2 additions & 4 deletions cubed/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cubed.core.ops import arg_reduction, elemwise


def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def argmax(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")
if axis is None:
Expand All @@ -17,12 +17,11 @@ def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No
nxp.argmax,
axis=axis,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def argmin(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmin")
if axis is None:
Expand All @@ -34,7 +33,6 @@ def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No
nxp.argmin,
axis=axis,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)

Expand Down
19 changes: 5 additions & 14 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
from cubed.core import reduction


def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def max(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in max")
return reduction(
x,
nxp.max,
axis=axis,
dtype=x.dtype,
use_new_impl=use_new_impl,
split_every=split_every,
keepdims=keepdims,
)


def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand All @@ -53,7 +52,6 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
Expand Down Expand Up @@ -108,23 +106,20 @@ def _numel(x, **kwargs):
return nxp.broadcast_to(nxp.asarray(prod, dtype=dtype), new_shape)


def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def min(x, /, *, axis=None, keepdims=False, split_every=None):
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in min")
return reduction(
x,
nxp.min,
axis=axis,
dtype=x.dtype,
use_new_impl=use_new_impl,
split_every=split_every,
keepdims=keepdims,
)


def prod(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
Expand All @@ -148,15 +143,12 @@ def prod(
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def sum(
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
):
def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
Expand All @@ -180,7 +172,6 @@ def sum(
axis=axis,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
6 changes: 2 additions & 4 deletions cubed/array_api/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from cubed.core import reduction


def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def all(x, /, *, axis=None, keepdims=False, split_every=None):
if x.size == 0:
return asarray(True, dtype=x.dtype)
return reduction(
Expand All @@ -12,12 +12,11 @@ def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
axis=axis,
dtype=bool,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)


def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
def any(x, /, *, axis=None, keepdims=False, split_every=None):
if x.size == 0:
return asarray(False, dtype=x.dtype)
return reduction(
Expand All @@ -26,6 +25,5 @@ def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
axis=axis,
dtype=bool,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)
4 changes: 2 additions & 2 deletions cubed/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_blocks, map_direct, reduction_new
from cubed.core.ops import map_blocks, map_direct, reduction
from cubed.utils import array_memory, get_item
from cubed.vendor.dask.array.core import normalize_chunks

Expand Down Expand Up @@ -105,7 +105,7 @@ def wrapper(a, by, **kwargs):
out = expand_dims(out, axis=dummy_axis)

# then reduce across blocks
return reduction_new(
return reduction(
out,
func=None,
combine_func=combine_func,
Expand Down
121 changes: 1 addition & 120 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,122 +1056,6 @@ def key_function(out_key):


def reduction(
x: "Array",
func,
combine_func=None,
aggregate_func=None,
axis=None,
intermediate_dtype=None,
dtype=None,
keepdims=False,
use_new_impl=True,
split_every=None,
extra_func_kwargs=None,
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
if use_new_impl:
return reduction_new(
x,
func,
combine_func=combine_func,
aggregate_func=aggregate_func,
axis=axis,
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)
if combine_func is None:
combine_func = func
if axis is None:
axis = tuple(range(x.ndim))
if isinstance(axis, Integral):
axis = (axis,)
axis = validate_axis(axis, x.ndim)
if intermediate_dtype is None:
intermediate_dtype = dtype

inds = tuple(range(x.ndim))

result = x
allowed_mem = x.spec.allowed_mem
max_mem = allowed_mem - x.spec.reserved_mem

# reduce initial chunks
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
}
result = blockwise(
func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

# merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis):
# merge along axis
target_chunks = list(result.chunksize)
chunk_mem = array_memory(intermediate_dtype, result.chunksize)
for i, s in enumerate(result.shape):
if i in axis:
assert result.chunksize[i] == 1 # result of reduction
if len(axis) > 1:
# multi-axis: don't exceed original chunksize in any reduction axis
# TODO: improve to use up to max_mem
target_chunks[i] = min(s, x.chunksize[i])
else:
# single axis: see how many result chunks fit in max_mem
# factor of 4 is memory for {compressed, uncompressed} x {input, output}
target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4)
if target_chunk_size <= 1:
raise ValueError(
f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size"
)
target_chunks[i] = min(s, target_chunk_size)
_target_chunks = tuple(target_chunks)
result = merge_chunks(result, _target_chunks)

# reduce chunks (if any axis chunksize is > 1)
if any(s > 1 for i, s in enumerate(result.chunksize) if i in axis):
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c
for i, c in enumerate(result.chunks)
}
result = blockwise(
combine_func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

if aggregate_func is not None:
result = map_blocks(aggregate_func, result, dtype=dtype)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
if len(axis_to_squeeze) > 0:
result = squeeze(result, axis_to_squeeze)

from cubed.array_api import astype

result = astype(result, dtype, copy=False)

return result


def reduction_new(
x: "Array",
func,
combine_func=None,
Expand Down Expand Up @@ -1426,9 +1310,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
return result


def arg_reduction(
x, /, arg_func, axis=None, *, keepdims=False, use_new_impl=True, split_every=None
):
def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False, split_every=None):
"""A reduction that returns the array indexes, not the values."""
dtype = nxp.int64 # index data type
intermediate_dtype = [("i", dtype), ("v", x.dtype)]
Expand All @@ -1454,7 +1336,6 @@ def arg_reduction(
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
)

Expand Down
Loading
Loading