diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 9a0be504f81a..aa7dcc58d7d1 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -72,6 +72,7 @@ broadcast_to as broadcast_to, can_cast as can_cast, ceil as ceil, + clip as clip, complex128 as complex128, complex64 as complex64, concat as concat, @@ -100,6 +101,7 @@ full_like as full_like, greater as greater, greater_equal as greater_equal, + hypot as hypot, iinfo as iinfo, imag as imag, inf as inf, @@ -195,8 +197,3 @@ from jax.experimental.array_api._data_type_functions import ( astype as astype, ) - -from jax.experimental.array_api._elementwise_functions import ( - clip as clip, - hypot as hypot, -) diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py deleted file mode 100644 index 103f8ab7d1ef..000000000000 --- a/jax/experimental/array_api/_elementwise_functions.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax.numpy import isdtype -from jax._src.dtypes import issubdtype -from jax._src.numpy.util import promote_args - - -# TODO(micky774): Remove when jnp.clip deprecation is completed -# (began 2024-4-2) and default behavior is Array API 2023 compliant -def clip(x, /, min=None, max=None): - """Returns the complex conjugate for each element x_i of the input array x.""" - x, = promote_args("clip", x) - - if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)): - raise ValueError( - "Clip received a complex value either through the input or the min/max " - "keywords. Complex values have no ordering and cannot be clipped. " - "Please convert to a real value or array by taking the real or " - "imaginary components via jax.numpy.real/imag respectively." - ) - return jax.numpy.clip(x, min=min, max=max) - - -# TODO(micky774): Remove when jnp.hypot deprecation is completed -# (began 2024-4-14) and default behavior is Array API 2023 compliant -def hypot(x1, x2, /): - """Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = promote_args("hypot", x1, x2) - - if issubdtype(x1.dtype, jax.numpy.complexfloating): - raise ValueError( - "hypot does not support complex-valued inputs. Please convert to real " - "values first, such as by using jnp.real or jnp.imag to take the real " - "or imaginary components respectively.") - return jax.numpy.hypot(x1, x2)