Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix compatibility with nightly numpy
Numpy recently merged support for the 2023.12 revision of the Array API: numpy/numpy#26724 This breaks two of our tests: 1. The first breakage was caused by differences in how numpy and JAX cast negative floats to `uint8`. Specifically `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`. We don't make any promises about consistency with casting floats to ints, noting that this can even be backend dependent. To fix our test, we now only generate positive inputs when the output dtype is unsigned. 2. The second failure was caused by the fact that the approach we took in jax-ml#20550 to support backwards compatibility and the Array API for `clip` differs from the one used in numpy/numpy#26724. Again, the behavior is consistent, but it produces a different signature. I've skipped checking `clip`'s signature, but we should revisit it once the `a_min` and `a_max` parameters have been removed from JAX. Fixes jax-ml#22251
- Loading branch information