diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b1c40e3f410..39f935422b44 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -44,6 +44,7 @@ from jax import jit from jax import errors from jax import lax +from jax.sharding import Sharding, SingleDeviceSharding from jax.tree_util import tree_leaves, tree_flatten, tree_map from jax._src import api_util @@ -58,6 +59,7 @@ from jax._src.lax.lax import (_array_copy, _sort_lt_comparator, _sort_le_comparator, PrecisionLike) from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util @@ -2257,16 +2259,28 @@ def empty_like(prototype: ArrayLike | DuckTypedArray, return zeros_like(prototype, dtype=dtype, shape=shape) +def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array: + return arr if device is None else jax.device_put(arr, device) + +def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None: + if isinstance(device, xc.Device): + return SingleDeviceSharding(device) + else: + return device + + @util._wraps(np.full) def full(shape: Any, fill_value: ArrayLike, - dtype: DTypeLike | None = None) -> Array: + dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "full") util.check_arraylike("full", fill_value) + if ndim(fill_value) == 0: shape = canonicalize_shape(shape) - return lax.full(shape, fill_value, dtype) + return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device)) else: - return broadcast_to(asarray(fill_value, dtype=dtype), shape) + return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device) @util._wraps(np.full_like) @@ -2289,30 +2303,33 @@ def full_like(a: ArrayLike | DuckTypedArray, @util._wraps(np.zeros) -def zeros(shape: Any, dtype: DTypeLike | None = None) -> Array: +def zeros(shape: Any, dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m) dtypes.check_user_dtype_supported(dtype, "zeros") shape = canonicalize_shape(shape) - return lax.full(shape, 0, _jnp_dtype(dtype)) + return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) @util._wraps(np.ones) -def ones(shape: Any, dtype: DTypeLike | None = None) -> Array: +def ones(shape: Any, dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: if isinstance(shape, types.GeneratorType): raise TypeError("expected sequence object with len >= 0 or a single integer") if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m) shape = canonicalize_shape(shape) dtypes.check_user_dtype_supported(dtype, "ones") - return lax.full(shape, 1, _jnp_dtype(dtype)) + return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device)) @util._wraps(np.empty, lax_description="""\ Because XLA cannot create uninitialized arrays, the JAX version will return an array initialized with zeros.""") -def empty(shape: Any, dtype: DTypeLike | None = None) -> Array: +def empty(shape: Any, dtype: DTypeLike | None = None, *, + device: xc.Device | Sharding | None = None) -> Array: if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m) dtypes.check_user_dtype_supported(dtype, "empty") - return zeros(shape, dtype) + return zeros(shape, dtype, device=device) def _check_forgot_shape_tuple(name, shape, dtype) -> str | None: # type: ignore if isinstance(dtype, int) and isinstance(shape, int): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index bdff69ca57ad..d9411ee1a5fe 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -10,12 +10,16 @@ from jax._src.lax.slicing import GatherScatterMode from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape from jax.numpy import fft as fft, linalg as linalg +from jax.sharding import Sharding as _Sharding import numpy as _np _T = TypeVar('_T') _Axis = Union[None, int, Sequence[int]] +# TODO(jakevdp): use xla_client.Device here +_Device = Any + ComplexWarning: type _deprecations: dict[str, tuple[str, Any]] @@ -300,7 +304,8 @@ def einsum( ) -> Array: ... def einsum_path(subscripts, *operands, optimize = ...): ... -def empty(shape: Any, dtype: Optional[DTypeLike] = ...) -> Array: ... +def empty(shape: Any, dtype: Optional[DTypeLike] = ..., + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def empty_like(prototype: Union[ArrayLike, DuckTypedArray], dtype: Optional[DTypeLike] = ..., shape: Any = ...) -> Array: ... @@ -357,7 +362,8 @@ def fromstring( string: str, dtype: DTypeLike = ..., count: int = ..., *, sep: str ) -> Array: ... def full(shape: Any, fill_value: ArrayLike, - dtype: Optional[DTypeLike] = ...) -> Array: ... + dtype: Optional[DTypeLike] = ..., *, + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def full_like(a: Union[ArrayLike, DuckTypedArray], fill_value: ArrayLike, dtype: Optional[DTypeLike] = ..., shape: Any = ...) -> Array: ... @@ -599,7 +605,8 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: Optional[DTypeLike] = ...) -> Array: ... +def ones(shape: Any, dtype: Optional[DTypeLike] = ..., + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def ones_like(a: Union[ArrayLike, DuckTypedArray], dtype: Optional[DTypeLike] = ..., shape: Any = ...) -> Array: ... @@ -874,7 +881,8 @@ def where(condition: ArrayLike, x: Optional[ArrayLike] = ..., fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ... ) -> Union[Array, tuple[Array, ...]]: ... -def zeros(shape: Any, dtype: Optional[DTypeLike] = ...) -> Array: ... +def zeros(shape: Any, dtype: Optional[DTypeLike] = ..., + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def zeros_like(a: Union[ArrayLike, DuckTypedArray], dtype: Optional[DTypeLike] = ..., shape: Any = ...) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b4aeebca2b61..f2ada4fbc228 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -41,6 +41,7 @@ import jax.ops from jax import lax from jax import numpy as jnp +from jax.sharding import SingleDeviceSharding from jax import tree_util from jax.test_util import check_grads @@ -2827,6 +2828,28 @@ def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], + shape=array_shapes, + dtype=default_dtypes, + ) + def testArrayCreationWithDevice(self, func, shape, dtype): + device = jax.devices()[-1] + kwds = {'fill_value': 1} if func is jnp.full else {} + out = func(**kwds, shape=shape, dtype=dtype, device=device) + self.assertEqual(out.devices(), {device}) + + @jtu.sample_product( + func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], + shape=array_shapes, + dtype=default_dtypes, + ) + def testArrayCreationWithSharding(self, func, shape, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + kwds = {'fill_value': 1} if func is jnp.full else {} + out = func(**kwds, shape=shape, dtype=dtype, device=sharding) + self.assertEqual(out.sharding, sharding) + def testDuckTypedLike(self): x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32")) self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype)) @@ -5650,7 +5673,7 @@ def testWrappedSignaturesMatch(self): 'hstack': ['casting'], 'identity': ['like'], 'isin': ['kind'], - 'full': ['device', 'order', 'like'], + 'full': ['order', 'like'], 'full_like': ['device', 'subok', 'order'], 'fromfunction': ['like'], 'histogram': ['normed'], @@ -5661,7 +5684,7 @@ def testWrappedSignaturesMatch(self): 'nanquantile': ['weights'], 'nanstd': ['correction', 'mean'], 'nanvar': ['correction', 'mean'], - 'ones': ['device', 'order', 'like'], + 'ones': ['order', 'like'], 'ones_like': ['device', 'subok', 'order'], 'partition': ['kind', 'order'], 'percentile': ['weights'],