Skip to content

Commit

Permalink
Merge pull request #22764 from jakevdp:array-api-methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657697622
  • Loading branch information
jax authors committed Jul 30, 2024
2 parents d7c2b49 + c2f2b0e commit b1066ee
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The JAX Authors.
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,23 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains metadata related to the `Python array API`_.
.. _Python array API: https://data-apis.org/array-api/
"""
from __future__ import annotations

import importlib

import jax
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config

# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api
# deprecation
class __array_namespace_info__:

def __init__(self):
self._capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}
# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
__array_api_version__ = '2023.12'


# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
def __array_namespace_info__() -> ArrayNamespaceInfo:
return ArrayNamespaceInfo()


def _array_namespace_property(self):
# TODO(jakevdp): clean this up once numpy fully supports the array API.
# In some environments, jax.experimental.array_api is not available.
# We return an AttributeError in this case, because some callers use
# hasattr checks to check for array API compatibility.
if not importlib.util.find_spec('jax.experimental.array_api'):
raise AttributeError("__array_namespace__ requires jax.experimental.array_api")
return __array_namespace__


def __array_namespace__(*, api_version: None | str = None):
"""Return the `Python array API`_ namespace for JAX.
.. _Python array API: https://data-apis.org/array-api/
"""
if api_version is not None and api_version != __array_api_version__:
raise ValueError(f"{api_version=!r} is not available; "
f"available versions are: {[__array_api_version__]}")
# TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete.
import jax.experimental.array_api
return jax.experimental.array_api # pytype: disable=module-attr


class ArrayNamespaceInfo:
"""Metadata for the `Python array API`_
.. _Python array API: https://data-apis.org/array-api/
"""
_capabilities = {
"boolean indexing": True,
"data-dependent shapes": False,
}

def _build_dtype_dict(self):
array_api_types = {
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.array import ArrayImpl
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.numpy import array_api_metadata
from jax._src.numpy import lax_numpy
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
Expand Down Expand Up @@ -718,6 +719,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
}

_array_properties = {
"__array_namespace__": array_api_metadata._array_namespace_property,
"flat": _notimplemented_flat,
"T": lax_numpy.transpose,
"mT": lax_numpy.matrix_transpose,
Expand Down
15 changes: 5 additions & 10 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@

from __future__ import annotations

from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__

from jax.experimental.array_api import fft as fft
from jax.experimental.array_api import linalg as linalg

from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,
)

from jax.numpy import (
abs as abs,
acos as acos,
Expand Down Expand Up @@ -197,11 +200,3 @@
clip as clip,
hypot as hypot,
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
)

from jax.experimental.array_api import _array_methods
_array_methods.add_array_object_methods()
del _array_methods
31 changes: 0 additions & 31 deletions jax/experimental/array_api/_array_methods.py

This file was deleted.

4 changes: 4 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@
except ImportError:
pass

from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__
)

from jax._src.numpy.index_tricks import (
c_ as c_,
index_exp as index_exp,
Expand Down

0 comments on commit b1066ee

Please sign in to comment.