Skip to content

Commit

Permalink
Add support for arr.to_device()
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 14, 2023
1 parent eb4e4e6 commit fc61a0a
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
7 changes: 1 addition & 6 deletions array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,9 @@
# JAX doesn't yet support scalar boolean indexing
array_api_tests/test_array_object.py::test_getitem_masking

# JAX arrays don't have a to_device() method
array_api_tests/test_signatures.py::test_array_method_signature[to_device]

# Hypothesis warning
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices

# JAX arrays don't yet support to_device
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]

# Test suite attempts in-place mutation:
array_api_tests/test_special_cases.py::test_binary
array_api_tests/test_special_cases.py::test_iop
Expand All @@ -38,6 +32,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)]

# JAX's NaN sorting doesn't match specification
Expand Down
19 changes: 4 additions & 15 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

__array_api_version__ = '2022.12'
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__

from jax.experimental.array_api import linalg as linalg

Expand Down Expand Up @@ -190,17 +190,6 @@
vecdot as vecdot,
)

def _array_namespace(self, /, *, api_version: None | str = None):
import sys
if api_version is not None and api_version != __array_api_version__:
raise ValueError(f"{api_version=!r} is not available; "
f"available versions are: {[__array_api_version__]}")
return sys.modules[__name__]

def _setup_array_type():
# TODO(jakevdp): set on tracers as well?
from jax._src.array import ArrayImpl
setattr(ArrayImpl, "__array_namespace__", _array_namespace)

_setup_array_type()
del _setup_array_type
from jax.experimental.array_api import _array_methods
_array_methods.add_array_object_methods()
del _array_methods
45 changes: 45 additions & 0 deletions jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, Callable, Optional, Union

import jax
from jax._src.array import ArrayImpl
from jax.experimental.array_api._version import __array_api_version__

from jax._src.lib import xla_extension as xe


def _array_namespace(self, /, *, api_version: None | str = None):
if api_version is not None and api_version != __array_api_version__:
raise ValueError(f"{api_version=!r} is not available; "
f"available versions are: {[__array_api_version__]}")
return jax.experimental.array_api


def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
stream: Optional[Union[int, Any]] = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
# The type of device is defined by Array.device. In JAX, this is a callable that
# returns a device, so we must handle this case to satisfy the API spec.
return jax.device_put(self, device() if callable(device) else device)


def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
15 changes: 15 additions & 0 deletions jax/experimental/array_api/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__array_api_version__ = '2022.12'

0 comments on commit fc61a0a

Please sign in to comment.