Skip to content

Commit

Permalink
Merge pull request jax-ml#20175 from Micky774:array_api
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622040353
  • Loading branch information
jax authors committed Apr 5, 2024
2 parents 8111f38 + 2b1c3de commit f37e503
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 57 deletions.
159 changes: 111 additions & 48 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from typing import Any
import warnings

from jax._src.api import device_put
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array

from jax._src.sharding import Sharding

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
Expand Down Expand Up @@ -82,16 +83,111 @@ def to_dlpack(x: Array, take_ownership: bool = False,
x.addressable_data(0), stream=stream
) # type: ignore

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr

def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
copy: bool | None = None):
preferred_platform = getattr(device, "platform", None)
if device and preferred_platform == "gpu":
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"

cpu_backend = xla_bridge.get_backend("cpu")
gpu_backend = None

if preferred_platform in {"cuda", "rocm"}:
try:
gpu_backend = xla_bridge.get_backend(preferred_platform)
except RuntimeError:
raise TypeError(
f"A {str.upper(preferred_platform)} device was specified, however no "
f"{str.upper(preferred_platform)} backend was found."
)

def from_dlpack(external_array):
if preferred_platform is None:
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
pass
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
pass

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)) # type: ignore
dlpack_device, = _arr.devices()
return _place_array(_arr, device, dlpack_device, copy)

def _from_dlpack(external_array, device: xla_client.Device | None = None,
copy: bool | None = None):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")

backend = xla_bridge.get_backend(dl_device_platform)
dlpack_device = backend.device_from_local_hardware_id(device_id)
try:
stream = dlpack_device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))
return _place_array(_arr, device, dlpack_device, copy)

def from_dlpack(external_array,
device: xla_client.Device | Sharding | None = None,
copy: bool | None = None):
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
The returned :class:`~jax.Array` shares memory with ``external_array``.
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
device transfer or copy was requested.
Args:
external_array: an array object that has __dlpack__ and __dlpack_device__
external_array: An array object that has __dlpack__ and __dlpack_device__
methods, or a DLPack tensor on either CPU or GPU (legacy API).
device: The (optional) :py:class:`Device`, representing the device on which
the returned array should be placed. If given, then the result is committed
to the device. If unspecified, the resulting array will be unpacked onto the
same device it originated from. Setting ``device`` to a device different from
the source of ``external_array`` will require a copy, meaning ``copy`` must be
set to either ``True`` or ``None``.
copy: An (optional) boolean, controlling whether or not to a copy is performed.
If ``copy=True`` then a copy is always performed, even if unpacked onto the
same device. If ``copy=False`` then the copy is never peformed and will raise
an error if necessary. When ``copy=None`` then a copy may be performed if
needed for a device transfer.
Returns:
A jax.Array
Expand All @@ -102,49 +198,16 @@ def from_dlpack(external_array):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
"""
if isinstance(device, Sharding):
device_set = device.device_set
if len(device_set) > 1:
raise ValueError(
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
f"a Sharding with {len(device_set)} devices was provided."
)
device, = device_set
if hasattr(external_array, "__dlpack__"):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")

backend = xla_bridge.get_backend(device_platform)
device = backend.device_from_local_hardware_id(device_id)
try:
stream = device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)

return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, device, stream))
else:
# Legacy path
dlpack = external_array
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
gpu_backend = None

# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
gpu_backend = None
return _from_dlpack(external_array, device, copy)

return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
# Legacy path
return _legacy_from_dlpack(external_array, device, copy)
5 changes: 3 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,9 +2442,10 @@ def fromiter(*args, **kwargs):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
""")
def from_dlpack(x: Any) -> Array:
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
copy: bool | None = None) -> Array:
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x)
return from_dlpack(x, device=device, copy=copy)

@util.implements(np.fromfunction)
def fromfunction(function: Callable[..., Array], shape: Any,
Expand Down
9 changes: 6 additions & 3 deletions jax/experimental/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import jax
import jax.numpy as jnp

from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding

def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
Expand All @@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None):
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)

def from_dlpack(x, /):
return jnp.from_dlpack(x)
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
return jnp.from_dlpack(x, device=device, copy=copy)

def full(shape, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)
Expand Down
3 changes: 2 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
def from_dlpack(x: Any) -> Array: ...
def from_dlpack(x: Any, /, *, device: _Device | None = None,
copy: builtins.bool | None = None) -> Array: ...
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
count: int = ..., offset: int = ...) -> Array: ...
def fromfile(*args, **kwargs): ...
Expand Down
15 changes: 12 additions & 3 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,21 @@ def testTensorFlowToJaxInt64(self):
@jtu.sample_product(
shape=all_shapes,
dtype=numpy_dtypes,
copy=[False, True],
)
def testNumpyToJax(self, shape, dtype):
def testNumpyToJax(self, shape, dtype, copy):
rng = jtu.rand_default(self.rng())
x_np = rng(shape, dtype)
x_jax = jnp.from_dlpack(x_np)
self.assertAllClose(x_np, x_jax)
device = jax.devices()[0]
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
if jax.default_backend() == 'gpu' and not copy:
self.assertRaisesRegex(
ValueError,
r"Specified .* which requires a copy",
_from_dlpack
)
else:
self.assertAllClose(x_np, _from_dlpack())

@jtu.sample_product(
shape=all_shapes,
Expand Down

0 comments on commit f37e503

Please sign in to comment.