Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jnp.clip to Array API 2023 standard and introduces jax.experimental.array_api.clip #20550

Merged
merged 1 commit into from
Apr 5, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Apr 2, 2024

See: clip Array API 2023 specification

Specifically this PR:

  1. Removes unused out keyword argument in jnp.clip to match Array API 2023 standard
  2. Introduces a new DeprecatedArg class in jax._src.typing for easier deprecation with mypy static analysis, and improved documentation
  3. Begins the deprecation of a, a_min, a_max in favor of x, min, max per Array API 2023 standard
  4. Begins the deprecation of accepting complex-valued inputs for x, min, max
  5. Changes the behavior of clip(..., min=None, max=None) from raising a ValueError to instead returning the input unchanged as an array
  6. Introduces jax.experimental.array_api.clip as a thin wrapper around jax.numpy.clip which issues a ValueError when receiving complex input (to be removed when jnp.clip deprecation is complete)
  7. Updates documentation for jnp.clip
  8. Adds tests for both {jnp, array_api}.clip with the intention of shifting to only testing jnp.clip after deprecation

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 2, 2024
@jakevdp jakevdp self-assigned this Apr 2, 2024
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 2, 2024

Sorry, forgot one last thing: this will need a CHANGELOG entry.

CHANGELOG.md Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 3, 2024

Actually, it's going to take a bit of work to land this because it breaks pytype for a number of downstream targets. You can fix the changelog now if you'd like: just remove the note about the experimental change, and make the description of the jnp.clip changes more concise. Something like this:

* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and `a_max` are deprecated
  in favor of `x` (positonal only), `min`, and `max` ({jax-issue}`20550`)

jax/numpy/__init__.pyi Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's run some internal tests and see if this does it...

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 4, 2024

Sorry we weren't able to get this in yesterday – it was blocked by an unrelated failure. In the meantime the 0.4.26 release happened, so the CHANGELOG needs to be updated here to reflect that this will be part of 0.4.27.

@Micky774
Copy link
Collaborator Author

Micky774 commented Apr 4, 2024

Sorry we weren't able to get this in yesterday – it was blocked by an unrelated failure. In the meantime the 0.4.26 release happened, so the CHANGELOG needs to be updated here to reflect that this will be part of 0.4.27.

Haha I missed that. Alright, the changelog has been updated :)

@copybara-service copybara-service bot merged commit 2512843 into jax-ml:main Apr 5, 2024
10 checks passed
@Micky774 Micky774 deleted the api_clip branch April 7, 2024 23:01
@fehiepsi
Copy link
Contributor

Re a_min/a_max -> min/max: Currently, numpy.clip still uses a_min/a_max. Do you think it will adopt the standard Array API in the near future? The deprecation warning is a bit surprised to me.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 21, 2024

Numpy is beginning to address this in #26724: it looks like they plan to continue supporting both names for the min/max arguments.

@fehiepsi
Copy link
Contributor

Thanks for the info, Jake!

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 25, 2024

Should we keep supporting the a_min and a_max aliases for min and max? It looks like NumPy is going that direction.

dfm added a commit to dfm/jax that referenced this pull request Jul 3, 2024
Numpy recently merged support for the 2023.12 revision of the Array API:
numpy/numpy#26724

This breaks two of our tests and I've chosen to skip those tests for
now:

1. The first breakage was caused by differences in how numpy and JAX
   cast negative floats to `uint8`. Specifically
   `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
   `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
   We don't make any promises about consistency with casting floats to
   ints, noting that this can even be backend dependent. I don't believe
   this failure is identifying any unexpected behavior, and we test many
   other dtypes properly so I'm not concerned about skipping this test.

2. The second failure was caused by the fact that the approach we took
   in jax-ml#20550 to support backwards compatibility and the Array API for
   `clip` differs from the one used in numpy/numpy#26724. Again, the
   behavior is consistent, but it produces a different signature. I've
   skipped checking `clip`'s signature, but we should revisit it once
   the `a_min` and `a_max` parameters have been removed from JAX.

Fixes jax-ml#22251
dfm added a commit to dfm/jax that referenced this pull request Jul 3, 2024
Numpy recently merged support for the 2023.12 revision of the Array API:
numpy/numpy#26724

This breaks two of our tests:

1. The first breakage was caused by differences in how numpy and JAX
   cast negative floats to `uint8`. Specifically
   `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
   `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
   We don't make any promises about consistency with casting floats to
   ints, noting that this can even be backend dependent. To fix our
   test, we now only generate positive inputs when the output dtype is
   unsigned.

2. The second failure was caused by the fact that the approach we took
   in jax-ml#20550 to support backwards compatibility and the Array API for
   `clip` differs from the one used in numpy/numpy#26724. Again, the
   behavior is consistent, but it produces a different signature. I've
   skipped checking `clip`'s signature, but we should revisit it once
   the `a_min` and `a_max` parameters have been removed from JAX.

Fixes jax-ml#22251
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants