From ceeb97573551f26a87c0431fa5546fc49b8904ff Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 Apr 2024 19:57:55 +0000 Subject: [PATCH] Add new cumulative_sum function to numpy and array_api --- CHANGELOG.md | 5 +- docs/jax.numpy.rst | 1 + jax/_src/numpy/reductions.py | 38 +++++++++++- jax/experimental/array_api/__init__.py | 1 + .../array_api/_statistical_functions.py | 4 ++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 3 + tests/array_api_test.py | 1 + tests/lax_numpy_reducers_test.py | 59 +++++++++++++++++++ 9 files changed, 110 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 331554c67e23..a30f333d961e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.27 * New Functionality - * Added {func}`jax.numpy.unstack`, following the addition of this function in - the array API 2023 standard, soon to be adopted by NumPy. + * Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`, + following their addition in the array API 2023 standard, soon to be + adopted by NumPy. * Changes * {func}`jax.pure_callback` and {func}`jax.experimental.io_callback` diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d866df52d16a..5ccf043d2282 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -138,6 +138,7 @@ namespace; they are listed below. csingle cumprod cumsum + cumulative_sum deg2rad degrees delete diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 0e628fc5e5b8..8c5594cb4541 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -26,7 +26,7 @@ from jax import lax from jax._src import api -from jax._src import core +from jax._src import core, config from jax._src import dtypes from jax._src.numpy import ufuncs from jax._src.numpy.util import ( @@ -708,6 +708,42 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None, nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, fill_nan=True, fill_value=1) +@implements(getattr(np, 'cumulative_sum', None)) +def cumulative_sum( + x: ArrayLike, /, *, axis: int | None = None, + dtype: DTypeLike | None = None, + include_initial: bool = False) -> Array: + check_arraylike("cumulative_sum", x) + x = lax_internal.asarray(x) + if x.ndim == 0: + raise ValueError( + "The input must be non-scalar to take a cumulative sum, however a " + "scalar value or scalar array was given." + ) + if axis is None: + axis = 0 + if x.ndim > 1: + raise ValueError( + f"The input array has rank {x.ndim}, however axis was not set to an " + "explicit value. The axis argument is only optional for one-dimensional " + "arrays.") + + axis = _canonicalize_axis(axis, x.ndim) + dtypes.check_user_dtype_supported(dtype) + kind = x.dtype.kind + if (dtype is None and kind in {'i', 'u'} + and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): + dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind]) + x = x.astype(dtype=dtype or x.dtype) + out = cumsum(x, axis=axis) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = lax_internal.concatenate( + [lax_internal.full(zeros_shape, 0, dtype=out.dtype), out], + dimension=axis) + return out + # Quantiles @implements(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index dfb2ff98878d..7050a4e3cd35 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -204,6 +204,7 @@ ) from jax.experimental.array_api._statistical_functions import ( + cumulative_sum as cumulative_sum, max as max, mean as mean, min as min, diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index 2e1333317605..141b80abfb14 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -18,6 +18,10 @@ ) +def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False): + """Calculates the cumulative sum of elements in the input array x.""" + return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial) + def max(x, /, *, axis=None, keepdims=False): """Calculates the maximum value of the input array x.""" return jax.numpy.max(x, axis=axis, keepdims=keepdims) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 0b8a8cdd0277..1b9a990f3a0d 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -296,6 +296,7 @@ count_nonzero as count_nonzero, cumsum as cumsum, cumprod as cumprod, + cumulative_sum as cumulative_sum, max as max, mean as mean, median as median, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 7c7e68a06cde..2740638041cd 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -241,6 +241,9 @@ def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., cumproduct = cumprod def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... +def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., + dtype: DTypeLike | None = ..., + include_initial: bool = ...) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 91c64e74954e..5667c3459dad 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -68,6 +68,7 @@ 'copysign', 'cos', 'cosh', + 'cumulative_sum', 'divide', 'e', 'empty', diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index b72e47d9d179..73100352c544 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -770,5 +770,64 @@ def test_f16_mean(self, dtype): self.assertAllClose(expected, actual, atol=0) + @jtu.sample_product( + [dict(shape=shape, axis=axis) + for shape in all_shapes + for axis in list( + range(-len(shape), len(shape)) + ) + ([None] if len(shape) == 1 else [])], + dtype=all_dtypes + [None], + out_dtype=all_dtypes, + include_initial=[False, True], + ) + @jtu.ignore_warning(category=NumpyComplexWarning) + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + + def np_mock_op(x, axis=None, dtype=None, include_initial=False): + kind = x.dtype.kind + if (dtype is None and kind in {'i', 'u'} + and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)): + dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind]) + axis = axis or 0 + x = x.astype(dtype=dtype or x.dtype) + out = jnp.cumsum(x, axis=axis) + if include_initial: + zeros_shape = list(x.shape) + zeros_shape[axis] = 1 + out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis) + return out + + + # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as + # input because we rely on JAX-specific casting behavior + args_maker = lambda: [jnp.array(rng(shape, dtype))] + np_op = getattr(np, "cumulative_sum", np_mock_op) + kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial) + + np_fun = lambda x: np_op(x, **kwargs) + jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + + @jtu.sample_product( + shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes, + include_initial=[False, True]) + def testCumulativeSumErrors(self, shape, dtype, include_initial): + rng = jtu.rand_some_zero(self.rng()) + x = rng(shape, dtype) + rank = jnp.asarray(x).ndim + if rank == 0: + msg = r"The input must be non-scalar to take" + with self.assertRaisesRegex(ValueError, msg): + jnp.cumulative_sum(x, include_initial=include_initial) + elif rank > 1: + msg = r"The input array has rank \d*, however" + with self.assertRaisesRegex(ValueError, msg): + jnp.cumulative_sum(x, include_initial=include_initial) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())