Skip to content

Commit

Permalink
Fix compatibility with nightly numpy
Browse files Browse the repository at this point in the history
Numpy recently merged support for the 2023.12 revision of the Array API:
numpy/numpy#26724

This breaks two of our tests:

1. The first breakage was caused by differences in how numpy and JAX
   cast negative floats to `uint8`. Specifically
   `np.float32(-1).astype(np.uint8)` returns `np.uint8(255)` whereas
   `jnp.float32(-1).astype(jnp.uint8)` produces `Array(0, dtype=uint8)`.
   We don't make any promises about consistency with casting floats to
   ints, noting that this can even be backend dependent. To fix our
   test, we now only generate positive inputs when the output dtype is
   unsigned.

2. The second failure was caused by the fact that the approach we took
   in jax-ml#20550 to support backwards compatibility and the Array API for
   `clip` differs from the one used in numpy/numpy#26724. Again, the
   behavior is consistent, but it produces a different signature. I've
   skipped checking `clip`'s signature, but we should revisit it once
   the `a_min` and `a_max` parameters have been removed from JAX.

Fixes jax-ml#22251
  • Loading branch information
dfm committed Jul 3, 2024
1 parent ade76f0 commit 9e9acc9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,6 @@ def test_f16_mean(self, dtype):
actual = jnp.mean(x)
self.assertAllClose(expected, actual, atol=0)


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
Expand Down Expand Up @@ -815,10 +814,14 @@ def np_mock_op(x, axis=None, dtype=None, include_initial=False):
out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
return out


# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
# input because we rely on JAX-specific casting behavior
args_maker = lambda: [jnp.array(rng(shape, dtype))]
def args_maker():
x = jnp.array(rng(shape, dtype))
if out_dtype in unsigned_dtypes:
x = 10 * jnp.abs(x)
return [x]

np_op = getattr(np, "cumulative_sum", np_mock_op)
kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial)

Expand All @@ -827,7 +830,6 @@ def np_mock_op(x, axis=None, dtype=None, include_initial=False):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@jtu.sample_product(
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes,
include_initial=[False, True])
Expand Down
10 changes: 10 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5971,6 +5971,7 @@ def testWrappedSignaturesMatch(self):
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
'cumulative_sum': ['out'],
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
Expand Down Expand Up @@ -6021,6 +6022,15 @@ def testWrappedSignaturesMatch(self):
# numpy 1.24 re-orders the density and weights arguments.
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
continue
if name == "clip":
# JAX's support of the Array API spec for clip, and the way it handles
# backwards compatibility was introduced in
# https://github.com/google/jax/pull/20550 with a different signature
# from the one in numpy, introduced in
# https://github.com/numpy/numpy/pull/26724
# TODO(dfm): After our deprecation period for the clip arguments ends
# it should be possible to reintroduce the check.
continue
# Note: can't use inspect.getfullargspec due to numpy issue
# https://github.com/numpy/numpy/issues/12225
try:
Expand Down

0 comments on commit 9e9acc9

Please sign in to comment.