From 29a8cce66cfe35f216b958779e975910d280bb5b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 5 Dec 2024 09:27:19 -0800 Subject: [PATCH] jax.numpy: require boolean dtype for where argument --- jax/_src/deprecations.py | 1 + jax/_src/numpy/reductions.py | 33 +++++++++++++++++++++++++++++--- tests/lax_numpy_reducers_test.py | 28 +++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index c7a956068981..778a084e807a 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -130,5 +130,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') +register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') register('pallas-gpu-triton') diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 69d6843f5155..eea734420176 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: return dtypes.int_ return dtype +def check_where(name: str, where: ArrayLike | None) -> Array | None: + if where is None: + return where + check_arraylike(name, where) + where_arr = lax_internal.asarray(where) + if where_arr.dtype != bool: + # Deprecation added 2024-12-05 + deprecations.warn( + 'jax-numpy-reduction-non-boolean-where', + f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.", + stacklevel=2) + return where_arr.astype(bool) + return where_arr + ReductionOp = Callable[[Any, Any], Any] @@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") check_arraylike(name, a) + where_ = check_where(name, where_) dtypes.check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") @@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") + check_arraylike("logsumexp", a) + where = check_where("logsumexp", where) a_arr, = promote_dtypes_inexact(a) pos_dims, dims = _reduction_dims(a_arr, axis) amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf) @@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") + check_arraylike("logsumexp2", a) + where = check_where("logsumexp2", where) ln2 = float(np.log(2)) if initial is not None: initial *= ln2 @@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: check_arraylike("mean", a) + where = check_where("mean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) + where = check_where("var", where) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") @@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) + where = check_where("std", where) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") @@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None, def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, - axis: Axis = None, keepdims: bool = False, **kwargs) -> Array: + axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None, + **kwargs) -> Array: check_arraylike(name, a) + where = check_where(name, where) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): - return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) + return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs) out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a), - axis=axis, keepdims=keepdims, **kwargs) + axis=axis, keepdims=keepdims, where=where, **kwargs) if nan_if_all_nan: return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims), _lax_const(a, np.nan), out) @@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out Array([[nan, nan, nan, nan]], dtype=float32) """ check_arraylike("nanmean", a) + where = check_where("nanmean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): @@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: [4. ]], dtype=float32) """ check_arraylike("nanvar", a) + where = check_where("nanvar", where) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") @@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: Array([[0.5, 0.5, 0. , 0. ]], dtype=float32) """ check_arraylike("nanstd", a) + where = check_where("nanstd", where) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index be6208e6e305..2bef35fbdcef 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -448,6 +448,34 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS) + def testReducerWhereNonBooleanErrorInitial(self, rec): + dtype = rec.dtypes[0] + x = jnp.zeros((10,), dtype) + where = jnp.ones(10, dtype=int) + func = getattr(jnp, rec.name) + def assert_warns_or_errors(msg): + if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + func(x, where=where, initial=jnp.array(0, dtype=dtype)) + + @jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS) + def testReducerWhereNonBooleanErrorNoInitial(self, rec): + dtype = rec.dtypes[0] + x = jnp.zeros((10,), dtype) + where = jnp.ones(10, dtype=int) + func = getattr(jnp, rec.name) + def assert_warns_or_errors(msg): + if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"): + func(x, where=where) + @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( [dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,