-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update jnp.clip
to Array API 2023 standard and introduces jax.experimental.array_api.clip
#20550
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Sorry, forgot one last thing: this will need a CHANGELOG entry. |
Actually, it's going to take a bit of work to land this because it breaks pytype for a number of downstream targets. You can fix the changelog now if you'd like: just remove the note about the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Let's run some internal tests and see if this does it...
Sorry we weren't able to get this in yesterday – it was blocked by an unrelated failure. In the meantime the 0.4.26 release happened, so the CHANGELOG needs to be updated here to reflect that this will be part of 0.4.27. |
Haha I missed that. Alright, the changelog has been updated :) |
Re |
Numpy is beginning to address this in #26724: it looks like they plan to continue supporting both names for the min/max arguments. |
Thanks for the info, Jake! |
Should we keep supporting the |
Numpy recently merged support for the 2023.12 revision of the Array API: numpy/numpy#26724 This breaks two of our tests and I've chosen to skip those tests for now: 1. The first breakage was caused by differences in how numpy and JAX cast negative floats to `uint8`. Specifically `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`. We don't make any promises about consistency with casting floats to ints, noting that this can even be backend dependent. I don't believe this failure is identifying any unexpected behavior, and we test many other dtypes properly so I'm not concerned about skipping this test. 2. The second failure was caused by the fact that the approach we took in jax-ml#20550 to support backwards compatibility and the Array API for `clip` differs from the one used in numpy/numpy#26724. Again, the behavior is consistent, but it produces a different signature. I've skipped checking `clip`'s signature, but we should revisit it once the `a_min` and `a_max` parameters have been removed from JAX. Fixes jax-ml#22251
Numpy recently merged support for the 2023.12 revision of the Array API: numpy/numpy#26724 This breaks two of our tests: 1. The first breakage was caused by differences in how numpy and JAX cast negative floats to `uint8`. Specifically `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`. We don't make any promises about consistency with casting floats to ints, noting that this can even be backend dependent. To fix our test, we now only generate positive inputs when the output dtype is unsigned. 2. The second failure was caused by the fact that the approach we took in jax-ml#20550 to support backwards compatibility and the Array API for `clip` differs from the one used in numpy/numpy#26724. Again, the behavior is consistent, but it produces a different signature. I've skipped checking `clip`'s signature, but we should revisit it once the `a_min` and `a_max` parameters have been removed from JAX. Fixes jax-ml#22251
See:
clip
Array API 2023 specificationSpecifically this PR:
out
keyword argument injnp.clip
to match Array API 2023 standardDeprecatedArg
class injax._src.typing
for easier deprecation with mypy static analysis, and improved documentationa, a_min, a_max
in favor ofx, min, max
per Array API 2023 standardx, min, max
clip(..., min=None, max=None)
from raising aValueError
to instead returning the input unchanged as an arrayjax.experimental.array_api.clip
as a thin wrapper aroundjax.numpy.clip
which issues aValueError
when receiving complex input (to be removed whenjnp.clip
deprecation is complete)jnp.clip
{jnp, array_api}.clip
with the intention of shifting to only testingjnp.clip
after deprecation