From 47a6c4c3e48cab2bcd75643f72e2629365d0ab91 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 30 Sep 2024 10:56:23 +0100 Subject: [PATCH] Remove old reduction implementation --- cubed/array_api/linear_algebra_functions.py | 15 +-- cubed/array_api/searching_functions.py | 6 +- cubed/array_api/statistical_functions.py | 19 +-- cubed/array_api/utility_functions.py | 6 +- cubed/core/groupby.py | 4 +- cubed/core/ops.py | 121 +------------------- cubed/nan_functions.py | 8 +- cubed/tests/test_array_api.py | 5 +- cubed/tests/test_core.py | 12 +- 9 files changed, 23 insertions(+), 173 deletions(-) diff --git a/cubed/array_api/linear_algebra_functions.py b/cubed/array_api/linear_algebra_functions.py index 272e06b4..c568a6f6 100644 --- a/cubed/array_api/linear_algebra_functions.py +++ b/cubed/array_api/linear_algebra_functions.py @@ -12,7 +12,7 @@ from cubed.core import blockwise, reduction, squeeze -def matmul(x1, x2, /, use_new_impl=True, split_every=None): +def matmul(x1, x2, /, split_every=None): if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in matmul") @@ -51,9 +51,7 @@ def matmul(x1, x2, /, use_new_impl=True, split_every=None): dtype=dtype, ) - out = _sum_wo_cat( - out, axis=-2, dtype=dtype, use_new_impl=use_new_impl, split_every=split_every - ) + out = _sum_wo_cat(out, axis=-2, dtype=dtype, split_every=split_every) if x1_is_1d: out = squeeze(out, -2) @@ -68,7 +66,7 @@ def _matmul(a, b): return chunk[..., nxp.newaxis, :] -def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None): +def _sum_wo_cat(a, axis=None, dtype=None, split_every=None): if a.shape[axis] == 1: return squeeze(a, axis) @@ -78,7 +76,6 @@ def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None): _chunk_sum, axis=axis, dtype=dtype, - use_new_impl=use_new_impl, split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) @@ -99,7 +96,7 @@ def matrix_transpose(x, /): return permute_dims(x, axes) -def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None): +def tensordot(x1, x2, /, *, axes=2, split_every=None): from cubed.array_api.statistical_functions import sum if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: @@ -147,7 +144,6 @@ def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None): out, axis=x1_axes, dtype=dtype, - use_new_impl=use_new_impl, split_every=split_every, ) @@ -161,7 +157,7 @@ def _tensordot(a, b, axes): return x -def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None): +def vecdot(x1, x2, /, *, axis=-1, split_every=None): # based on the implementation in array-api-compat if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in vecdot") @@ -176,7 +172,6 @@ def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None): res = matmul( x1_[..., None, :], x2_[..., None], - use_new_impl=use_new_impl, split_every=split_every, ) return res[..., 0, 0] diff --git a/cubed/array_api/searching_functions.py b/cubed/array_api/searching_functions.py index 9f4590a3..d01cc7df 100644 --- a/cubed/array_api/searching_functions.py +++ b/cubed/array_api/searching_functions.py @@ -5,7 +5,7 @@ from cubed.core.ops import arg_reduction, elemwise -def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def argmax(x, /, *, axis=None, keepdims=False, split_every=None): if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argmax") if axis is None: @@ -17,12 +17,11 @@ def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No nxp.argmax, axis=axis, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) -def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def argmin(x, /, *, axis=None, keepdims=False, split_every=None): if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argmin") if axis is None: @@ -34,7 +33,6 @@ def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No nxp.argmin, axis=axis, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 7ee6525e..ab1f8c2a 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -18,7 +18,7 @@ from cubed.core import reduction -def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def max(x, /, *, axis=None, keepdims=False, split_every=None): if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in max") return reduction( @@ -26,13 +26,12 @@ def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) nxp.max, axis=axis, dtype=x.dtype, - use_new_impl=use_new_impl, split_every=split_every, keepdims=keepdims, ) -def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def mean(x, /, *, axis=None, keepdims=False, split_every=None): if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in mean") # This implementation uses NumPy and Zarr's structured arrays to store a @@ -53,7 +52,6 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None intermediate_dtype=intermediate_dtype, dtype=dtype, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) @@ -108,7 +106,7 @@ def _numel(x, **kwargs): return nxp.broadcast_to(nxp.asarray(prod, dtype=dtype), new_shape) -def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def min(x, /, *, axis=None, keepdims=False, split_every=None): if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in min") return reduction( @@ -116,15 +114,12 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) nxp.min, axis=axis, dtype=x.dtype, - use_new_impl=use_new_impl, split_every=split_every, keepdims=keepdims, ) -def prod( - x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None -): +def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): # boolean is allowed by numpy if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: raise TypeError("Only numeric or boolean dtypes are allowed in prod") @@ -148,15 +143,12 @@ def prod( axis=axis, dtype=dtype, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) -def sum( - x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None -): +def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): # boolean is allowed by numpy if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: raise TypeError("Only numeric or boolean dtypes are allowed in sum") @@ -180,7 +172,6 @@ def sum( axis=axis, dtype=dtype, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) diff --git a/cubed/array_api/utility_functions.py b/cubed/array_api/utility_functions.py index 9825dd9b..f076f754 100644 --- a/cubed/array_api/utility_functions.py +++ b/cubed/array_api/utility_functions.py @@ -3,7 +3,7 @@ from cubed.core import reduction -def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def all(x, /, *, axis=None, keepdims=False, split_every=None): if x.size == 0: return asarray(True, dtype=x.dtype) return reduction( @@ -12,12 +12,11 @@ def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) axis=axis, dtype=bool, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) -def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def any(x, /, *, axis=None, keepdims=False, split_every=None): if x.size == 0: return asarray(False, dtype=x.dtype) return reduction( @@ -26,6 +25,5 @@ def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) axis=axis, dtype=bool, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) diff --git a/cubed/core/groupby.py b/cubed/core/groupby.py index fece0853..8c4aa9b7 100644 --- a/cubed/core/groupby.py +++ b/cubed/core/groupby.py @@ -2,7 +2,7 @@ from cubed.array_api.manipulation_functions import broadcast_to, expand_dims from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import map_blocks, map_direct, reduction_new +from cubed.core.ops import map_blocks, map_direct, reduction from cubed.utils import array_memory, get_item from cubed.vendor.dask.array.core import normalize_chunks @@ -105,7 +105,7 @@ def wrapper(a, by, **kwargs): out = expand_dims(out, axis=dummy_axis) # then reduce across blocks - return reduction_new( + return reduction( out, func=None, combine_func=combine_func, diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 627e7916..6224c112 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -1056,122 +1056,6 @@ def key_function(out_key): def reduction( - x: "Array", - func, - combine_func=None, - aggregate_func=None, - axis=None, - intermediate_dtype=None, - dtype=None, - keepdims=False, - use_new_impl=True, - split_every=None, - extra_func_kwargs=None, -) -> "Array": - """Apply a function to reduce an array along one or more axes.""" - if use_new_impl: - return reduction_new( - x, - func, - combine_func=combine_func, - aggregate_func=aggregate_func, - axis=axis, - intermediate_dtype=intermediate_dtype, - dtype=dtype, - keepdims=keepdims, - split_every=split_every, - extra_func_kwargs=extra_func_kwargs, - ) - if combine_func is None: - combine_func = func - if axis is None: - axis = tuple(range(x.ndim)) - if isinstance(axis, Integral): - axis = (axis,) - axis = validate_axis(axis, x.ndim) - if intermediate_dtype is None: - intermediate_dtype = dtype - - inds = tuple(range(x.ndim)) - - result = x - allowed_mem = x.spec.allowed_mem - max_mem = allowed_mem - x.spec.reserved_mem - - # reduce initial chunks - args = (result, inds) - adjust_chunks = { - i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks) - } - result = blockwise( - func, - inds, - *args, - axis=axis, - keepdims=True, - dtype=intermediate_dtype, - adjust_chunks=adjust_chunks, - extra_func_kwargs=extra_func_kwargs, - ) - - # merge/reduce along axis in multiple rounds until there's a single block in each reduction axis - while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis): - # merge along axis - target_chunks = list(result.chunksize) - chunk_mem = array_memory(intermediate_dtype, result.chunksize) - for i, s in enumerate(result.shape): - if i in axis: - assert result.chunksize[i] == 1 # result of reduction - if len(axis) > 1: - # multi-axis: don't exceed original chunksize in any reduction axis - # TODO: improve to use up to max_mem - target_chunks[i] = min(s, x.chunksize[i]) - else: - # single axis: see how many result chunks fit in max_mem - # factor of 4 is memory for {compressed, uncompressed} x {input, output} - target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4) - if target_chunk_size <= 1: - raise ValueError( - f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size" - ) - target_chunks[i] = min(s, target_chunk_size) - _target_chunks = tuple(target_chunks) - result = merge_chunks(result, _target_chunks) - - # reduce chunks (if any axis chunksize is > 1) - if any(s > 1 for i, s in enumerate(result.chunksize) if i in axis): - args = (result, inds) - adjust_chunks = { - i: (1,) * len(c) if i in axis else c - for i, c in enumerate(result.chunks) - } - result = blockwise( - combine_func, - inds, - *args, - axis=axis, - keepdims=True, - dtype=intermediate_dtype, - adjust_chunks=adjust_chunks, - extra_func_kwargs=extra_func_kwargs, - ) - - if aggregate_func is not None: - result = map_blocks(aggregate_func, result, dtype=dtype) - - if not keepdims: - axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1) - if len(axis_to_squeeze) > 0: - result = squeeze(result, axis_to_squeeze) - - from cubed.array_api import astype - - result = astype(result, dtype, copy=False) - - return result - - -def reduction_new( x: "Array", func, combine_func=None, @@ -1426,9 +1310,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None): return result -def arg_reduction( - x, /, arg_func, axis=None, *, keepdims=False, use_new_impl=True, split_every=None -): +def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False, split_every=None): """A reduction that returns the array indexes, not the values.""" dtype = nxp.int64 # index data type intermediate_dtype = [("i", dtype), ("v", x.dtype)] @@ -1454,7 +1336,6 @@ def arg_reduction( intermediate_dtype=intermediate_dtype, dtype=dtype, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 2acd308b..928a2f41 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -18,7 +18,7 @@ # https://github.com/data-apis/array-api/issues/621 -def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +def nanmean(x, /, *, axis=None, keepdims=False, split_every=None): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" dtype = x.dtype intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] @@ -31,7 +31,6 @@ def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=N intermediate_dtype=intermediate_dtype, dtype=dtype, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) @@ -61,9 +60,7 @@ def _nannumel(x, **kwargs): return nxp.sum(~(nxp.isnan(x)), **kwargs) -def nansum( - x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None -): +def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): """Return the sum of array elements over a given axis treating NaNs as zero.""" if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in nansum") @@ -84,6 +81,5 @@ def nansum( axis=axis, dtype=dtype, keepdims=keepdims, - use_new_impl=use_new_impl, split_every=split_every, ) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index b7764caa..b674781a 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -702,12 +702,11 @@ def test_argmin_axis_0(spec): # Statistical functions -@pytest.mark.parametrize("use_new_impl", [False, True]) -def test_mean_axis_0(spec, executor, use_new_impl): +def test_mean_axis_0(spec, executor): a = xp.asarray( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec ) - b = xp.mean(a, axis=0, use_new_impl=use_new_impl) + b = xp.mean(a, axis=0) assert_array_equal( b.compute(executor=executor), np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).mean(axis=0), diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 1e415092..2892843b 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -429,14 +429,6 @@ def test_reduction_multiple_rounds(tmp_path, executor): assert_array_equal(b.compute(executor=executor), np.ones((100, 10)).sum(axis=0)) -def test_reduction_not_enough_memory(tmp_path): - spec = cubed.Spec(tmp_path, allowed_mem=50) - a = xp.ones((100, 10), dtype=np.uint8, chunks=(1, 10), spec=spec) - with pytest.raises(ValueError, match=r"Not enough memory for reduction"): - # only a problem with the old implementation, so set use_new_impl=False - xp.sum(a, axis=0, dtype=np.uint8, use_new_impl=False) - - def test_partial_reduce(spec): a = xp.asarray(np.arange(242).reshape((11, 22)), chunks=(3, 4), spec=spec) b = partial_reduce(a, np.sum, split_every={0: 2}) @@ -612,7 +604,7 @@ def test_plan_quad_means(tmp_path, t_length): u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) uv = u * v - m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True) + m = xp.mean(uv, axis=0, split_every=10) assert m.plan._finalize().num_tasks() > 0 m.visualize( @@ -674,7 +666,7 @@ def test_quad_means_zarr(tmp_path, t_length=50): u = cubed.from_zarr(f"{tmp_path}/u_{t_length}.zarr", spec=spec) v = cubed.from_zarr(f"{tmp_path}/v_{t_length}.zarr", spec=spec) uv = u * v - m = xp.mean(uv, axis=0, use_new_impl=True, split_every=10) + m = xp.mean(uv, axis=0, split_every=10) opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=40)