Skip to content

Commit

Permalink
Raise ValueError for complex inputs to jnp.clip and jnp.hypot.
Browse files Browse the repository at this point in the history
Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases.

PiperOrigin-RevId: 657717875
  • Loading branch information
Jake VanderPlas authored and jax authors committed Jul 30, 2024
1 parent b996612 commit 8bcd288
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 53 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.

* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.

## jaxlib 0.4.32

## jax 0.4.31 (July 29, 2024)
Expand Down
11 changes: 3 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,7 +2316,6 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)

deprecations.register("jax-numpy-clip-complex")

@jit
def clip(
Expand Down Expand Up @@ -2377,15 +2376,11 @@ def clip(

util.check_arraylike("clip", arr)
if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
deprecations.warn(
"jax-numpy-clip-complex",
raise ValueError(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
stacklevel=2)
"Please convert to a real value or array by taking the real or "
"imaginary components via jax.numpy.real/imag respectively.")
if min is not None:
arr = ufuncs.maximum(min, arr)
if max is not None:
Expand Down
14 changes: 3 additions & 11 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import numpy as np

from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
Expand Down Expand Up @@ -1132,9 +1131,6 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))


deprecations.register("jax-numpy-hypot-complex")


@implements(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
Expand All @@ -1143,13 +1139,9 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# TODO(micky774): Promote to ValueError when deprecation is complete
# (began 2024-4-14).
if dtypes.issubdtype(x1.dtype, np.complexfloating):
deprecations.warn(
"jax-numpy-hypot-complex",
"Passing complex-valued inputs to hypot is deprecated and will raise a "
"ValueError in the future. Please convert to real values first, such as "
"by using jnp.real or jnp.imag to take the real or imaginary components "
"respectively.",
stacklevel=2)
raise ValueError(
"jnp.hypot is not well defined for complex-valued inputs. "
"Please convert to real values first, such as by using abs(x)")
x1, x2 = lax.abs(x1), lax.abs(x2)
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
x1, x2 = maximum(x1, x2), minimum(x1, x2)
Expand Down
46 changes: 12 additions & 34 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,58 +896,36 @@ def testClipNone(self, shape, dtype):
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)

# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipComplexInputDeprecation(self, shape):
def testClipComplexInputError(self):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
x = rng((5,), dtype=jnp.complex64)
msg = ".*Complex values have no ordering and cannot be clipped.*"
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-clip-complex"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x)

with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, max=x)

x = rng(shape, dtype=jnp.int32)
with assert_warns_or_errors():
x = rng((5,), dtype=jnp.int32)
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, min=-1+5j)

with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, max=jnp.array([-1+5j]))

# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testHypotComplexInputDeprecation(self, shape):
def testHypotComplexInputError(self):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Passing complex-valued inputs to hypot.*"
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-hypot-complex"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
x = rng((5,), dtype=jnp.complex64)
msg = "jnp.hypot is not well defined for complex-valued inputs.*"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.hypot(x, x)

y = jnp.ones_like(x)
with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.hypot(x, y)

@jtu.sample_product(
Expand Down

0 comments on commit 8bcd288

Please sign in to comment.