Skip to content

Commit

Permalink
Add support for device kwarg in astype
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed May 17, 2024
1 parent 5e2710c commit 33d2df7
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 37 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.28 (May 9, 2024)

* New Functionality
* {func}`jax.numpy.astype` supports a new `device` keyword argument.

* Bug fixes
* Reverted a change to `make_jaxpr` that was breaking Equinox (#21116).

Expand Down
19 changes: 3 additions & 16 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src.lib import xla_client
from jax._src.typing import Array, DLDeviceType
from jax._src.sharding import Sharding
from jax._src.numpy.util import _place_array

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)
Expand Down Expand Up @@ -148,19 +149,6 @@ def to_dlpack(x: Array, stream: int | Any | None = None,
f"version ({max_version}) was requested."
)

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr

def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
copy: bool | None = None):
Expand Down Expand Up @@ -194,8 +182,7 @@ def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)) # type: ignore
dlpack_device, = _arr.devices()
return _place_array(_arr, device, dlpack_device, copy)
return _place_array(_arr, device, copy)

def _from_dlpack(external_array, device: xla_client.Device | None = None,
copy: bool | None = None):
Expand Down Expand Up @@ -226,7 +213,7 @@ def _from_dlpack(external_array, device: xla_client.Device | None = None,

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))
return _place_array(_arr, device, dlpack_device, copy)
return _place_array(_arr, device, copy)

def from_dlpack(external_array,
device: xla_client.Device | Sharding | None = None,
Expand Down
14 changes: 5 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,18 +2853,14 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return _place_array(
return util._place_array(
lax.convert_element_type(x_arr, dtype),
device=device, copy=copy,
device=device,
# We translate between array API semantics of copy in _place_array, and
# the NumPy semantics of copy in astype.
copy=True if copy else None,
)

def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
if copy:
return _array_copy(x)
return x


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
Expand Down
25 changes: 24 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
import re
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar

import warnings

from jax.sharding import Sharding

from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape

Expand Down Expand Up @@ -117,6 +119,27 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]:
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}


def _place_array(x: Array, device: xc.Device | Sharding | None = None, copy=None) -> Array:
"""Helper utility for copying an array, or placing it on a device or sharding.
This utility uses `jax.device_put` for device placement.
"""
out = x
if device is not None:
# TODO(micky774): Add check to avoid error if no actual device transfer is
# necessary
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy, however copy=False. Set "
"copy=True or copy=None to perform the requested operation."
)
out = api.device_put(out, device)

# TODO(micky774): Avoid copy if data has already been copied via device
# transfer
return lax._array_copy(out) if copy else out


def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
Expand Down
6 changes: 6 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
and implements most of the API listed in the standard.
.. _Python array API standard: https://data-apis.org/array-api/latest/
Note that JAX may not always strictly adhere to array API device semantics when
using ``jax.jit``. In particular, specifying the ``device`` argument is
equivalent to calling ``jax.device_put(x, device)``. For up-to-date details on
device placement, see the documentation of ``jax.device_put`` for more details.
"""

from __future__ import annotations
Expand Down
7 changes: 4 additions & 3 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ def testTensorFlowToJaxInt64(self):
shape=all_shapes,
dtype=numpy_dtypes,
copy=[False, True],
device_transfer=[False, True],
)
def testNumpyToJax(self, shape, dtype, copy):
def testNumpyToJax(self, shape, dtype, copy, device_transfer):
rng = jtu.rand_default(self.rng())
x_np = rng(shape, dtype)
device = jax.devices()[0]
device = jax.devices()[0] if device_transfer else None
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
if jax.default_backend() == 'gpu' and not copy:
if device_transfer and not copy:
self.assertRaisesRegex(
ValueError,
r"Specified .* which requires a copy",
Expand Down
40 changes: 32 additions & 8 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax.sharding import SingleDeviceSharding
from jax.sharding import SingleDeviceSharding, PartitionSpec as P
from jax.test_util import check_grads

from jax._src import array
Expand Down Expand Up @@ -3931,19 +3931,43 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
change_dtype=[True, False],
[dict(dtype=dtype, new_dtype=new_dtype)
for dtype in all_dtypes
for new_dtype in (
complex_dtypes
if np.issubdtype(dtype, np.complexfloating)
else all_dtypes
)
],
shape=array_shapes,
copy=[True, False],
device_type=[None, "single", "shard"],
)
def testAstypeCopy(self, change_dtype, copy):
dtype = 'float32' if change_dtype else 'int32'
expect_copy = change_dtype or copy
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy)
@jtu.run_on_devices("gpu")
def testAstypePlacement(self, shape, dtype, new_dtype, copy, device_type):
rng = jtu.rand_default(self.rng())
x = jnp.asarray(rng(shape, dtype))

if device_type is None:
device = None
expected_sharding = x.sharding
elif device_type == "single":
device = jax.devices("cpu")[0]
expected_sharding = SingleDeviceSharding(device)
else:
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
device = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
expected_sharding = device

expect_copy = (dtype != new_dtype) or copy or device

self.assertEqual(y.dtype, dtype)
y = x.astype(new_dtype, copy=copy, device=device)
self.assertEqual(y.dtype, new_dtype)
self.assertEqual(y.sharding, expected_sharding)
y.delete()
self.assertNotEqual(x.is_deleted(), expect_copy)


def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
msg = "Casting from complex to non-complex dtypes will soon raise "
Expand Down

0 comments on commit 33d2df7

Please sign in to comment.