diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0b0e8b8ead35..be9b77034016 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2226,25 +2226,27 @@ def copy(a: ArrayLike, order: str | None = None) -> Array: @util.implements(np.zeros_like) def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, - shape: Any = None) -> Array: + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: shape = canonicalize_shape(shape) - return lax.full_like(a, 0, dtype, shape) + return lax.full_like(a, 0, dtype, shape, sharding=_normalize_to_sharding(device)) @util.implements(np.ones_like) def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, - shape: Any = None) -> Array: + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: shape = canonicalize_shape(shape) - return lax.full_like(a, 1, dtype, shape) + return lax.full_like(a, 1, dtype, shape, sharding=_normalize_to_sharding(device)) @util.implements(np.empty_like, lax_description="""\ @@ -2252,11 +2254,12 @@ def ones_like(a: ArrayLike | DuckTypedArray, return an array initialized with zeros.""") def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = None, - shape: Any = None) -> Array: + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing util.check_arraylike("empty_like", prototype) dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape) + return zeros_like(prototype, dtype=dtype, shape=shape, device=device) def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array: @@ -2286,7 +2289,8 @@ def full(shape: Any, fill_value: ArrayLike, @util.implements(np.full_like) def full_like(a: ArrayLike | DuckTypedArray, fill_value: ArrayLike, dtype: DTypeLike | None = None, - shape: Any = None) -> Array: + shape: Any = None, *, + device: xc.Device | Sharding | None = None) -> Array: if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing util.check_arraylike("full_like", 0, fill_value) else: @@ -2295,11 +2299,11 @@ def full_like(a: ArrayLike | DuckTypedArray, if shape is not None: shape = canonicalize_shape(shape) if ndim(fill_value) == 0: - return lax.full_like(a, fill_value, dtype, shape) + return lax.full_like(a, fill_value, dtype, shape, sharding=_normalize_to_sharding(device)) else: shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type] - return broadcast_to(asarray(fill_value, dtype=dtype), shape) + return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device) @util.implements(np.zeros) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d9411ee1a5fe..82ab2141f69e 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -308,7 +308,8 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = ..., device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def empty_like(prototype: Union[ArrayLike, DuckTypedArray], dtype: Optional[DTypeLike] = ..., - shape: Any = ...) -> Array: ... + shape: Any = ..., *, + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... euler_gamma: float def exp(x: ArrayLike, /) -> Array: ... @@ -366,7 +367,8 @@ def full(shape: Any, fill_value: ArrayLike, device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def full_like(a: Union[ArrayLike, DuckTypedArray], fill_value: ArrayLike, dtype: Optional[DTypeLike] = ..., - shape: Any = ...) -> Array: ... + shape: Any = ..., *, + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: ... generic = _np.generic def geomspace( @@ -609,7 +611,8 @@ def ones(shape: Any, dtype: Optional[DTypeLike] = ..., device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def ones_like(a: Union[ArrayLike, DuckTypedArray], dtype: Optional[DTypeLike] = ..., - shape: Any = ...) -> Array: ... + shape: Any = ..., *, + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def outer(a: ArrayLike, b: Array, out: None = ...) -> Array: ... def packbits( a: ArrayLike, axis: Optional[int] = ..., bitorder: str = ... @@ -885,6 +888,7 @@ def zeros(shape: Any, dtype: Optional[DTypeLike] = ..., device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def zeros_like(a: Union[ArrayLike, DuckTypedArray], dtype: Optional[DTypeLike] = ..., - shape: Any = ...) -> Array: ... + shape: Any = ..., *, + device: Optional[Union[_Device, _Sharding]] = ...) -> Array: ... def vectorize(pyfunc, *, excluded = ..., signature = ...) -> Callable: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 048f4fbf56bb..e29fdb59c0d2 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2850,6 +2850,44 @@ def testArrayCreationWithSharding(self, func, shape, dtype): out = func(**kwds, shape=shape, dtype=dtype, device=sharding) self.assertEqual(out.sharding, sharding) + @jtu.sample_product( + func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], + shape=array_shapes, + dtype=default_dtypes, + ) + def testFullLikeWithDevice(self, func, shape, dtype): + device = jax.devices()[-1] + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + kwds = {'fill_value': 1} if func is jnp.full_like else {} + + with self.subTest('device from keyword'): + out = func(x, **kwds, device=device) + self.assertEqual(out.devices(), {device}) + + with self.subTest('device from input array'): + out2 = func(x, **kwds) + self.assertEqual(out2.devices(), out.devices()) + + @jtu.sample_product( + func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], + shape=array_shapes, + dtype=default_dtypes, + ) + def testFullLikeWithSharding(self, func, shape, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype) + kwds = {'fill_value': 1} if func is jnp.full_like else {} + + with self.subTest('device from keyword'): + out = func(x, **kwds, device=sharding) + self.assertEqual(out.sharding, sharding) + + with self.subTest('device from input array'): + out2 = func(out, **kwds) + self.assertEqual(out2.devices(), out.devices()) + def testDuckTypedLike(self): x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32")) self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype)) @@ -5674,7 +5712,7 @@ def testWrappedSignaturesMatch(self): 'identity': ['like'], 'isin': ['kind'], 'full': ['order', 'like'], - 'full_like': ['device', 'subok', 'order'], + 'full_like': ['subok', 'order'], 'fromfunction': ['like'], 'histogram': ['normed'], 'histogram2d': ['normed'], @@ -5685,7 +5723,7 @@ def testWrappedSignaturesMatch(self): 'nanstd': ['correction', 'mean'], 'nanvar': ['correction', 'mean'], 'ones': ['order', 'like'], - 'ones_like': ['device', 'subok', 'order'], + 'ones_like': ['subok', 'order'], 'partition': ['kind', 'order'], 'percentile': ['weights'], 'quantile': ['weights'], @@ -5695,7 +5733,7 @@ def testWrappedSignaturesMatch(self): 'tri': ['like'], 'var': ['correction', 'mean'], 'vstack': ['casting'], - 'zeros_like': ['device', 'subok', 'order'] + 'zeros_like': ['subok', 'order'] } extra_params = {