Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise ValueError for complex inputs to jnp.clip and jnp.hypot. #22765

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.

* Deprecations
* Complex inputs to {func}`jax.numpy.clip` and {func}`jax.numpy.hypot` are
no longer allowed, after being deprecated since JAX v0.4.27.

## jaxlib 0.4.32

## jax 0.4.31 (July 29, 2024)
Expand Down
11 changes: 3 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,7 +2316,6 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)

deprecations.register("jax-numpy-clip-complex")

@jit
def clip(
Expand Down Expand Up @@ -2377,15 +2376,11 @@ def clip(

util.check_arraylike("clip", arr)
if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
deprecations.warn(
"jax-numpy-clip-complex",
raise ValueError(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
stacklevel=2)
"Please convert to a real value or array by taking the real or "
"imaginary components via jax.numpy.real/imag respectively.")
if min is not None:
arr = ufuncs.maximum(min, arr)
if max is not None:
Expand Down
14 changes: 3 additions & 11 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import numpy as np

from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src.api import jit
from jax._src.custom_derivatives import custom_jvp
Expand Down Expand Up @@ -1132,9 +1131,6 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_where(lax.gt(x1, zero), _lax_const(x1, 1), x2))


deprecations.register("jax-numpy-hypot-complex")


@implements(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
Expand All @@ -1143,13 +1139,9 @@ def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
# TODO(micky774): Promote to ValueError when deprecation is complete
# (began 2024-4-14).
if dtypes.issubdtype(x1.dtype, np.complexfloating):
deprecations.warn(
"jax-numpy-hypot-complex",
"Passing complex-valued inputs to hypot is deprecated and will raise a "
"ValueError in the future. Please convert to real values first, such as "
"by using jnp.real or jnp.imag to take the real or imaginary components "
"respectively.",
stacklevel=2)
raise ValueError(
"jnp.hypot is not well defined for complex-valued inputs. "
"Please convert to real values first, such as by using abs(x)")
x1, x2 = lax.abs(x1), lax.abs(x2)
idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2))
x1, x2 = maximum(x1, x2), minimum(x1, x2)
Expand Down
46 changes: 12 additions & 34 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,58 +896,36 @@ def testClipNone(self, shape, dtype):
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)

# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipComplexInputDeprecation(self, shape):
def testClipComplexInputError(self):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
x = rng((5,), dtype=jnp.complex64)
msg = ".*Complex values have no ordering and cannot be clipped.*"
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-clip-complex"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x)

with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, max=x)

x = rng(shape, dtype=jnp.int32)
with assert_warns_or_errors():
x = rng((5,), dtype=jnp.int32)
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, min=-1+5j)

with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.clip(x, max=jnp.array([-1+5j]))

# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testHypotComplexInputDeprecation(self, shape):
def testHypotComplexInputError(self):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Passing complex-valued inputs to hypot.*"
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-hypot-complex"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
x = rng((5,), dtype=jnp.complex64)
msg = "jnp.hypot is not well defined for complex-valued inputs.*"
# jit is disabled so we don't miss warnings due to caching.
with jax.disable_jit():
with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.hypot(x, x)

y = jnp.ones_like(x)
with assert_warns_or_errors():
with self.assertRaisesRegex(ValueError, msg):
jnp.hypot(x, y)

@jtu.sample_product(
Expand Down
Loading