From f5fa5c1a036a5b4ed7a7f67b48344935435735a8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 15 Apr 2024 14:51:53 +0000 Subject: [PATCH] Add new cumulative_sum function to numpy and array_api namespaces --- CHANGELOG.md | 4 ++ jax/_src/numpy/lax_numpy.py | 31 +++++++++++ jax/experimental/array_api/__init__.py | 1 + .../array_api/_statistical_functions.py | 4 ++ jax/numpy/__init__.py | 1 + tests/array_api_test.py | 1 + tests/lax_numpy_test.py | 53 +++++++++++++++++++ 7 files changed, 95 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2e8794f2df1..c3941dd48f46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.27 +* New Functionality + * Added {func}`jax.numpy.cumulative_sum`, following the addition of this + function in the array API 2023 standard, soon to be adopted by NumPy. + * Changes * {func}`jax.pure_callback` and {func}`jax.experimental.io_callback` now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b23ca210b88..2645339bc555 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5567,3 +5567,34 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, else: raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}") return arr.at[unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode) + + +@util.implements(getattr(np, 'cumulative_sum', None)) +def cumulative_sum( + x: Array, /, *, axis: int | None = None, + dtype: DTypeLike | None = None, + include_initial: bool = False) -> Array: + if isscalar(x) or x.ndim == 0: + raise ValueError( + "The input must be non-scalar to take a cumulative sum, however a " + "scalar value or scalar array was given." + ) + if axis is None and x.ndim > 1: + raise ValueError( + f"The input array has rank {x.ndim}, however axis was not set to an " + "explicit value. The axis argument is only optional for one-dimensional " + "arrays.") + axis = axis or 0 + util.check_arraylike("cumulative_sum", x) + dtypes.check_user_dtype_supported(dtype) + kind = x.dtype.kind + if (dtype is None and kind in {'i', 'u'} + and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): + dtype = dtypes.dtype(dtypes._default_types[kind]) + + out = reductions.cumsum(x, axis=axis, dtype=dtype) + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + if include_initial: + out = concat([zeros(zeros_shape, dtype=out.dtype), out], axis=axis) + return out diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index d59a9eec751e..2bd69ca09d0f 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -195,6 +195,7 @@ ) from jax.experimental.array_api._statistical_functions import ( + cumulative_sum as cumulative_sum, max as max, mean as mean, min as min, diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index 2e1333317605..141b80abfb14 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -18,6 +18,10 @@ ) +def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False): + """Calculates the cumulative sum of elements in the input array x.""" + return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial) + def max(x, /, *, axis=None, keepdims=False): """Calculates the maximum value of the input array x.""" return jax.numpy.max(x, axis=axis, keepdims=keepdims) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 273c5a2aa80f..94687dae3eaa 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -78,6 +78,7 @@ cov as cov, cross as cross, csingle as csingle, + cumulative_sum as cumulative_sum, delete as delete, diag as diag, diagflat as diagflat, diff --git a/tests/array_api_test.py b/tests/array_api_test.py index b11fc35845ff..2f004cb2137e 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -67,6 +67,7 @@ 'conj', 'cos', 'cosh', + 'cumulative_sum', 'divide', 'e', 'empty', diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c957ed669b3a..c31264d9f21f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -286,6 +286,59 @@ def np_fun(x): atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) + + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list(range(-len(shape), len(shape))) + [None] if len(shape) == 1], + dtype=all_dtypes, + out_dtype=all_dtypes + [None], + include_initial=[False, True], + ) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + x = rng(shape, dtype) + out = jnp.cumulative_sum(x, dtype=out_dtype, include_initial=include_initial) + + target_dtype = out_dtype or x.dtype + kind = x.dtype.kind + if (out_dtype is None and kind in {'i', 'u'} + and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): + target_dtype = dtypes.dtype(dtypes._default_types[kind]) + assert out.dtype == target_dtype + + _axis = axis or 0 + target_shape = list(x.shape) + if include_initial: + target_shape[_axis] += 1 + assert out.shape == tuple(target_shape) + + target = jnp.cumsum(x, axis=_axis, dtype=out.dtype) + if include_initial: + zeros_shape = target_shape + zeros_shape[_axis] = 1 + target = jnp.concat([jnp.zeros(target_shape, dtype=out.dtype), target]) + self.assertArraysEqual(out, target) + + + @jtu.sample_product( + shape=all_shapes, dtype=all_dtypes, + include_initial=[False, True]) + def testCumulativeSumErrors(self, shape, dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + x = rng(shape, dtype) + if jnp.isscalar(x) or x.ndim == 0: + msg = r"The input must be non-scalar to take" + with self.assertRaisesRegex(ValueError, msg): + jnp.cumulative_sum(x, include_initial=include_initial) + elif x.ndim > 1: + msg = r"The input array has rank \d*, however" + with self.assertRaisesRegex(ValueError, msg): + jnp.cumulative_sum(x, include_initial=include_initial) + + + @jtu.sample_product( [dict(shape=shape, axis=axis) for shape in all_shapes