Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new cumulative_sum function to numpy and array_api namespaces #20756

Merged
merged 1 commit into from
Apr 16, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Apr 15, 2024

Towards #20200

Adds a new cumulative_sum function to the jax.numpy and jax.experimental.array_api namespaces in compliance with the array API 2023 standard.

@Micky774 Micky774 force-pushed the array-api-cumulative-sum branch from 1bfd146 to b6f48c0 Compare April 15, 2024 14:44
@Micky774 Micky774 marked this pull request as ready for review April 15, 2024 14:45
@Micky774 Micky774 force-pushed the array-api-cumulative-sum branch from b6f48c0 to f5fa5c1 Compare April 15, 2024 14:52
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Also please add the new function to the docs by listing it here: https://github.com/google/jax/blob/main/docs/jax.numpy.rst

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
@jakevdp jakevdp self-assigned this Apr 15, 2024
@Micky774 Micky774 force-pushed the array-api-cumulative-sum branch from f5fa5c1 to 10d2ede Compare April 15, 2024 23:41
CHANGELOG.md Outdated Show resolved Hide resolved
tests/lax_numpy_reducers_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_reducers_test.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array-api-cumulative-sum branch 3 times, most recently from f111aa3 to ff1a526 Compare April 16, 2024 19:35
@Micky774 Micky774 force-pushed the array-api-cumulative-sum branch from ff1a526 to ceeb975 Compare April 16, 2024 19:58
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 16, 2024
if (dtype is None and kind in {'i', 'u'}
and x.dtype.itemsize*8 < int(config.default_dtype_bits.value)):
dtype = dtypes.canonicalize_dtype(dtypes._default_types[kind])
x = x.astype(dtype=dtype or x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-note: referencing the discussion in #20195, this is an example of a place where we wouldn't want copy in astype to default to True!

@copybara-service copybara-service bot merged commit 47815c5 into jax-ml:main Apr 16, 2024
14 checks passed
@Micky774 Micky774 deleted the array-api-cumulative-sum branch April 17, 2024 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants