From 14fa06298ef82ead0313ba3b24a303ace7edfdf5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 11:19:17 -0700 Subject: [PATCH] [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api --- .github/workflows/ci-build.yaml | 2 +- .github/workflows/jax-array-api.yml | 4 +- CHANGELOG.md | 7 + docs/jax.experimental.array_api.rst | 26 ++- docs/jax.numpy.rst | 33 +++ jax/BUILD | 10 +- jax/_src/basearray.pyi | 3 + jax/_src/numpy/array_api_metadata.py | 28 +-- jax/_src/numpy/array_methods.py | 2 +- jax/experimental/__init__.py | 17 ++ jax/experimental/array_api/__init__.py | 193 ------------------ jax/experimental/array_api/_version.py | 15 -- jax/experimental/array_api/fft.py | 30 --- jax/experimental/array_api/linalg.py | 39 ---- jax/numpy/__init__.py | 3 +- jax/numpy/__init__.pyi | 4 + .../skips.txt => tests/array_api_skips.txt | 0 tests/array_api_test.py | 37 ++-- 18 files changed, 128 insertions(+), 325 deletions(-) delete mode 100644 jax/experimental/array_api/__init__.py delete mode 100644 jax/experimental/array_api/_version.py delete mode 100644 jax/experimental/array_api/fft.py delete mode 100644 jax/experimental/array_api/linalg.py rename jax/experimental/array_api/skips.txt => tests/array_api_skips.txt (100%) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 6b8912d61558..9869d0256f8d 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -140,7 +140,7 @@ jobs: JAX_ARRAY: 1 PY_COLORS: 1 run: | - pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md + pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 29cef9c856fa..78cddb411feb 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -38,8 +38,8 @@ jobs: python -m pip install -r array-api-tests/requirements.txt - name: Run the test suite env: - ARRAY_API_TESTS_MODULE: jax.experimental.array_api + ARRAY_API_TESTS_MODULE: jax.numpy JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt + pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 348972f23a2a..5045de76c84c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.32 +* Changes + * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. + See {ref}`python-array-api` for more information. + * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the `stablehlo` dialect instead. @@ -23,6 +27,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly. * `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`. * `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`. + * The `jax.experimental.array_api` module is deprecated, and importing it is no + longer required to use the Array API. `jax.numpy` supports the array API + directly; see {ref}`python-array-api` for more information. ## jaxlib 0.4.32 diff --git a/docs/jax.experimental.array_api.rst b/docs/jax.experimental.array_api.rst index e5fa25f90b18..8661fae7c1c6 100644 --- a/docs/jax.experimental.array_api.rst +++ b/docs/jax.experimental.array_api.rst @@ -1,4 +1,28 @@ ``jax.experimental.array_api`` module ===================================== -.. automodule:: jax.experimental.array_api +.. note:: + The ``jax.experimental.array_api`` module is deprecated as of JAX v0.4.32, and + importing ``jax.experimental.array_api`` is no longer necessary. {mod}`jax.numpy` + implements the array API standard directly by default. See :ref:`python-array-api` + for details. + +This module includes experimental JAX support for the `Python array API standard`_. +Support for this is currently experimental and not fully complete. + +Example Usage:: + + >>> from jax.experimental import array_api as xp + + >>> xp.__array_api_version__ + '2023.12' + + >>> arr = xp.arange(1000) + + >>> arr.sum() + Array(499500, dtype=int32) + +The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`, +and implements most of the API listed in the standard. + +.. _Python array API standard: https://data-apis.org/array-api/ diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index b96dfcdfb208..d6b7d74bd429 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -528,3 +528,36 @@ This is because in general, pickling and unpickling may take place in different environments, and there is no general way to map the device IDs of one runtime to the device IDs of another. If :mod:`pickle` is used in traced/JIT-compiled code, it will result in a :class:`~jax.errors.ConcretizationTypeError`. + +.. _python-array-api: + +Python Array API standard +------------------------- + +.. note:: + + Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order + to enable the array API for JAX arrays. After JAX v0.4.32, importing this + module is no longer required, and will raise a deprecation warning. + +Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible +with the `Python Array API Standard`_. You can access the Array API namespace via +:meth:`jax.Array.__array_namespace__`:: + + >>> def f(x): + ... nx = x.__array_namespace__() + ... return nx.sin(x) ** 2 + nx.cos(x) ** 2 + + >>> import jax.numpy as jnp + >>> x = jnp.arange(5) + >>> f(x).round() + Array([1., 1., 1., 1., 1.], dtype=float32) + +JAX departs from the standard in a few places, namely because JAX arrays are +immutable, in-place updates are not supported. Some of these incompatibilities +are being addressed via the `array-api-compat`_ module. + +For more information, refer to the `Python Array API Standard`_ documentation. + +.. _Python Array API Standard: https://data-apis.org/array-api +.. _array-api-compat: https://github.com/data-apis/array-api-compat \ No newline at end of file diff --git a/jax/BUILD b/jax/BUILD index 4556fa285220..69be704ba57c 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -987,13 +987,11 @@ pytype_library( pytype_library( name = "experimental_array_api", - srcs = glob( - [ - "experimental/array_api/*.py", - ], - ), visibility = [":internal"] + jax_visibility("array_api"), - deps = [":jax"], + deps = [ + ":experimental", + ":jax", + ], ) pytype_library( diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index eda69d5f4ba2..32e1d27dcdf5 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -13,6 +13,7 @@ # limitations under the License. import abc from collections.abc import Callable, Sequence +from types import ModuleType from typing import Any, Union import numpy as np @@ -48,6 +49,8 @@ class Array(abc.ABC): raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." " Use jax.numpy.array, or jax.numpy.zeros instead.") + def __array_namespace__(self, *, api_version: None | str = ...) -> ModuleType: ... + def __getitem__(self, key) -> Array: ... def __setitem__(self, key, value) -> None: ... def __len__(self) -> int: ... diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index a196556b8e08..4a01f579a67e 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -19,7 +19,7 @@ """ from __future__ import annotations -import importlib +from types import ModuleType import jax from jax._src.sharding import Sharding @@ -27,26 +27,10 @@ from jax._src import dtypes as _dtypes, config -# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete. __array_api_version__ = '2023.12' -# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete. -def __array_namespace_info__() -> ArrayNamespaceInfo: - return ArrayNamespaceInfo() - - -def _array_namespace_property(self): - # TODO(jakevdp): clean this up once numpy fully supports the array API. - # In some environments, jax.experimental.array_api is not available. - # We return an AttributeError in this case, because some callers use - # hasattr checks to check for array API compatibility. - if not importlib.util.find_spec('jax.experimental.array_api'): - raise AttributeError("__array_namespace__ requires jax.experimental.array_api") - return __array_namespace__ - - -def __array_namespace__(*, api_version: None | str = None): +def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: """Return the `Python array API`_ namespace for JAX. .. _Python array API: https://data-apis.org/array-api/ @@ -54,9 +38,11 @@ def __array_namespace__(*, api_version: None | str = None): if api_version is not None and api_version != __array_api_version__: raise ValueError(f"{api_version=!r} is not available; " f"available versions are: {[__array_api_version__]}") - # TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete. - import jax.experimental.array_api - return jax.experimental.array_api # pytype: disable=module-attr + return jax.numpy + + +def __array_namespace_info__() -> ArrayNamespaceInfo: + return ArrayNamespaceInfo() class ArrayNamespaceInfo: diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 6a9e6e0ff4f8..87635be37c84 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -666,6 +666,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, } _array_methods = { + "__array_namespace__": array_api_metadata.__array_namespace__, "all": reductions.all, "any": reductions.any, "argmax": lax_numpy.argmax, @@ -719,7 +720,6 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, } _array_properties = { - "__array_namespace__": array_api_metadata._array_namespace_property, "flat": _notimplemented_flat, "T": lax_numpy.transpose, "mT": lax_numpy.matrix_transpose, diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index caf27ec7a8ca..dcca44773921 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -25,3 +25,20 @@ from jax._src.earray import ( EArray as EArray ) + +from jax import numpy as _array_api + + +_deprecations = { + # Deprecated 01 Aug 2024 + "array_api": ( + "jax.experimental.array_api import is no longer required as of JAX v0.4.32; " + "jax.numpy supports the array API by default.", + _array_api + ), +} + +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr +del _array_api diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py deleted file mode 100644 index c3bc83112f67..000000000000 --- a/jax/experimental/array_api/__init__.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module includes experimental JAX support for the `Python array API standard`_. -Support for this is currently experimental and not fully complete. - -Example Usage:: - - >>> from jax.experimental import array_api as xp - - >>> xp.__array_api_version__ - '2023.12' - - >>> arr = xp.arange(1000) - - >>> arr.sum() - Array(499500, dtype=int32) - -The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`, -and implements most of the API listed in the standard. - -.. _Python array API standard: https://data-apis.org/array-api/latest/ -""" - -from __future__ import annotations - -from jax.experimental.array_api import fft as fft -from jax.experimental.array_api import linalg as linalg - -from jax._src.numpy.array_api_metadata import ( - __array_api_version__ as __array_api_version__, - __array_namespace_info__ as __array_namespace_info__, -) - -from jax.numpy import ( - abs as abs, - acos as acos, - acosh as acosh, - add as add, - all as all, - any as any, - arange as arange, - argmax as argmax, - argmin as argmin, - argsort as argsort, - asarray as asarray, - asin as asin, - asinh as asinh, - astype as astype, - atan as atan, - atan2 as atan2, - atanh as atanh, - bitwise_and as bitwise_and, - bitwise_invert as bitwise_invert, - bitwise_left_shift as bitwise_left_shift, - bitwise_or as bitwise_or, - bitwise_right_shift as bitwise_right_shift, - bitwise_xor as bitwise_xor, - bool as bool, - broadcast_arrays as broadcast_arrays, - broadcast_to as broadcast_to, - can_cast as can_cast, - ceil as ceil, - clip as clip, - complex128 as complex128, - complex64 as complex64, - concat as concat, - conj as conj, - copysign as copysign, - cos as cos, - cosh as cosh, - cumulative_sum as cumulative_sum, - divide as divide, - e as e, - empty as empty, - empty_like as empty_like, - equal as equal, - exp as exp, - expand_dims as expand_dims, - expm1 as expm1, - eye as eye, - finfo as finfo, - flip as flip, - float32 as float32, - float64 as float64, - floor as floor, - floor_divide as floor_divide, - from_dlpack as from_dlpack, - full as full, - full_like as full_like, - greater as greater, - greater_equal as greater_equal, - hypot as hypot, - iinfo as iinfo, - imag as imag, - inf as inf, - int16 as int16, - int32 as int32, - int64 as int64, - int8 as int8, - isdtype as isdtype, - isfinite as isfinite, - isinf as isinf, - isnan as isnan, - less as less, - less_equal as less_equal, - linspace as linspace, - log as log, - log10 as log10, - log1p as log1p, - log2 as log2, - logaddexp as logaddexp, - logical_and as logical_and, - logical_not as logical_not, - logical_or as logical_or, - logical_xor as logical_xor, - matmul as matmul, - matrix_transpose as matrix_transpose, - max as max, - maximum as maximum, - mean as mean, - meshgrid as meshgrid, - min as min, - minimum as minimum, - moveaxis as moveaxis, - multiply as multiply, - nan as nan, - negative as negative, - newaxis as newaxis, - nonzero as nonzero, - not_equal as not_equal, - ones as ones, - ones_like as ones_like, - permute_dims as permute_dims, - pi as pi, - positive as positive, - pow as pow, - prod as prod, - real as real, - remainder as remainder, - repeat as repeat, - reshape as reshape, - result_type as result_type, - roll as roll, - round as round, - searchsorted as searchsorted, - sign as sign, - signbit as signbit, - sin as sin, - sinh as sinh, - sort as sort, - sqrt as sqrt, - square as square, - squeeze as squeeze, - stack as stack, - std as std, - subtract as subtract, - sum as sum, - take as take, - tan as tan, - tanh as tanh, - tensordot as tensordot, - tile as tile, - tril as tril, - triu as triu, - trunc as trunc, - uint16 as uint16, - uint32 as uint32, - uint64 as uint64, - uint8 as uint8, - unique_all as unique_all, - unique_counts as unique_counts, - unique_inverse as unique_inverse, - unique_values as unique_values, - unstack as unstack, - var as var, - vecdot as vecdot, - where as where, - zeros as zeros, - zeros_like as zeros_like, -) diff --git a/jax/experimental/array_api/_version.py b/jax/experimental/array_api/_version.py deleted file mode 100644 index 104df73c77b9..000000000000 --- a/jax/experimental/array_api/_version.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__array_api_version__ = '2023.12' diff --git a/jax/experimental/array_api/fft.py b/jax/experimental/array_api/fft.py deleted file mode 100644 index 0354aa41a764..000000000000 --- a/jax/experimental/array_api/fft.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax.numpy.fft import ( - fft as fft, - fftfreq as fftfreq, - fftn as fftn, - fftshift as fftshift, - hfft as hfft, - ifft as ifft, - ifftn as ifftn, - ifftshift as ifftshift, - ihfft as ihfft, - irfft as irfft, - irfftn as irfftn, - rfft as rfft, - rfftfreq as rfftfreq, - rfftn as rfftn, -) diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py deleted file mode 100644 index 98ede18b8b2e..000000000000 --- a/jax/experimental/array_api/linalg.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax.numpy.linalg import ( - cholesky as cholesky, - cross as cross, - det as det, - diagonal as diagonal, - eigh as eigh, - eigvalsh as eigvalsh, - inv as inv, - matmul as matmul, - matrix_norm as matrix_norm, - matrix_power as matrix_power, - matrix_rank as matrix_rank, - matrix_transpose as matrix_transpose, - outer as outer, - pinv as pinv, - qr as qr, - slogdet as slogdet, - solve as solve, - svd as svd, - svdvals as svdvals, - tensordot as tensordot, - trace as trace, - vecdot as vecdot, - vector_norm as vector_norm, -) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 9f78dd0a8224..88e1840ef1c0 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -275,7 +275,8 @@ pass from jax._src.numpy.array_api_metadata import ( - __array_api_version__ as __array_api_version__ + __array_api_version__ as __array_api_version__, + __array_namespace_info__ as __array_namespace_info__, ) from jax._src.numpy.index_tricks import ( diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 4ead2abb4ece..dfea8a8ddd74 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -11,6 +11,7 @@ from jax._src.lax.lax import PrecisionLike from jax._src.lax.slicing import GatherScatterMode from jax._src.lib import Device from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass +from jax._src.numpy.array_api_metadata import ArrayNamespaceInfo from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar, @@ -27,6 +28,9 @@ _Device = Device ComplexWarning: type +__array_api_version__: str +def __array_namespace_info__() -> ArrayNamespaceInfo: ... + _deprecations: dict[str, tuple[str, Any]] def abs(x: ArrayLike, /) -> Array: ... def absolute(x: ArrayLike, /) -> Array: ... diff --git a/jax/experimental/array_api/skips.txt b/tests/array_api_skips.txt similarity index 100% rename from jax/experimental/array_api/skips.txt rename to tests/array_api_skips.txt diff --git a/tests/array_api_test.py b/tests/array_api_test.py index dcb33b9bc57f..41951fea381d 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Smoketest for jax.experimental.array_api +"""Smoketest for JAX's array API. The full test suite for the array API is run via the array-api-tests CI; this is just a minimal smoke test to catch issues early. @@ -26,7 +26,8 @@ import jax.numpy as jnp from jax._src import config, test_util as jtu from jax._src.dtypes import _default_types, canonicalize_dtype -from jax.experimental import array_api + +ARRAY_API_NAMESPACE = jnp config.parse_flags_with_absl() @@ -36,7 +37,6 @@ 'acosh', 'add', 'all', - 'annotations', 'any', 'arange', 'argmax', @@ -233,22 +233,29 @@ class ArrayAPISmokeTest(absltest.TestCase): """Smoke test for the array API.""" def test_main_namespace(self): - self.assertContainsSubset(MAIN_NAMESPACE, names(array_api)) + self.assertContainsSubset(MAIN_NAMESPACE, names(ARRAY_API_NAMESPACE)) def test_linalg_namespace(self): - self.assertContainsSubset(LINALG_NAMESPACE, names(array_api.linalg)) + self.assertContainsSubset(LINALG_NAMESPACE, names(ARRAY_API_NAMESPACE.linalg)) def test_fft_namespace(self): - self.assertContainsSubset(FFT_NAMESPACE, names(array_api.fft)) + self.assertContainsSubset(FFT_NAMESPACE, names(ARRAY_API_NAMESPACE.fft)) def test_array_namespace_method(self): - x = array_api.arange(20) + x = ARRAY_API_NAMESPACE.arange(20) self.assertIsInstance(x, jax.Array) - self.assertIs(x.__array_namespace__(), array_api) + self.assertIs(x.__array_namespace__(), ARRAY_API_NAMESPACE) + + def test_deprecated_import(self): + msg = "jax.experimental.array_api import is no longer required" + with self.assertWarnsRegex(DeprecationWarning, msg): + from jax.experimental import array_api + self.assertIs(array_api, ARRAY_API_NAMESPACE) + class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase): - info = array_api.__array_namespace_info__() + info = ARRAY_API_NAMESPACE.__array_namespace_info__() def setUp(self): super().setUp() @@ -333,20 +340,20 @@ class ArrayAPIErrors(absltest.TestCase): # TODO(micky774): Remove when jnp.clip deprecation is completed # (began 2024-4-2) and default behavior is Array API 2023 compliant def test_clip_complex(self): - x = array_api.arange(5, dtype=array_api.complex64) + x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.complex64) complex_msg = "Complex values have no ordering and cannot be clipped" with self.assertRaisesRegex(ValueError, complex_msg): - array_api.clip(x) + ARRAY_API_NAMESPACE.clip(x) with self.assertRaisesRegex(ValueError, complex_msg): - array_api.clip(x, max=x) + ARRAY_API_NAMESPACE.clip(x, max=x) - x = array_api.arange(5, dtype=array_api.int32) + x = ARRAY_API_NAMESPACE.arange(5, dtype=ARRAY_API_NAMESPACE.int32) with self.assertRaisesRegex(ValueError, complex_msg): - array_api.clip(x, min=-1+5j) + ARRAY_API_NAMESPACE.clip(x, min=-1+5j) with self.assertRaisesRegex(ValueError, complex_msg): - array_api.clip(x, max=-1+5j) + ARRAY_API_NAMESPACE.clip(x, max=-1+5j) if __name__ == '__main__':