Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 11, 2024
1 parent 71ec6e3 commit d71ac98
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 55 deletions.
158 changes: 109 additions & 49 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
from __future__ import annotations

import enum
from typing import Any
from typing import Any, Optional
import warnings

from jax._src.api import device_put
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array

from jax._src.sharding import Sharding

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
Expand Down Expand Up @@ -82,16 +83,108 @@ def to_dlpack(x: Array, take_ownership: bool = False,
x.addressable_data(0), stream=stream
) # type: ignore

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr

def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None , copy: Optional[bool] = None):
preferred_platform = getattr(device, "platform", None)
if device and preferred_platform == "gpu":
preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm"

cpu_backend = xla_bridge.get_backend("cpu")
gpu_backend = None

if preferred_platform in {"cuda", "rocm"}:
try:
gpu_backend = xla_bridge.get_backend(preferred_platform)
except RuntimeError:
raise TypeError(
f"A {str.upper(preferred_platform)} device was specified, however no "
f"{str.upper(preferred_platform)} backend was found."
)

def from_dlpack(external_array):
if preferred_platform is None:
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
pass
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
pass

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)) # type: ignore

return _place_array(_arr, device, _arr.devices().pop(), copy)

def _from_dlpack(external_array, device: xla_client.Device | None = None , copy: bool | None = None):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")

backend = xla_bridge.get_backend(dl_device_platform)
dlpack_device = backend.device_from_local_hardware_id(device_id)
try:
stream = dlpack_device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))

return _place_array(_arr, device, dlpack_device, copy)

def from_dlpack(external_array, device: xla_client.Device | Sharding | None = None , copy: bool | None = None):
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
The returned :class:`~jax.Array` shares memory with ``external_array``.
The returned :class:`~jax.Array` shares memory with ``external_array`` if no
device transfer or copy was requested.
Args:
external_array: an array object that has __dlpack__ and __dlpack_device__
external_array: An array object that has __dlpack__ and __dlpack_device__
methods, or a DLPack tensor on either CPU or GPU (legacy API).
device: The (optional) :py:class:`Device`, representing the device on which
the returned array should be placed. If given, then the result is committed
to the device. If unspecified, the resulting array will be unpacked onto the
same device it originated from. Setting ``device`` to a device different from
the source of ``external_array`` will require a copy, meaning ``copy`` must be
set to either ``True`` or ``None``.
copy: An (optional) boolean, controlling whether or not to a copy is performed.
If ``copy=True`` then a copy is always performed, even if unpacked onto the
same device. If ``copy=False`` then the copy is never peformed and will raise
an error if necessary. When ``copy=None`` then a copy may be performed if
needed for a device transfer.
Returns:
A jax.Array
Expand All @@ -102,49 +195,16 @@ def from_dlpack(external_array):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
"""
if isinstance(device, Sharding):
device_set = device.device_set
if len(device_set) > 1:
raise ValueError(
"from_dlpack can only unpack a dlpack tensor onto a singular device, but "
f"a Sharding with {len(device_set)} devices was provided."
)
device = device_set.pop()
if hasattr(external_array, "__dlpack__"):
dl_device_type, device_id = external_array.__dlpack_device__()
try:
device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")

backend = xla_bridge.get_backend(device_platform)
device = backend.device_from_local_hardware_id(device_id)
try:
stream = device.get_stream_for_external_ready_events()
except xla_client.XlaRuntimeError as err: # type: ignore
if "UNIMPLEMENTED" in str(err):
stream = None
else:
raise
dlpack = external_array.__dlpack__(stream=stream)

return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, device, stream))
else:
# Legacy path
dlpack = external_array
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
gpu_backend = None

# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
gpu_backend = None
return _from_dlpack(external_array, device, copy)

return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
# Legacy path
return _legacy_from_dlpack(external_array, device, copy)
4 changes: 2 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2436,9 +2436,9 @@ def fromiter(*args, **kwargs):
is later modified in-place, it may lead to undefined behavior when using
the associated JAX array.
""")
def from_dlpack(x: Any) -> Array:
def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None) -> Array:
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x)
return from_dlpack(x, device=device, copy=copy)

@util.implements(np.fromfunction)
def fromfunction(function: Callable[..., Array], shape: Any,
Expand Down
9 changes: 6 additions & 3 deletions jax/experimental/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import jax
import jax.numpy as jnp

from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding

def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
Expand All @@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None):
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)

def from_dlpack(x, /):
return jnp.from_dlpack(x)
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
return jnp.from_dlpack(x, device=device, copy=copy)

def full(shape, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)
Expand Down
3 changes: 2 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeV

from jax._src import core as _core
from jax._src import dtypes as _dtypes
from jax._src.lib import xla_client as xc
from jax._src.lax.lax import PrecisionLike
from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
Expand Down Expand Up @@ -353,7 +354,7 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ...
def from_dlpack(x: Any) -> Array: ...
def from_dlpack(x: Any, /, *, device: _Sharding | xc.Device | None = None, copy: builtins.bool | None = None) -> Array: ...
def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ...,
count: int = ..., offset: int = ...) -> Array: ...
def fromfile(*args, **kwargs): ...
Expand Down

0 comments on commit d71ac98

Please sign in to comment.