Skip to content

Commit

Permalink
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jul 25, 2024
1 parent f17d0f3 commit eeb71ef
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 61 deletions.
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def convert_element_type(operand: ArrayLike,
Similar to a C++ `static_cast`.
Args:
operand: an array or scalar value to be cast
operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type.
Returns:
Expand Down
50 changes: 32 additions & 18 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3351,7 +3351,10 @@ def _supports_buffer_protocol(obj):

deprecations.register("jax-numpy-array-none")

@util.implements(np.array, lax_description=_ARRAY_DOC)
@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params="""
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
Expand Down Expand Up @@ -3453,7 +3456,6 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
out = np.array(object) if copy else np.asarray(object)
else:
raise TypeError(f"Unexpected input type for array: {type(object)}")

out_array: Array = lax_internal._convert_element_type(
out, dtype, weak_type=weak_type, sharding=sharding)
if ndmin > ndim(out_array):
Expand Down Expand Up @@ -3544,9 +3546,13 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
return _array_copy(result) if copy else result


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params="""
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
*, copy: bool | None = None) -> Array:
*, copy: bool | None = None,
device: xc.Device | Sharding | None = None) -> Array:
# For copy=False, the array API specifies that we raise a ValueError if the input supports
# the buffer protocol but a copy is required. Since array() supports the buffer protocol
# via numpy, this is only the case when the default device is not 'cpu'
Expand All @@ -3559,7 +3565,7 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
dtypes.check_user_dtype_supported(dtype, "asarray")
if dtype is not None:
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
return array(a, dtype=dtype, copy=bool(copy), order=order)
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)


@util.implements(np.copy, lax_description=_ARRAY_DOC)
Expand Down Expand Up @@ -4329,36 +4335,45 @@ def _arange_dynamic(
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: bool, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = None,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace)
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ...
@util.implements(np.linspace, extra_params="""
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
""")
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)

@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device'))
def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: bool = False,
dtype: DTypeLike | None = None,
axis: int = 0) -> Array | tuple[Array, Array]:
axis: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
"""Implementation of linspace differentiable in start and stop args."""
dtypes.check_user_dtype_supported(dtype, "linspace")
if num < 0:
Expand Down Expand Up @@ -4406,10 +4421,9 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer):
out = lax.floor(out)

if retstep:
return lax.convert_element_type(out, dtype), delta
else:
return lax.convert_element_type(out, dtype)
sharding = canonicalize_device_to_sharding(device)
result = lax_internal._convert_element_type(out, dtype, sharding=sharding)
return (result, delta) if retstep else result


@util.implements(np.logspace)
Expand Down
7 changes: 2 additions & 5 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
argmax as argmax,
argmin as argmin,
argsort as argsort,
asarray as asarray,
asin as asin,
asinh as asinh,
atan as atan,
Expand Down Expand Up @@ -109,6 +110,7 @@
isnan as isnan,
less as less,
less_equal as less_equal,
linspace as linspace,
log as log,
log10 as log10,
log1p as log1p,
Expand Down Expand Up @@ -187,11 +189,6 @@
reshape as reshape,
)

from jax.experimental.array_api._creation_functions import (
asarray as asarray,
linspace as linspace,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
)
Expand Down
25 changes: 0 additions & 25 deletions jax/experimental/array_api/_creation_functions.py

This file was deleted.

15 changes: 10 additions & 5 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def array_split(
array_str = _np.array_str
def asarray(
a: Any, dtype: DTypeLike | None = ..., order: str | None = ...,
*, copy: builtins.bool | None = ...
*, copy: builtins.bool | None = ...,
device: _Device | _Sharding | None = ...,
) -> Array: ...
def asin(x: ArrayLike, /) -> Array: ...
def asinh(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -523,22 +524,26 @@ def lexsort(keys: Sequence[ArrayLike], axis: int = ...) -> Array: ...
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: Literal[False] = False,
dtype: DTypeLike | None = ...,
axis: int = 0) -> Array: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Array: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int,
endpoint: builtins.bool, retstep: Literal[True],
dtype: DTypeLike | None = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, *, retstep: Literal[True],
dtype: DTypeLike | None = ...,
axis: int = 0) -> tuple[Array, Array]: ...
axis: int = 0,
device: _Device | _Sharding | None = ...) -> tuple[Array, Array]: ...
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: builtins.bool = True, retstep: builtins.bool = False,
dtype: DTypeLike | None = ...,
axis: int = 0) -> Array | tuple[Array, Array]: ...
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...

def load(*args: Any, **kwargs: Any) -> Array: ...
def log(x: ArrayLike, /) -> Array: ...
Expand Down
19 changes: 12 additions & 7 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2998,25 +2998,31 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithDevice(self, func, dtype):
def testArangeEyeLinspaceArrayWithDevice(self, func, dtype):
device = jax.devices()[-1]
out = func(dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})
output = func(dtype=dtype, device=device)
jax.tree.map(lambda x: self.assertEqual(x.devices(), {device}), output)

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, dtype=dtype, device=device),
lambda dtype, device: jnp.linspace(5, 6, 7, retstep=True, dtype=dtype, device=device),
lambda dtype, device: jnp.array([1, 2, 3, 4, 5], dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithSharding(self, func, dtype):
def testArangeEyeLinspaceArrayWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
out = func(dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)
output = func(dtype=dtype, device=sharding)
jax.tree.map(lambda x: self.assertEqual(x.sharding, sharding), output)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
Expand Down Expand Up @@ -6066,7 +6072,6 @@ def testWrappedSignaturesMatch(self):
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],
'linspace': ['device'],
'nanpercentile': ['weights'],
'nanquantile': ['weights'],
'nanstd': ['correction', 'mean'],
Expand Down

0 comments on commit eeb71ef

Please sign in to comment.