Skip to content

Commit

Permalink
Merge pull request jax-ml#20550 from Micky774:api_clip
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622045823
  • Loading branch information
jax authors committed Apr 5, 2024
2 parents f37e503 + 8b7aae5 commit 2512843
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 30 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Remember to align the itemized text with the first line of an item within a list
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).


## jaxlib 0.4.27
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:


def _clip(number: ArrayLike,
min: ArrayLike | None = None, max: ArrayLike | None = None,
out: None = None) -> Array:
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
"""Return an array whose values are limited to a specified range.
Refer to :func:`jax.numpy.clip` for full documentation."""
return lax_numpy.clip(number, a_min=min, a_max=max, out=out)
return lax_numpy.clip(number, min=min, max=max)


def _transpose(a: Array, *args: Any) -> Array:
Expand Down
74 changes: 60 additions & 14 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.typing import (
Array, ArrayLike, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, DeprecatedArg
)
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis,
Expand Down Expand Up @@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)

@util.implements(np.clip, skip_params=['out'])

_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit
def clip(a: ArrayLike, a_min: ArrayLike | None = None,
a_max: ArrayLike | None = None, out: None = None) -> Array:
util.check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = ufuncs.maximum(a_min, a)
if a_max is not None:
a = ufuncs.minimum(a_max, a)
return asarray(a)
def clip(
x: ArrayLike | None = None, # Default to preserve backwards compatability
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None,
*,
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
) -> Array:
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)

util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)

@util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
s_over_a = (s_ + m) / (a_ + N)
T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
T1 = jnp.clip(T1, max=jnp.finfo(dtype).max)
coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
tuple(range(a.ndim)))
T1 = T1 / coefs
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def shape(self) -> Shape: ...
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
# nor does it accept string data.

# We use a class for deprecated args to avoid using Any/object types which can
# introduce complications and mistakes in static analysis
class DeprecatedArg:
def __repr__(self):
return "Deprecated"
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
clip as clip,
conj as conj,
cos as cos,
cosh as cosh,
Expand Down
16 changes: 16 additions & 0 deletions jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ def ceil(x, /):
return jax.numpy.ceil(x)


def clip(x, /, min=None, max=None):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("clip", x)

# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
raise ValueError(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Please convert to a real value or array by taking the real or "
"imaginary components via jax.numpy.real/imag respectively."
)
return jax.numpy.clip(x, min=min, max=max)


def conj(x, /):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("conj", x)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ def dot_column_wise(a, b):
# values like 1.0000001 on float32, which are clipped to 1.0. It is
# possible that anything other than `cos_angular_diff` can be outside
# the interval [0, 1] due to roundoff.
cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0)
cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0)

angular_diff = jnp.arccos(cos_angular_diff)

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def body_fun(state):
next_t = t + dt
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax)
dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax)

new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
Expand All @@ -214,7 +214,7 @@ def body_fun(state):
return carry, y_target

f0 = func_(y0, ts[0])
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax)
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax)
interp_coeff = jnp.array([y0] * 5)
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
Expand Down
16 changes: 13 additions & 3 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ from jax._src import dtypes as _dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
from jax._src.typing import (
Array, ArrayLike, DType, DTypeLike,
DimSize, DuckTypedArray, Shape, DeprecatedArg
)
from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding
import numpy as _np
Expand Down Expand Up @@ -181,8 +184,15 @@ def ceil(x: ArrayLike, /) -> Array: ...
character = _np.character
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = ..., mode: str = ...) -> Array: ...
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ...,
a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ...
def clip(
x: ArrayLike | None = ...,
/,
min: Optional[ArrayLike] = ...,
max: Optional[ArrayLike] = ...,
a: ArrayLike | DeprecatedArg | None = ...,
a_min: ArrayLike | DeprecatedArg | None = ...,
a_max: ArrayLike | DeprecatedArg | None = ...
) -> Array: ...
def column_stack(
tup: Union[_np.ndarray, Array, Sequence[ArrayLike]]
) -> Array: ...
Expand Down
23 changes: 23 additions & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
'broadcast_to',
'can_cast',
'ceil',
'clip',
'complex128',
'complex64',
'concat',
Expand Down Expand Up @@ -233,5 +234,27 @@ def test_array_namespace_method(self):
self.assertIs(x.__array_namespace__(), array_api)


class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""

# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def test_clip_complex(self):
x = array_api.arange(5, dtype=array_api.complex64)
complex_msg = "Complex values have no ordering and cannot be clipped"
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x)

with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=x)

x = array_api.arange(5, dtype=array_api.int32)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, min=-1+5j)

with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=-1+5j)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max):
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
Expand Down
44 changes: 39 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,14 +872,45 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max):
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

def testClipError(self):
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
jnp.clip(jnp.zeros((3,)))

@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes + unsigned_dtypes,
)
def testClipNone(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)


# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Complex values have no ordering and cannot be clipped"
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x)

with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=x)

x = rng(shape, dtype=jnp.int32)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, min=-1+5j)

with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=jnp.array([-1+5j]))


@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
Expand Down Expand Up @@ -5772,7 +5803,7 @@ def testWrappedSignaturesMatch(self):
'argpartition': ['kind', 'order'],
'asarray': ['like'],
'broadcast_to': ['subok'],
'clip': ['kwargs'],
'clip': ['kwargs', 'out'],
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
Expand Down Expand Up @@ -5809,6 +5840,9 @@ def testWrappedSignaturesMatch(self):
}

extra_params = {
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
# standard
'clip': ['x', 'max', 'min'],
'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'],
'take_along_axis': ['mode'],
Expand Down

0 comments on commit 2512843

Please sign in to comment.