From 7837e90ceb6e94b5e233183c5144ac1f1c9c6a65 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 19 Apr 2024 13:53:48 +0000 Subject: [PATCH] Add support for copy kwarg in astype to match Array API --- CHANGELOG.md | 7 +++ jax/_src/numpy/array_methods.py | 6 ++- jax/_src/numpy/lax_numpy.py | 48 ++++++++++++++++--- .../array_api/_data_type_functions.py | 20 +++++++- jax/numpy/__init__.pyi | 4 +- tests/lax_numpy_test.py | 20 ++++++++ 6 files changed, 94 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 690c50baf993..3ca7df459433 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,13 @@ Remember to align the itemized text with the first line of an item within a list * Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and related functions now raise an error, following a similar change in NumPy. +* Bug fixes + * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. + Previously, no copy would be made when the output array would have the same + dtype as the input array. This may result in some increased memory usage. + To prevent copying when possible, set `copy=None`. To error when a copy is + required, set `copy=False`. + ## jaxlib 0.4.27 ## jax 0.4.26 (April 3, 2024) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 98eea8887198..831fef33627c 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -31,11 +31,13 @@ import numpy as np import jax from jax import lax +from jax.sharding import Sharding from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -55,7 +57,7 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: +def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype) + return lax_numpy.astype(arr, dtype, copy=copy, device=device) def _nbytes(arr: ArrayLike) -> int: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3767633deefc..5103bb522275 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2272,17 +2272,53 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) -def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: +def astype(x: ArrayLike, dtype: DTypeLike | None, + /, *, copy: bool | DeprecatedArg = DeprecatedArg(), + device: xc.Device | Sharding | None = None) -> Array: util.check_arraylike("astype", x) x_arr = asarray(x) - del copy # unused in JAX + + # TODO(micky774): Deprecated 2024-4-19, remove after deprecation completed. + if isinstance(copy, DeprecatedArg): + warnings.warn( + "The copy keyword of astype was previously ignored but is now " + "implemented. The default of copy=True will lead to the expected " + "behavior of creating a copy, which may potentially lead to extra " + "memory usage. To preserve previous behavior, use copy=False. To " + "suppress this warning, please explicitly set copy.", + DeprecationWarning, stacklevel=2) + copy = False + if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") - # convert_element_type(complex, bool) has the wrong semantics. - if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating): - return (x_arr != _lax_const(x_arr, 0)) - return lax.convert_element_type(x_arr, dtype) + if issubdtype(x_arr.dtype, complexfloating): + if dtypes.isdtype(dtype, ("integral", "real floating")): + warnings.warn( + "Casting from complex to real dtypes will soon raise a ValueError. " + "Please first use jnp.real or jnp.imag to take the real/imaginary " + "component of your input.", + DeprecationWarning, stacklevel=2 + ) + elif np.dtype(dtype) == bool: + # convert_element_type(complex, bool) has the wrong semantics. + x_arr = (x_arr != _lax_const(x_arr, 0)) + + # We offer a more specific warning than the usual ComplexWarning so we prefer + # to issue our warning. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ComplexWarning) + return _place_array( + lax.convert_element_type(x_arr, dtype), + device=device, copy=copy, + ) + +def _place_array(x, device=None, copy=None): + # TODO(micky774): Implement in future PRs as we formalize device placement + # semantics + if copy: + return _array_copy(x) + return x @util.implements(np.asarray, lax_description=_ARRAY_DOC) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 4f72fcba29d0..770d264c1c07 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import builtins import functools from typing import NamedTuple @@ -19,6 +21,9 @@ import jax.numpy as jnp +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding +from jax._src import dtypes as _dtypes from jax.experimental.array_api._dtypes import ( bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -124,8 +129,19 @@ def _promote_types(t1, t2): raise ValueError("No promotion path for {t1} & {t2}") -def astype(x, dtype, /, *, copy=True): - return jnp.array(x, dtype=dtype, copy=copy) +def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): + src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) + if ( + src_dtype is not None + and _dtypes.isdtype(src_dtype, "complex floating") + and _dtypes.isdtype(dtype, ("integral", "real floating")) + ): + raise ValueError( + "Casting from complex to non-complex dtypes is not permitted. Please " + "first use jnp.real or jnp.imag to take the real/imaginary component of " + "your input." + ) + return jnp.astype(x, dtype, copy=copy, device=device) def can_cast(from_, to, /): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 2740638041cd..75c97a481f8b 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -13,6 +13,8 @@ from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape, DeprecatedArg ) +from jax._src.sharding import Sharding +from jax._src.lib import xla_client as xc from jax.numpy import fft as fft, linalg as linalg from jax.sharding import Sharding as _Sharding import numpy as _np @@ -115,7 +117,7 @@ def asarray( ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ... +def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ddc599792e63..707366d99271 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3870,6 +3870,26 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + change_dtype=[True, False], + copy=[True, False], + ) + def testAstypeCopy(self, change_dtype, copy): + dtype = 'float32' if change_dtype else 'int32' + expect_copy = change_dtype or copy + x = jnp.arange(5, dtype='int32') + y = x.astype(dtype, copy=copy) + + assert y.dtype == dtype + y.delete() + assert x.is_deleted() != expect_copy + + def testAstypeComplexDowncast(self): + x = jnp.array(2.0+1.5j, dtype='complex64') + msg = "Casting from complex to non-complex dtypes will soon raise " + with self.assertWarns(DeprecationWarning, msg=msg): + x.astype('float32') + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)