Skip to content

Commit

Permalink
Added device kwargs to jnp.linspace, jnp.array
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jun 25, 2024
1 parent 737246a commit 8d5cc7f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 20 deletions.
48 changes: 40 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3157,8 +3157,24 @@ def _supports_buffer_protocol(obj):

deprecations.register("jax-numpy-array-none")

@util.implements(np.array, lax_description=_ARRAY_DOC)
@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
""")
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _array(object, dtype=dtype, copy=copy, order=order, ndmin=ndmin)
if device is not None:
return jax.device_put(output, device=device)
return output


def _array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
Expand Down Expand Up @@ -3693,30 +3709,46 @@ def _arange_dynamic(
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: bool, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace)
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace, extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
""")
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)

# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _linspace(start, stop, num, endpoint, retstep, dtype, axis)
if device is not None:
return jax.device_put(output, device=device)
return output

@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
Expand Down
15 changes: 10 additions & 5 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def argwhere(
) -> Array: ...
around = round
def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True,
order: str | None = ..., ndmin: int = ...) -> Array: ...
order: str | None = ..., ndmin: int = ...,
*, device: _Device | _Sharding | None = ...) -> Array: ...
def array_equal(
a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ...
) -> Array: ...
Expand Down Expand Up @@ -521,22 +522,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ...
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: Literal[False] = False,
dtype: Optional[DTypeLike] = ...,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: builtins.bool, retstep: Literal[True],
dtype: Optional[DTypeLike] = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, *, retstep: Literal[True],
dtype: Optional[DTypeLike] = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: builtins.bool = False,
dtype: Optional[DTypeLike] = ...,
axis: int = 0) -> Union[Array, tuple[Array, Array]]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...

def load(*args: Any, **kwargs: Any) -> Array: ...
def log(x: ArrayLike, /) -> Array: ...
Expand Down
27 changes: 20 additions & 7 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2992,25 +2992,39 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithDevice(self, func, dtype):
def testArangeEyeLinspaceArrayWithDevice(self, func, dtype):
device = jax.devices()[-1]
out = func(dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})
output = func(dtype=dtype, device=device)
if isinstance(output, tuple):
for out in output:
self.assertEqual(out.devices(), {device})
else:
self.assertEqual(output.devices(), {device})

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithSharding(self, func, dtype):
def testArangeEyeLinspaceArrayWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
out = func(dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)
output = func(dtype=dtype, device=sharding)
if isinstance(output, tuple):
for out in output:
self.assertEqual(out.sharding, sharding)
else:
self.assertEqual(output.sharding, sharding)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
Expand Down Expand Up @@ -5983,7 +5997,6 @@ def testWrappedSignaturesMatch(self):
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'linspace': ['device'],
'nanpercentile': ['weights'],
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],
Expand Down

0 comments on commit 8d5cc7f

Please sign in to comment.