diff --git a/array-api-skips.txt b/array-api-skips.txt index 2b38370dda55..b3c79e413cf3 100644 --- a/array-api-skips.txt +++ b/array-api-skips.txt @@ -3,15 +3,9 @@ # JAX doesn't yet support scalar boolean indexing array_api_tests/test_array_object.py::test_getitem_masking -# JAX arrays don't have a to_device() method -array_api_tests/test_signatures.py::test_array_method_signature[to_device] - # Hypothesis warning array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# JAX arrays don't yet support to_device -array_api_tests/test_has_names.py::test_has_names[array_method-to_device] - # Test suite attempts in-place mutation: array_api_tests/test_special_cases.py::test_binary array_api_tests/test_special_cases.py::test_iop @@ -38,6 +32,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)] # JAX's NaN sorting doesn't match specification diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 416fc6bf1ea6..9e11eefb993d 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -14,7 +14,7 @@ from __future__ import annotations -__array_api_version__ = '2022.12' +from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ from jax.experimental.array_api import linalg as linalg @@ -190,17 +190,6 @@ vecdot as vecdot, ) -def _array_namespace(self, /, *, api_version: None | str = None): - import sys - if api_version is not None and api_version != __array_api_version__: - raise ValueError(f"{api_version=!r} is not available; " - f"available versions are: {[__array_api_version__]}") - return sys.modules[__name__] - -def _setup_array_type(): - # TODO(jakevdp): set on tracers as well? - from jax._src.array import ArrayImpl - setattr(ArrayImpl, "__array_namespace__", _array_namespace) - -_setup_array_type() -del _setup_array_type +from jax.experimental.array_api import _array_methods +_array_methods.add_array_object_methods() +del _array_methods diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py new file mode 100644 index 000000000000..5eedd6a151f2 --- /dev/null +++ b/jax/experimental/array_api/_array_methods.py @@ -0,0 +1,45 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Callable, Optional, Union + +import jax +from jax._src.array import ArrayImpl +from jax.experimental.array_api._version import __array_api_version__ + +from jax._src.lib import xla_extension as xe + + +def _array_namespace(self, /, *, api_version: None | str = None): + if api_version is not None and api_version != __array_api_version__: + raise ValueError(f"{api_version=!r} is not available; " + f"available versions are: {[__array_api_version__]}") + return jax.experimental.array_api + + +def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *, + stream: Optional[Union[int, Any]] = None): + if stream is not None: + raise NotImplementedError("stream argument of array.to_device()") + # The type of device is defined by Array.device. In JAX, this is a callable that + # returns a device, so we must handle this case to satisfy the API spec. + return jax.device_put(self, device() if callable(device) else device) + + +def add_array_object_methods(): + # TODO(jakevdp): set on tracers as well? + setattr(ArrayImpl, "__array_namespace__", _array_namespace) + setattr(ArrayImpl, "to_device", _to_device) diff --git a/jax/experimental/array_api/_version.py b/jax/experimental/array_api/_version.py new file mode 100644 index 000000000000..4936af86da4c --- /dev/null +++ b/jax/experimental/array_api/_version.py @@ -0,0 +1,15 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__array_api_version__ = '2022.12'