diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e318e12f1bd..925c1327cbec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.26 +* New Features + * {func}`jax.numpy.astype` supports new `device` keyword argument. + * Deprecations & Removals * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. @@ -28,6 +31,12 @@ Remember to align the itemized text with the first line of an item within a list `jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config` and `jax.extend.source_info_util` instead. +* Bug fixes + * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. + Previously, no copy would be made when the output array would have the same + dtype as the input array. This may result in some increased memory usage. + To prevent copying when possible, set `copy=False`. + ## jaxlib 0.4.26 ## jax 0.4.25 (Feb 26, 2024) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index fb1e52bd1be9..3611ed4aa6e5 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -31,11 +31,13 @@ import numpy as np import jax from jax import lax +from jax.sharding import Sharding from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -55,7 +57,7 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: +def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype) + return lax_numpy.astype(arr, dtype, copy=copy, device=device) def _nbytes(arr: ArrayLike) -> int: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 8bda835f2d0b..8b75f802ebff 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -41,7 +41,7 @@ import opt_einsum import jax -from jax import jit +from jax import jit, device_put from jax import errors from jax import lax from jax.sharding import Sharding, SingleDeviceSharding @@ -2209,19 +2209,38 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: else: return x - @util.implements(getattr(np, "astype", None), lax_description=""" This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :func:`numpy.astype` in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) -def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: - del copy # unused in JAX +def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array: if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") - return lax.convert_element_type(x, dtype) + src_dtype = x.dtype if hasattr(x, "dtype") else dtypes.dtype(x) + if ( + src_dtype is not None + and dtypes.isdtype(src_dtype, "complex floating") + and dtypes.isdtype(dtype, ("integral", "real floating")) + ): + raise ValueError( + "Casting from complex to non-complex dtypes is not permitted. Please " + "first use jnp.real or jnp.imag to take the real/imaginary component of " + "your input." + ) + src_devices = ( + x.devices() if hasattr(x, "devices") + and not isinstance(x, core.Tracer) else None + ) + arr = x + if device is not None and src_devices != {device}: + arr = device_put(x, device) + elif copy: + arr = _array_copy(x) + return lax.convert_element_type(arr, dtype) + @util.implements(np.asarray, lax_description=_ARRAY_DOC) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index d2bb032b85ab..a34f81244dc3 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations +import builtins import functools from typing import NamedTuple import jax import jax.numpy as jnp +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding from jax.experimental.array_api._dtypes import ( bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -124,8 +128,8 @@ def _promote_types(t1, t2): raise ValueError("No promotion path for {t1} & {t2}") -def astype(x, dtype, /, *, copy=True): - return jnp.array(x, dtype=dtype, copy=copy) +def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): + return jnp.astype(x, dtype, copy=copy, device=device) def can_cast(from_, to, /): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index a618f457016d..8c9a1f3c6105 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -9,6 +9,8 @@ from jax._src import dtypes as _dtypes from jax._src.lax.lax import PrecisionLike from jax._src.lax.slicing import GatherScatterMode from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass +from jax._src.sharding import Sharding +from jax._src.lib import xla_client as xc from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape from jax.numpy import fft as fft, linalg as linalg from jax.sharding import Sharding as _Sharding @@ -112,7 +114,7 @@ def asarray( ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ... +def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ed3a16eda43a..f123fd714ffb 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3789,6 +3789,41 @@ def testAstype(self, from_dtype, to_dtype, use_method): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + change_dtype=[True, False], + copy=[True, False], + change_device=[True, False], + ) + def testAstypeCopy(self, change_dtype, copy, change_device): + if jax.device_count() == 1 and change_device: + raise unittest.SkipTest( + "Testing device transfer requires at least two available devices." + ) + + dtype = 'float32' if change_dtype else 'int32' + device = jax.devices()[-1] if change_device else None + expect_copy = change_dtype or copy or change_device + x = jnp.arange(5, dtype='int32') + y = x.astype(dtype, copy=copy, device=device) + + assert y.dtype == dtype + if change_device: + assert y.devices() == {device} + else: + y.delete() + get_val = lambda: np.array(x) + err_msg = "Array has been deleted" + if expect_copy: + get_val() + else: + jtu.check_raises(get_val, RuntimeError, err_msg) + + def testAstypeComplexDowncast(self): + x = jnp.array(2.0+1.5j, dtype='complex64') + complex_downcast = lambda: x.astype('float32') + err_msg = "Casting from complex to non-complex " + jtu.check_raises(complex_downcast, ValueError, err_msg) + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)