From c2f2b0ed2877bd3a73c03e99a2ee0598e20ab7cd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 30 Jul 2024 12:15:24 -0700 Subject: [PATCH] [array API] move api metadata into jax.numpy namespace --- .../numpy/array_api_metadata.py} | 57 ++++++++++++++++--- jax/_src/numpy/array_methods.py | 2 + jax/experimental/array_api/__init__.py | 15 ++--- jax/experimental/array_api/_array_methods.py | 31 ---------- jax/numpy/__init__.py | 4 ++ 5 files changed, 59 insertions(+), 50 deletions(-) rename jax/{experimental/array_api/_utility_functions.py => _src/numpy/array_api_metadata.py} (57%) delete mode 100644 jax/experimental/array_api/_array_methods.py diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/_src/numpy/array_api_metadata.py similarity index 57% rename from jax/experimental/array_api/_utility_functions.py rename to jax/_src/numpy/array_api_metadata.py index f75b2e2e29af..a196556b8e08 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,23 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This module contains metadata related to the `Python array API`_. + +.. _Python array API: https://data-apis.org/array-api/ +""" from __future__ import annotations +import importlib + import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc from jax._src import dtypes as _dtypes, config -# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api -# deprecation -class __array_namespace_info__: - def __init__(self): - self._capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, - } +# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete. +__array_api_version__ = '2023.12' + + +# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete. +def __array_namespace_info__() -> ArrayNamespaceInfo: + return ArrayNamespaceInfo() + + +def _array_namespace_property(self): + # TODO(jakevdp): clean this up once numpy fully supports the array API. + # In some environments, jax.experimental.array_api is not available. + # We return an AttributeError in this case, because some callers use + # hasattr checks to check for array API compatibility. + if not importlib.util.find_spec('jax.experimental.array_api'): + raise AttributeError("__array_namespace__ requires jax.experimental.array_api") + return __array_namespace__ + + +def __array_namespace__(*, api_version: None | str = None): + """Return the `Python array API`_ namespace for JAX. + + .. _Python array API: https://data-apis.org/array-api/ + """ + if api_version is not None and api_version != __array_api_version__: + raise ValueError(f"{api_version=!r} is not available; " + f"available versions are: {[__array_api_version__]}") + # TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete. + import jax.experimental.array_api + return jax.experimental.array_api # pytype: disable=module-attr + + +class ArrayNamespaceInfo: + """Metadata for the `Python array API`_ + .. _Python array API: https://data-apis.org/array-api/ + """ + _capabilities = { + "boolean indexing": True, + "data-dependent shapes": False, + } def _build_dtype_dict(self): array_api_types = { diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 515f245d11d3..8c755824a0f6 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -39,6 +39,7 @@ from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal from jax._src.lib import xla_client as xc +from jax._src.numpy import array_api_metadata from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -718,6 +719,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, } _array_properties = { + "__array_namespace__": array_api_metadata._array_namespace_property, "flat": _notimplemented_flat, "T": lax_numpy.transpose, "mT": lax_numpy.matrix_transpose, diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index ba0031951432..9a0be504f81a 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -36,11 +36,14 @@ from __future__ import annotations -from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ - from jax.experimental.array_api import fft as fft from jax.experimental.array_api import linalg as linalg +from jax._src.numpy.array_api_metadata import ( + __array_api_version__ as __array_api_version__, + __array_namespace_info__ as __array_namespace_info__, +) + from jax.numpy import ( abs as abs, acos as acos, @@ -197,11 +200,3 @@ clip as clip, hypot as hypot, ) - -from jax.experimental.array_api._utility_functions import ( - __array_namespace_info__ as __array_namespace_info__, -) - -from jax.experimental.array_api import _array_methods -_array_methods.add_array_object_methods() -del _array_methods diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py deleted file mode 100644 index 6a2fb09e2a6c..000000000000 --- a/jax/experimental/array_api/_array_methods.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax -from jax._src.array import ArrayImpl -from jax.experimental.array_api._version import __array_api_version__ - - -def _array_namespace(self, /, *, api_version: None | str = None): - if api_version is not None and api_version != __array_api_version__: - raise ValueError(f"{api_version=!r} is not available; " - f"available versions are: {[__array_api_version__]}") - return jax.experimental.array_api - - -def add_array_object_methods(): - # TODO(jakevdp): set on tracers as well? - setattr(ArrayImpl, "__array_namespace__", _array_namespace) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index bfefb56521c9..9f78dd0a8224 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -274,6 +274,10 @@ except ImportError: pass +from jax._src.numpy.array_api_metadata import ( + __array_api_version__ as __array_api_version__ +) + from jax._src.numpy.index_tricks import ( c_ as c_, index_exp as index_exp,