-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for device kwarg in astype, and add matching utility func #21086
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,14 +18,16 @@ | |
import re | ||
import textwrap | ||
from typing import Any, Callable, NamedTuple, TypeVar | ||
|
||
import warnings | ||
|
||
from jax.sharding import Sharding | ||
|
||
from jax._src import api | ||
from jax._src import config | ||
from jax._src import core | ||
from jax._src import dtypes | ||
from jax._src.lax import lax | ||
from jax._src.lib import xla_client as xc | ||
from jax._src.util import safe_zip, safe_map | ||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape | ||
|
||
|
@@ -117,6 +119,27 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]: | |
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} | ||
|
||
|
||
def _place_array(x: Array, device: xc.Device | Sharding | None = None, copy=None) -> Array: | ||
Micky774 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Helper utility for copying an array, or placing it on a device or sharding. | ||
|
||
This utility uses `jax.device_put` for device placement. | ||
""" | ||
out = x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why out=x? |
||
if device is not None: | ||
Micky774 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO(micky774): Add check to avoid error if no actual device transfer is | ||
# necessary | ||
if copy is not None and not copy: | ||
raise ValueError( | ||
f"Specified {device=} which requires a copy, however copy=False. Set " | ||
"copy=True or copy=None to perform the requested operation." | ||
) | ||
out = api.device_put(out, device) | ||
|
||
# TODO(micky774): Avoid copy if data has already been copied via device | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This todo doesn't make sense? In this branch the device is None. Remove the todo? |
||
# transfer | ||
return lax._array_copy(out) if copy else out | ||
|
||
|
||
def implements( | ||
original_fun: Callable[..., Any] | None, | ||
update_doc: bool = True, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,12 @@ | |
and implements most of the API listed in the standard. | ||
|
||
.. _Python array API standard: https://data-apis.org/array-api/latest/ | ||
|
||
|
||
Note that JAX may not always strictly adhere to array API device semantics when | ||
using ``jax.jit``. In particular, specifying the ``device`` argument is | ||
equivalent to calling ``jax.device_put(x, device)``. For up-to-date details on | ||
device placement, see the documentation of ``jax.device_put`` for more details. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually references go at the bottom of the docstring. You can put this paragraph above the |
||
""" | ||
|
||
from __future__ import annotations | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,7 @@ | |
import jax.ops | ||
from jax import lax | ||
from jax import numpy as jnp | ||
from jax.sharding import SingleDeviceSharding | ||
from jax.sharding import SingleDeviceSharding, PartitionSpec as P | ||
from jax.test_util import check_grads | ||
|
||
from jax._src import array | ||
|
@@ -3931,19 +3931,43 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): | |
self._CompileAndCheck(jnp_op, args_maker) | ||
|
||
@jtu.sample_product( | ||
change_dtype=[True, False], | ||
[dict(dtype=dtype, new_dtype=new_dtype) | ||
for dtype in all_dtypes | ||
for new_dtype in ( | ||
complex_dtypes | ||
if np.issubdtype(dtype, np.complexfloating) | ||
else all_dtypes | ||
) | ||
], | ||
shape=array_shapes, | ||
copy=[True, False], | ||
device_type=[None, "single", "shard"], | ||
) | ||
def testAstypeCopy(self, change_dtype, copy): | ||
dtype = 'float32' if change_dtype else 'int32' | ||
expect_copy = change_dtype or copy | ||
x = jnp.arange(5, dtype='int32') | ||
y = x.astype(dtype, copy=copy) | ||
@jtu.run_on_devices("gpu") | ||
def testAstypePlacement(self, shape, dtype, new_dtype, copy, device_type): | ||
rng = jtu.rand_default(self.rng()) | ||
x = jnp.asarray(rng(shape, dtype)) | ||
|
||
if device_type is None: | ||
device = None | ||
expected_sharding = x.sharding | ||
elif device_type == "single": | ||
device = jax.devices("cpu")[0] | ||
expected_sharding = SingleDeviceSharding(device) | ||
else: | ||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use (2, 2) mesh to coverage across more hardware? |
||
device = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to try other PartitionSpecs too (can you parameterize the test on that too) |
||
expected_sharding = device | ||
|
||
expect_copy = (dtype != new_dtype) or copy or device | ||
|
||
self.assertEqual(y.dtype, dtype) | ||
y = x.astype(new_dtype, copy=copy, device=device) | ||
self.assertEqual(y.dtype, new_dtype) | ||
self.assertEqual(y.sharding, expected_sharding) | ||
y.delete() | ||
self.assertNotEqual(x.is_deleted(), expect_copy) | ||
|
||
|
||
def testAstypeComplexDowncast(self): | ||
x = jnp.array(2.0+1.5j, dtype='complex64') | ||
msg = "Casting from complex to non-complex dtypes will soon raise " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move up to 0.4.29