diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6b6671449531..ffcfc617c1f0 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3555,8 +3555,25 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) -@util.implements(np.eye) +@util.implements(np.eye, extra_params=""" +device : :py:class:`Device`, :py:class:`Sharding`, optional + The (optional) :py:class:`Device`, :py:class:`Sharding`, + representing the device(s) to which created array should be + transferred. If given, then the result is committed to the device(s). +""") def eye(N: DimSize, M: DimSize | None = None, + k: int | ArrayLike = 0, + dtype: DTypeLike | None = None, + *, device: xc.Device | Sharding | None = None) -> Array: + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _eye(N, M=M, k=k, dtype=dtype) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _eye(N: DimSize, M: DimSize | None = None, k: int | ArrayLike = 0, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "eye") @@ -3581,7 +3598,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: return eye(n, dtype=dtype) -@util.implements(np.arange,lax_description= """ +@util.implements(np.arange, lax_description= """ .. note:: Using ``arange`` with the ``step`` argument can lead to precision errors, @@ -3590,8 +3607,25 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array: To avoid precision errors, consider using an expression like ``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision and then convert it to the desired lower precision. -""") +""", extra_params=""" +device : :py:class:`Device`, :py:class:`Sharding`, optional + The (optional) :py:class:`Device`, :py:class:`Sharding`, + representing the device(s) to which created array should be + transferred. If given, then the result is committed to the device(s). +""" +) def arange(start: DimSize, stop: DimSize | None = None, + step: DimSize | None = None, dtype: DTypeLike | None = None, + *, device: xc.Device | Sharding | None = None) -> Array: + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _arange(start, stop=stop, step=step, dtype=dtype) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _arange(start: DimSize, stop: DimSize | None = None, step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "arange") if not config.dynamic_shapes.value: diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 3e306a24bb23..6d7f48408b57 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -58,7 +58,8 @@ def arange( start: DimSize, stop: Optional[DimSize] = ..., step: Optional[DimSize] = ..., - dtype: Optional[DTypeLike] = ..., + dtype: Optional[DTypeLike] = ..., *, + device: _Device | _Sharding | None = ..., ) -> Array: ... def arccos(x: ArrayLike, /) -> Array: ... def arccosh(x: ArrayLike, /) -> Array: ... @@ -352,7 +353,8 @@ def expm1(x: ArrayLike, /) -> Array: ... def extract(condition: ArrayLike, arr: ArrayLike, *, size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... def eye(N: DimSize, M: Optional[DimSize] = ..., k: int | ArrayLike = ..., - dtype: Optional[DTypeLike] = ...) -> Array: ... + dtype: Optional[DTypeLike] = ..., *, + device: _Device | _Sharding | None = ...) -> Array: ... def fabs(x: ArrayLike, /) -> Array: ... finfo = _dtypes.finfo def fix(x: ArrayLike, out: None = ...) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 77a2df9151e5..40fa1f287955 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2988,6 +2988,30 @@ def testArrayCreationWithSharding(self, func, shape, dtype): out = func(**kwds, shape=shape, dtype=dtype, device=sharding) self.assertEqual(out.sharding, sharding) + @jtu.sample_product( + func=[ + lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), + lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + ], + dtype=default_dtypes, + ) + def testArangeEyeWithDevice(self, func, dtype): + device = jax.devices()[-1] + out = func(dtype=dtype, device=device) + self.assertEqual(out.devices(), {device}) + + @jtu.sample_product( + func=[ + lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), + lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + ], + dtype=default_dtypes, + ) + def testArangeEyeWithSharding(self, func, dtype): + sharding = SingleDeviceSharding(jax.devices()[-1]) + out = func(dtype=dtype, device=sharding) + self.assertEqual(out.sharding, sharding) + @jtu.sample_product( func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], shape=array_shapes, @@ -4796,11 +4820,19 @@ def testArangeJit(self): expected = jtu.with_jax_dtype_defaults(np.arange)(5) self.assertAllClose(ans, expected) - @jtu.sample_product(args=[(5,), (0, 5)]) + @jtu.sample_product( + args=[(5,), (0, 5)], + ) def testArangeJaxpr(self, args): - jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) - self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + for device in [None, jax.devices()[-1]]: + kwargs = {"device": device} + jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args, **kwargs))() + # We have 2 statements in jaxpr: + # [a:i32[5] = iota[dimension=0 dtype=int32 shape=(5,)], + # a:i32[5] = device_put[devices=[None] srcs=[None]] b] + num_eqs = 2 if device is not None else 1 + self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) def testIssue830(self): a = jnp.arange(4, dtype=jnp.complex64) @@ -5941,7 +5973,7 @@ def testWrappedSignaturesMatch(self): 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], 'einsum_path': ['einsum_call'], - 'eye': ['device', 'order', 'like'], + 'eye': ['order', 'like'], 'hstack': ['casting'], 'identity': ['like'], 'isin': ['kind'],