Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array api] streamline astype device implementation #22664

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3479,15 +3479,44 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:

deprecations.register("jax-numpy-astype-complex-to-real")

@util.implements(getattr(np, "astype", None), lax_description="""
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None,
/, *, copy: bool = False,
device: xc.Device | Sharding | None = None) -> Array:
"""Convert an array to a specified dtype.

JAX imlementation of :func:`numpy.astype`.

This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.

Args:
x: input array to convert
dtype: output dtype
copy: if True, then always return a copy. If False (default) then only
return a copy if necessary.
device: optionally specify the device to which the output will be committed.

Returns:
An array with the same shape as ``x``, containing values of the specified
dtype.

See Also:
- :func:`jax.lax.convert_element_type`: lower-level function for XLA-style
dtype conversions.

Examples:
>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)

>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int) # truncates fractional values
Array([0, 0, 1], dtype=int32)
"""
util.check_arraylike("astype", x)
x_arr = asarray(x)

Expand All @@ -3510,17 +3539,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return _place_array(
lax.convert_element_type(x_arr, dtype),
device=device, copy=copy,
)

def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
if copy:
return _array_copy(x)
return x
result = lax_internal._convert_element_type(
x_arr, dtype, sharding=_normalize_to_sharding(device))
return _array_copy(result) if copy else result


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
Expand Down