Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new cumulative_sum function to numpy and array_api namespaces #20756

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.27

* New Functionality
* Added {func}`jax.numpy.unstack`, following the addition of this function in
the array API 2023 standard, soon to be adopted by NumPy.
* Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.

* Changes
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
Expand Down
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ namespace; they are listed below.
csingle
cumprod
cumsum
cumulative_sum
deg2rad
degrees
delete
Expand Down
38 changes: 37 additions & 1 deletion jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from jax import lax
from jax._src import api
from jax._src import core
from jax._src import core, config
from jax._src import dtypes
from jax._src.numpy import ufuncs
from jax._src.numpy.util import (
Expand Down Expand Up @@ -708,6 +708,42 @@ def _cumulative_reduction(a: ArrayLike, axis: Axis = None,
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)

@implements(getattr(np, 'cumulative_sum', None))
def cumulative_sum(
x: ArrayLike, /, *, axis: int | None = None,
dtype: DTypeLike | None = None,
include_initial: bool = False) -> Array:
check_arraylike("cumulative_sum", x)
x = lax_internal.asarray(x)
if x.ndim == 0:
raise ValueError(
"The input must be non-scalar to take a cumulative sum, however a "
"scalar value or scalar array was given."
)
if axis is None:
axis = 0
if x.ndim > 1:
raise ValueError(
f"The input array has rank {x.ndim}, however axis was not set to an "
"explicit value. The axis argument is only optional for one-dimensional "
"arrays.")

axis = _canonicalize_axis(axis, x.ndim)
dtypes.check_user_dtype_supported(dtype)
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
x = x.astype(dtype=dtype or x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-note: referencing the discussion in #20195, this is an example of a place where we wouldn't want copy in astype to default to True!

out = cumsum(x, axis=axis)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
out = lax_internal.concatenate(
[lax_internal.full(zeros_shape, 0, dtype=out.dtype), out],
dimension=axis)
return out

# Quantiles
@implements(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
)

from jax.experimental.array_api._statistical_functions import (
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
min as min,
Expand Down
4 changes: 4 additions & 0 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
)


def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False):
"""Calculates the cumulative sum of elements in the input array x."""
return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial)

def max(x, /, *, axis=None, keepdims=False):
"""Calculates the maximum value of the input array x."""
return jax.numpy.max(x, axis=axis, keepdims=keepdims)
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@
count_nonzero as count_nonzero,
cumsum as cumsum,
cumprod as cumprod,
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
median as median,
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
cumproduct = cumprod
def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ...) -> Array: ...
def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ...,
dtype: DTypeLike | None = ...,
include_initial: bool = ...) -> Array: ...

def deg2rad(x: ArrayLike, /) -> Array: ...
degrees = rad2deg
Expand Down
1 change: 1 addition & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'copysign',
'cos',
'cosh',
'cumulative_sum',
'divide',
'e',
'empty',
Expand Down
59 changes: 59 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,5 +770,64 @@ def test_f16_mean(self, dtype):
self.assertAllClose(expected, actual, atol=0)


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
for axis in list(
range(-len(shape), len(shape))
) + ([None] if len(shape) == 1 else [])],
dtype=all_dtypes + [None],
out_dtype=all_dtypes,
include_initial=[False, True],
)
@jtu.ignore_warning(category=NumpyComplexWarning)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())

def np_mock_op(x, axis=None, dtype=None, include_initial=False):
kind = x.dtype.kind
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
axis = axis or 0
x = x.astype(dtype=dtype or x.dtype)
out = jnp.cumsum(x, axis=axis)
if include_initial:
zeros_shape = list(x.shape)
zeros_shape[axis] = 1
out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
return out


# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
# input because we rely on JAX-specific casting behavior
args_maker = lambda: [jnp.array(rng(shape, dtype))]
np_op = getattr(np, "cumulative_sum", np_mock_op)
kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial)

np_fun = lambda x: np_op(x, **kwargs)
jnp_fun = lambda x: jnp.cumulative_sum(x, **kwargs)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@jtu.sample_product(
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes,
include_initial=[False, True])
def testCumulativeSumErrors(self, shape, dtype, include_initial):
rng = jtu.rand_some_zero(self.rng())
x = rng(shape, dtype)
rank = jnp.asarray(x).ndim
if rank == 0:
msg = r"The input must be non-scalar to take"
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)
elif rank > 1:
msg = r"The input array has rank \d*, however"
with self.assertRaisesRegex(ValueError, msg):
jnp.cumulative_sum(x, include_initial=include_initial)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())