Skip to content

Commit

Permalink
Add new cumulative_sum function to numpy and array_api namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 15, 2024
1 parent 2c85ca6 commit 1bfd146
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 0 deletions.
24 changes: 24 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5567,3 +5567,27 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
else:
raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}")
return arr.at[unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)


@util.implements(getattr(np, 'cumulative_sum', None))
def cumulative_sum(
x: Array, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
include_initial: bool = False) -> Array:
if axis is None and x.ndim > 1:
raise ValueError(
f"The input array has rank {x.ndim}, however axis was not set to an "
"explicit value. The axis argument is only optional for one-dimensional "
"arrays.")
util.check_arraylike("cumulative_sum", x)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
default_dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize < default_dtype.itemsize):
dtype = default_dtype

out = reductions.cumsum(x, axis=axis, dtype=dtype)
zeros_shape = list(x.shape)
zeros_shape[axis if axis else 0] = 1
return append(zeros(zeros_shape, dtype=out.dtype), out, axis=axis) if include_initial else out
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
)

from jax.experimental.array_api._statistical_functions import (
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
min as min,
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
)


def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
"""Calculates the cumulative sum of elements in the input array x."""
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)

def max(x, /, *, axis=None, keepdims=False):
"""Calculates the maximum value of the input array x."""
return jax.numpy.max(x, axis=axis, keepdims=keepdims)
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
cov as cov,
cross as cross,
csingle as csingle,
cumulative_sum as cumulative_sum,
delete as delete,
diag as diag,
diagflat as diagflat,
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'conj',
'cos',
'cosh',
'cumulative_sum',
'divide',
'e',
'empty',
Expand Down

0 comments on commit 1bfd146

Please sign in to comment.