From 8b7aae586b8aa83cef1f5c66e81fc43f017dcd18 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 4 Apr 2024 22:55:10 +0000 Subject: [PATCH] Update `jnp.clip` to Array API 2023 standard --- CHANGELOG.md | 3 + jax/_src/numpy/array_methods.py | 5 +- jax/_src/numpy/lax_numpy.py | 74 +++++++++++++++---- jax/_src/scipy/special.py | 2 +- jax/_src/typing.py | 6 ++ jax/experimental/array_api/__init__.py | 1 + .../array_api/_elementwise_functions.py | 16 ++++ .../jax2tf/tests/jax2tf_limitations.py | 2 +- jax/experimental/ode.py | 4 +- jax/numpy/__init__.pyi | 16 +++- tests/array_api_test.py | 23 ++++++ tests/lax_metal_test.py | 2 +- tests/lax_numpy_test.py | 44 +++++++++-- 13 files changed, 168 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35d93eeae3c4..786020d7d548 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ Remember to align the itemized text with the first line of an item within a list * Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the `JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect. + * {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and + `a_max` are deprecated in favor of `x` (positonal only), `min`, and + `max` ({jax-issue}`20550`). ## jaxlib 0.4.27 diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index fb1e52bd1be9..98eea8887198 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int: def _clip(number: ArrayLike, - min: ArrayLike | None = None, max: ArrayLike | None = None, - out: None = None) -> Array: + min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: """Return an array whose values are limited to a specified range. Refer to :func:`jax.numpy.clip` for full documentation.""" - return lax_numpy.clip(number, a_min=min, a_max=max, out=out) + return lax_numpy.clip(number, min=min, max=max) def _transpose(a: Array, *args: Any) -> Array: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 826251c82945..ea721581247a 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -66,7 +66,10 @@ from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.numpy.vectorize import vectorize -from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape +from jax._src.typing import ( + Array, ArrayLike, DimSize, DuckTypedArray, + DType, DTypeLike, Shape, DeprecatedArg +) from jax._src.util import (unzip2, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis, @@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array axis: int = 0) -> list[Array]: return _split("array_split", ary, indices_or_sections, axis=axis) -@util.implements(np.clip, skip_params=['out']) + +_DEPRECATED_CLIP_ARG = DeprecatedArg() +@util.implements( + np.clip, + skip_params=['a', 'a_min'], + extra_params=_dedent(""" + x : array_like + Array containing elements to clip. + min : array_like, optional + Minimum value. If ``None``, clipping is not performed on the + corresponding edge. The value of ``min`` is broadcast against x. + max : array_like, optional + Maximum value. If ``None``, clipping is not performed on the + corresponding edge. The value of ``max`` is broadcast against x. +""") +) @jit -def clip(a: ArrayLike, a_min: ArrayLike | None = None, - a_max: ArrayLike | None = None, out: None = None) -> Array: - util.check_arraylike("clip", a) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.clip is not supported.") - if a_min is None and a_max is None: - raise ValueError("At most one of a_min and a_max may be None") - if a_min is not None: - a = ufuncs.maximum(a_min, a) - if a_max is not None: - a = ufuncs.minimum(a_max, a) - return asarray(a) +def clip( + x: ArrayLike | None = None, # Default to preserve backwards compatability + /, + min: ArrayLike | None = None, + max: ArrayLike | None = None, + *, + a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG, + a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG, + a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG +) -> Array: + # TODO(micky774): deprecated 2024-4-2, remove after deprecation expires. + x = a if not isinstance(a, DeprecatedArg) else x + if x is None: + raise ValueError("No input was provided to the clip function.") + min = a_min if not isinstance(a_min, DeprecatedArg) else min + max = a_max if not isinstance(a_max, DeprecatedArg) else max + if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)): + warnings.warn( + "Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is " + "deprecated. Please use 'x', 'min', and 'max' respectively instead.", + DeprecationWarning, + stacklevel=2, + ) + + util.check_arraylike("clip", x) + if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)): + # TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires. + warnings.warn( + "Clip received a complex value either through the input or the min/max " + "keywords. Complex values have no ordering and cannot be clipped. " + "Attempting to clip using complex numbers is deprecated and will soon " + "raise a ValueError. Please convert to a real value or array by taking " + "the real or imaginary components via jax.numpy.real/imag respectively.", + DeprecationWarning, stacklevel=2, + ) + if min is not None: + x = ufuncs.maximum(min, x) + if max is not None: + x = ufuncs.minimum(max, x) + return asarray(x) @util.implements(np.around, skip_params=['out']) @partial(jit, static_argnames=('decimals',)) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index d4aced143016..e973bb30c0a5 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -301,7 +301,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] - T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) + T1 = jnp.clip(T1, max=jnp.finfo(dtype).max) coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 8ae276b37457..6cd466500f71 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -77,3 +77,9 @@ def shape(self) -> Shape: ... # JAX array (i.e. not including future non-standard array types like KeyArray and BInt). # It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences, # nor does it accept string data. + +# We use a class for deprecated args to avoid using Any/object types which can +# introduce complications and mistakes in static analysis +class DeprecatedArg: + def __repr__(self): + return "Deprecated" diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 3169f9667256..ab54f8c09dfd 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -112,6 +112,7 @@ bitwise_right_shift as bitwise_right_shift, bitwise_xor as bitwise_xor, ceil as ceil, + clip as clip, conj as conj, cos as cos, cosh as cosh, diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 1352cd5b0b3e..6d5a4ee7fe2b 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -125,6 +125,22 @@ def ceil(x, /): return jax.numpy.ceil(x) +def clip(x, /, min=None, max=None): + """Returns the complex conjugate for each element x_i of the input array x.""" + x, = _promote_dtypes("clip", x) + + # TODO(micky774): Remove when jnp.clip deprecation is completed + # (began 2024-4-2) and default behavior is Array API 2023 compliant + if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)): + raise ValueError( + "Clip received a complex value either through the input or the min/max " + "keywords. Complex values have no ordering and cannot be clipped. " + "Please convert to a real value or array by taking the real or " + "imaginary components via jax.numpy.real/imag respectively." + ) + return jax.numpy.clip(x, min=min, max=max) + + def conj(x, /): """Returns the complex conjugate for each element x_i of the input array x.""" x, = _promote_dtypes("conj", x) diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 8265a38e29d2..2c314a3e9218 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1283,7 +1283,7 @@ def dot_column_wise(a, b): # values like 1.0000001 on float32, which are clipped to 1.0. It is # possible that anything other than `cos_angular_diff` can be outside # the interval [0, 1] due to roundoff. - cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0) + cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0) angular_diff = jnp.arccos(cos_angular_diff) diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index 81a9630cbdfe..b8e3daee48c8 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -201,7 +201,7 @@ def body_fun(state): next_t = t + dt error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y) new_interp_coeff = interp_fit_dopri(y, next_y, k, dt) - dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax) + dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax) new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff] old = [i + 1, y, f, t, dt, last_t, interp_coeff] @@ -214,7 +214,7 @@ def body_fun(state): return carry, y_target f0 = func_(y0, ts[0]) - dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax) + dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax) interp_coeff = jnp.array([y0] * 5) init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff] _, ys = lax.scan(scan_fun, init_carry, ts[1:]) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 9ed5f39b393e..706bd35335d5 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -9,7 +9,10 @@ from jax._src import dtypes as _dtypes from jax._src.lax.lax import PrecisionLike from jax._src.lax.slicing import GatherScatterMode from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass -from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape +from jax._src.typing import ( + Array, ArrayLike, DType, DTypeLike, + DimSize, DuckTypedArray, Shape, DeprecatedArg +) from jax.numpy import fft as fft, linalg as linalg from jax.sharding import Sharding as _Sharding import numpy as _np @@ -181,8 +184,15 @@ def ceil(x: ArrayLike, /) -> Array: ... character = _np.character def choose(a: ArrayLike, choices: Sequence[ArrayLike], out: None = ..., mode: str = ...) -> Array: ... -def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ..., - a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ... +def clip( + x: ArrayLike | None = ..., + /, + min: Optional[ArrayLike] = ..., + max: Optional[ArrayLike] = ..., + a: ArrayLike | DeprecatedArg | None = ..., + a_min: ArrayLike | DeprecatedArg | None = ..., + a_max: ArrayLike | DeprecatedArg | None = ... +) -> Array: ... def column_stack( tup: Union[_np.ndarray, Array, Sequence[ArrayLike]] ) -> Array: ... diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 0d4893e4939e..75ac74cfe924 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -58,6 +58,7 @@ 'broadcast_to', 'can_cast', 'ceil', + 'clip', 'complex128', 'complex64', 'concat', @@ -233,5 +234,27 @@ def test_array_namespace_method(self): self.assertIs(x.__array_namespace__(), array_api) +class ArrayAPIErrors(absltest.TestCase): + """Test that our array API implementations raise errors where required""" + + # TODO(micky774): Remove when jnp.clip deprecation is completed + # (began 2024-4-2) and default behavior is Array API 2023 compliant + def test_clip_complex(self): + x = array_api.arange(5, dtype=array_api.complex64) + complex_msg = "Complex values have no ordering and cannot be clipped" + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x) + + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x, max=x) + + x = array_api.arange(5, dtype=array_api.int32) + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x, min=-1+5j) + + with self.assertRaisesRegex(ValueError, complex_msg): + array_api.clip(x, max=-1+5j) + + if __name__ == '__main__': absltest.main() diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index c84d3b1b66d6..0f7b5d16d626 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -877,7 +877,7 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max): a_max = None if a_max is None else abs(a_max) rng = jtu.rand_default(self.rng()) np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) - jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max) + jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ab52cd6ed1c9..0451122490ca 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -872,14 +872,45 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max): a_max = None if a_max is None else abs(a_max) rng = jtu.rand_default(self.rng()) np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) - jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max) + jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - def testClipError(self): - with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): - jnp.clip(jnp.zeros((3,))) + + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes + unsigned_dtypes, + ) + def testClipNone(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + self.assertArraysEqual(jnp.clip(x), x) + + + # TODO(micky774): Check for ValueError instead of DeprecationWarning when + # jnp.clip deprecation is completed (began 2024-4-2) and default behavior is + # Array API 2023 compliant + @jtu.sample_product(shape=all_shapes) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testClipComplexInputDeprecation(self, shape): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype=jnp.complex64) + msg = "Complex values have no ordering and cannot be clipped" + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x) + + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x, max=x) + + x = rng(shape, dtype=jnp.int32) + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x, min=-1+5j) + + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.clip(x, max=jnp.array([-1+5j])) + @jtu.sample_product( [dict(shape=shape, dtype=dtype) @@ -5772,7 +5803,7 @@ def testWrappedSignaturesMatch(self): 'argpartition': ['kind', 'order'], 'asarray': ['like'], 'broadcast_to': ['subok'], - 'clip': ['kwargs'], + 'clip': ['kwargs', 'out'], 'copy': ['subok'], 'corrcoef': ['ddof', 'bias', 'dtype'], 'cov': ['dtype'], @@ -5809,6 +5840,9 @@ def testWrappedSignaturesMatch(self): } extra_params = { + # TODO(micky774): Remove when np.clip has adopted the Array API 2023 + # standard + 'clip': ['x', 'max', 'min'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], 'take_along_axis': ['mode'],