diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b23ca210b88..ace8259eb7dd 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5567,3 +5567,27 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, else: raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}") return arr.at[unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode) + + +@util.implements(getattr(np, 'cumulative_sum', None)) +def cumulative_sum( + x: Array, /, *, axis: int | None = None, + dtype: DTypeLike | None = None, + include_initial: bool = False) -> Array: + if axis is None and x.ndim > 1: + raise ValueError( + f"The input array has rank {x.ndim}, however axis was not set to an " + "explicit value. The axis argument is only optional for one-dimensional " + "arrays.") + util.check_arraylike("cumulative_sum", x) + dtypes.check_user_dtype_supported(dtype) + kind = x.dtype.kind + default_dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind]) + if (dtype is None and kind in {'i', 'u'} + and x.dtype.itemsize < default_dtype.itemsize): + dtype = default_dtype + + out = reductions.cumsum(x, axis=axis, dtype=dtype) + zeros_shape = list(x.shape) + zeros_shape[axis if axis else 0] = 1 + return append(zeros(zeros_shape, dtype=out.dtype), out, axis=axis) if include_initial else out diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index d59a9eec751e..2bd69ca09d0f 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -195,6 +195,7 @@ ) from jax.experimental.array_api._statistical_functions import ( + cumulative_sum as cumulative_sum, max as max, mean as mean, min as min, diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index 2e1333317605..141b80abfb14 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -18,6 +18,10 @@ ) +def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False): + """Calculates the cumulative sum of elements in the input array x.""" + return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial) + def max(x, /, *, axis=None, keepdims=False): """Calculates the maximum value of the input array x.""" return jax.numpy.max(x, axis=axis, keepdims=keepdims) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 273c5a2aa80f..94687dae3eaa 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -78,6 +78,7 @@ cov as cov, cross as cross, csingle as csingle, + cumulative_sum as cumulative_sum, delete as delete, diag as diag, diagflat as diagflat, diff --git a/tests/array_api_test.py b/tests/array_api_test.py index b11fc35845ff..2f004cb2137e 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -67,6 +67,7 @@ 'conj', 'cos', 'cosh', + 'cumulative_sum', 'divide', 'e', 'empty',