From 8d5cc7f05d30337f6efac108df6098c7fca540ec Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 25 Jun 2024 10:38:37 +0200 Subject: [PATCH] Added device kwargs to jnp.linspace, jnp.array --- jax/_src/numpy/lax_numpy.py | 48 ++++++++++++++++++++++++++++++------- jax/numpy/__init__.pyi | 15 ++++++++---- tests/lax_numpy_test.py | 27 +++++++++++++++------ 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ffcfc617c1f0..66c04ba2ed81 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3157,8 +3157,24 @@ def _supports_buffer_protocol(obj): deprecations.register("jax-numpy-array-none") -@util.implements(np.array, lax_description=_ARRAY_DOC) +@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params=""" +device : :py:class:`Device`, :py:class:`Sharding`, optional + The (optional) :py:class:`Device`, :py:class:`Sharding`, + representing the device(s) to which created array should be + transferred. If given, then the result is committed to the device(s). +""") def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, + order: str | None = "K", ndmin: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _array(object, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + if device is not None: + return jax.device_put(output, device=device) + return output + + +def _array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0) -> Array: if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") @@ -3693,30 +3709,46 @@ def _arange_dynamic( def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: Literal[False] = False, dtype: DTypeLike | None = None, - axis: int = 0) -> Array: ... + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: bool, retstep: Literal[True], dtype: DTypeLike | None = None, - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, *, retstep: Literal[True], dtype: DTypeLike | None = None, - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, - axis: int = 0) -> Array | tuple[Array, Array]: ... -@util.implements(np.linspace) + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@util.implements(np.linspace, extra_params=""" +device : :py:class:`Device`, :py:class:`Sharding`, optional + The (optional) :py:class:`Device`, :py:class:`Sharding`, + representing the device(s) to which created array should be + transferred. If given, then the result is committed to the device(s). +""") def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, - axis: int = 0) -> Array | tuple[Array, Array]: + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") - return _linspace(start, stop, num, endpoint, retstep, dtype, axis) + + # TODO(vfdev-5): optimize putting the array directly on the device specified + # instead of putting it on default device and then on the specific device + output = _linspace(start, stop, num, endpoint, retstep, dtype, axis) + if device is not None: + return jax.device_put(output, device=device) + return output @partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis')) def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 6d7f48408b57..04b5f2af20a9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -98,7 +98,8 @@ def argwhere( ) -> Array: ... around = round def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, - order: str | None = ..., ndmin: int = ...) -> Array: ... + order: str | None = ..., ndmin: int = ..., + *, device: _Device | _Sharding | None = ...) -> Array: ... def array_equal( a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ... ) -> Array: ... @@ -521,22 +522,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ... def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: Literal[False] = False, dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> Array: ... + axis: int = 0, + *, device: _Device | _Sharding | None = ...) -> Array: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int, endpoint: builtins.bool, retstep: Literal[True], dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + *, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, *, retstep: Literal[True], dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> tuple[Array, Array]: ... + axis: int = 0, + device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ... @overload def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: builtins.bool = True, retstep: builtins.bool = False, dtype: Optional[DTypeLike] = ..., - axis: int = 0) -> Union[Array, tuple[Array, Array]]: ... + axis: int = 0, + *, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ... def load(*args: Any, **kwargs: Any) -> Array: ... def log(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 40fa1f287955..452e166fbc33 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2992,25 +2992,39 @@ def testArrayCreationWithSharding(self, func, shape, dtype): func=[ lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device), + lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device), ], dtype=default_dtypes, ) - def testArangeEyeWithDevice(self, func, dtype): + def testArangeEyeLinspaceArrayWithDevice(self, func, dtype): device = jax.devices()[-1] - out = func(dtype=dtype, device=device) - self.assertEqual(out.devices(), {device}) + output = func(dtype=dtype, device=device) + if isinstance(output, tuple): + for out in output: + self.assertEqual(out.devices(), {device}) + else: + self.assertEqual(output.devices(), {device}) @jtu.sample_product( func=[ lambda dtype, device: jnp.arange(5, dtype=dtype, device=device), lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device), + lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device), + lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device), ], dtype=default_dtypes, ) - def testArangeEyeWithSharding(self, func, dtype): + def testArangeEyeLinspaceArrayWithSharding(self, func, dtype): sharding = SingleDeviceSharding(jax.devices()[-1]) - out = func(dtype=dtype, device=sharding) - self.assertEqual(out.sharding, sharding) + output = func(dtype=dtype, device=sharding) + if isinstance(output, tuple): + for out in output: + self.assertEqual(out.sharding, sharding) + else: + self.assertEqual(output.sharding, sharding) @jtu.sample_product( func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], @@ -5983,7 +5997,6 @@ def testWrappedSignaturesMatch(self): 'histogram': ['normed'], 'histogram2d': ['normed'], 'histogramdd': ['normed'], - 'linspace': ['device'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], 'nanstd': ['correction', 'mean'],