From 81b9db6b80cfdce989726aa1220766eff36ed787 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 25 Jul 2024 10:18:33 -0700 Subject: [PATCH] [array api] streamline astype device implementation When this was first implemented, convert_element_type did not yet have a sharding argument. Now we can simplify things by using it. --- jax/_src/numpy/lax_numpy.py | 55 +++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 66b240a402a9..76c99895165d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3479,15 +3479,44 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: deprecations.register("jax-numpy-astype-complex-to-real") -@util.implements(getattr(np, "astype", None), lax_description=""" -This is implemented via :func:`jax.lax.convert_element_type`, which may -have slightly different behavior than :func:`numpy.astype` in some cases. -In particular, the details of float-to-int and int-to-float casts are -implementation dependent. -""") def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: + """Convert an array to a specified dtype. + + JAX imlementation of :func:`numpy.astype`. + + This is implemented via :func:`jax.lax.convert_element_type`, which may + have slightly different behavior than :func:`numpy.astype` in some cases. + In particular, the details of float-to-int and int-to-float casts are + implementation dependent. + + Args: + x: input array to convert + dtype: output dtype + copy: if True, then always return a copy. If False (default) then only + return a copy if necessary. + device: optionally specify the device to which the output will be committed. + + Returns: + An array with the same shape as ``x``, containing values of the specified + dtype. + + See Also: + - :func:`jax.lax.convert_element_type`: lower-level function for XLA-style + dtype conversions. + + Examples: + >>> x = jnp.array([0, 1, 2, 3]) + >>> x + Array([0, 1, 2, 3], dtype=int32) + >>> x.astype('float32') + Array([0.0, 1.0, 2.0, 3.0], dtype=float32) + + >>> y = jnp.array([0.0, 0.5, 1.0]) + >>> y.astype(int) # truncates fractional values + Array([0, 0, 1], dtype=int32) + """ util.check_arraylike("astype", x) x_arr = asarray(x) @@ -3510,17 +3539,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, # to issue our warning. with warnings.catch_warnings(): warnings.simplefilter("ignore", ComplexWarning) - return _place_array( - lax.convert_element_type(x_arr, dtype), - device=device, copy=copy, - ) - -def _place_array(x, device=None, copy=None): - # TODO(micky774): Implement in future PRs as we formalize device placement - # semantics - if copy: - return _array_copy(x) - return x + result = lax_internal._convert_element_type( + x_arr, dtype, sharding=_normalize_to_sharding(device)) + return _array_copy(result) if copy else result @util.implements(np.asarray, lax_description=_ARRAY_DOC)