From 2395e764dad11a5a339a029d3b1ba66168532107 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 20 Mar 2024 19:07:02 +0000 Subject: [PATCH] Add support for device kwarg in astype to match Array API --- CHANGELOG.md | 9 +++++++++ jax/_src/numpy/array_methods.py | 6 ++++-- jax/_src/numpy/lax_numpy.py | 24 ++++++++++++++++++---- tests/lax_numpy_test.py | 35 +++++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e318e12f1bd..925c1327cbec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.26 +* New Features + * {func}`jax.numpy.astype` supports new `device` keyword argument. + * Deprecations & Removals * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. @@ -28,6 +31,12 @@ Remember to align the itemized text with the first line of an item within a list `jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config` and `jax.extend.source_info_util` instead. +* Bug fixes + * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. + Previously, no copy would be made when the output array would have the same + dtype as the input array. This may result in some increased memory usage. + To prevent copying when possible, set `copy=False`. + ## jaxlib 0.4.26 ## jax 0.4.25 (Feb 26, 2024) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index fb1e52bd1be9..3611ed4aa6e5 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -31,11 +31,13 @@ import numpy as np import jax from jax import lax +from jax.sharding import Sharding from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -55,7 +57,7 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: +def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype) + return lax_numpy.astype(arr, dtype, copy=copy, device=device) def _nbytes(arr: ArrayLike) -> int: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 18a182948a7d..8b75f802ebff 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2216,13 +2216,29 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: implementation dependent. """) def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array: - src_devices = x.devices() if hasattr(x, "devices") and not isinstance(x, core.Tracer) else None - if device is not None and src_devices != {device}: - return device_put(x, device) - arr = _array_copy(x) if copy else x if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") + src_dtype = x.dtype if hasattr(x, "dtype") else dtypes.dtype(x) + if ( + src_dtype is not None + and dtypes.isdtype(src_dtype, "complex floating") + and dtypes.isdtype(dtype, ("integral", "real floating")) + ): + raise ValueError( + "Casting from complex to non-complex dtypes is not permitted. Please " + "first use jnp.real or jnp.imag to take the real/imaginary component of " + "your input." + ) + src_devices = ( + x.devices() if hasattr(x, "devices") + and not isinstance(x, core.Tracer) else None + ) + arr = x + if device is not None and src_devices != {device}: + arr = device_put(x, device) + elif copy: + arr = _array_copy(x) return lax.convert_element_type(arr, dtype) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ed3a16eda43a..f123fd714ffb 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3789,6 +3789,41 @@ def testAstype(self, from_dtype, to_dtype, use_method): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + change_dtype=[True, False], + copy=[True, False], + change_device=[True, False], + ) + def testAstypeCopy(self, change_dtype, copy, change_device): + if jax.device_count() == 1 and change_device: + raise unittest.SkipTest( + "Testing device transfer requires at least two available devices." + ) + + dtype = 'float32' if change_dtype else 'int32' + device = jax.devices()[-1] if change_device else None + expect_copy = change_dtype or copy or change_device + x = jnp.arange(5, dtype='int32') + y = x.astype(dtype, copy=copy, device=device) + + assert y.dtype == dtype + if change_device: + assert y.devices() == {device} + else: + y.delete() + get_val = lambda: np.array(x) + err_msg = "Array has been deleted" + if expect_copy: + get_val() + else: + jtu.check_raises(get_val, RuntimeError, err_msg) + + def testAstypeComplexDowncast(self): + x = jnp.array(2.0+1.5j, dtype='complex64') + complex_downcast = lambda: x.astype('float32') + err_msg = "Casting from complex to non-complex " + jtu.check_raises(complex_downcast, ValueError, err_msg) + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)