From 8bcd2886212240148d4c3088e074848eac45a6c3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 30 Jul 2024 13:48:51 -0700 Subject: [PATCH] Raise ValueError for complex inputs to jnp.clip and jnp.hypot. Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases. PiperOrigin-RevId: 657717875 --- CHANGELOG.md | 4 ++++ jax/_src/numpy/lax_numpy.py | 11 +++------ jax/_src/numpy/ufuncs.py | 14 +++-------- tests/lax_numpy_test.py | 46 ++++++++++--------------------------- 4 files changed, 22 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e8f4645926aa..a7e83f0a4854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the `stablehlo` dialect instead. +* Deprecations + * Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are + no longer allowed, after being deprecated since JAX v0.4.27. + ## jaxlib 0.4.32 ## jax 0.4.31 (July 29, 2024) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3ec6359fe43a..c4a9ee17fe40 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2316,7 +2316,6 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array axis: int = 0) -> list[Array]: return _split("array_split", ary, indices_or_sections, axis=axis) -deprecations.register("jax-numpy-clip-complex") @jit def clip( @@ -2377,15 +2376,11 @@ def clip( util.check_arraylike("clip", arr) if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)): - # TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires. - deprecations.warn( - "jax-numpy-clip-complex", + raise ValueError( "Clip received a complex value either through the input or the min/max " "keywords. Complex values have no ordering and cannot be clipped. " - "Attempting to clip using complex numbers is deprecated and will soon " - "raise a ValueError. Please convert to a real value or array by taking " - "the real or imaginary components via jax.numpy.real/imag respectively.", - stacklevel=2) + "Please convert to a real value or array by taking the real or " + "imaginary components via jax.numpy.real/imag respectively.") if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 891d79fb5128..531d0bec813f 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -25,7 +25,6 @@ import numpy as np from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src.api import jit from jax._src.custom_derivatives import custom_jvp @@ -1132,9 +1131,6 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) -deprecations.register("jax-numpy-hypot-complex") - - @implements(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -1143,13 +1139,9 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: # TODO(micky774): Promote to ValueError when deprecation is complete # (began 2024-4-14). if dtypes.issubdtype(x1.dtype, np.complexfloating): - deprecations.warn( - "jax-numpy-hypot-complex", - "Passing complex-valued inputs to hypot is deprecated and will raise a " - "ValueError in the future. Please convert to real values first, such as " - "by using jnp.real or jnp.imag to take the real or imaginary components " - "respectively.", - stacklevel=2) + raise ValueError( + "jnp.hypot is not well defined for complex-valued inputs. " + "Please convert to real values first, such as by using abs(x)") x1, x2 = lax.abs(x1), lax.abs(x2) idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2)) x1, x2 = maximum(x1, x2), minimum(x1, x2) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c9c7e3e15548..d4b5d2bb00ed 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -896,58 +896,36 @@ def testClipNone(self, shape, dtype): x = rng(shape, dtype) self.assertArraysEqual(jnp.clip(x), x) - # TODO(micky774): Check for ValueError instead of DeprecationWarning when - # jnp.clip deprecation is completed (began 2024-4-2) and default behavior is - # Array API 2023 compliant - @jtu.sample_product(shape=all_shapes) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testClipComplexInputDeprecation(self, shape): + def testClipComplexInputError(self): rng = jtu.rand_default(self.rng()) - x = rng(shape, dtype=jnp.complex64) + x = rng((5,), dtype=jnp.complex64) msg = ".*Complex values have no ordering and cannot be clipped.*" - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-clip-complex"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) # jit is disabled so we don't miss warnings due to caching. with jax.disable_jit(): - with assert_warns_or_errors(): + with self.assertRaisesRegex(ValueError, msg): jnp.clip(x) - with assert_warns_or_errors(): + with self.assertRaisesRegex(ValueError, msg): jnp.clip(x, max=x) - x = rng(shape, dtype=jnp.int32) - with assert_warns_or_errors(): + x = rng((5,), dtype=jnp.int32) + with self.assertRaisesRegex(ValueError, msg): jnp.clip(x, min=-1+5j) - with assert_warns_or_errors(): + with self.assertRaisesRegex(ValueError, msg): jnp.clip(x, max=jnp.array([-1+5j])) - # TODO(micky774): Check for ValueError instead of DeprecationWarning when - # jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is - # Array API 2023 compliant - @jtu.sample_product(shape=all_shapes) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testHypotComplexInputDeprecation(self, shape): + def testHypotComplexInputError(self): rng = jtu.rand_default(self.rng()) - x = rng(shape, dtype=jnp.complex64) - msg = "Passing complex-valued inputs to hypot.*" - def assert_warns_or_errors(msg=msg): - if deprecations.is_accelerated("jax-numpy-hypot-complex"): - return self.assertRaisesRegex(ValueError, msg) - else: - return self.assertWarnsRegex(DeprecationWarning, msg) + x = rng((5,), dtype=jnp.complex64) + msg = "jnp.hypot is not well defined for complex-valued inputs.*" # jit is disabled so we don't miss warnings due to caching. with jax.disable_jit(): - with assert_warns_or_errors(): + with self.assertRaisesRegex(ValueError, msg): jnp.hypot(x, x) y = jnp.ones_like(x) - with assert_warns_or_errors(): + with self.assertRaisesRegex(ValueError, msg): jnp.hypot(x, y) @jtu.sample_product(