Skip to content

Commit

Permalink
Added device to jnp.arange, jnp.eye and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jun 25, 2024
1 parent 543621e commit 737246a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 10 deletions.
40 changes: 37 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3555,8 +3555,25 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))


@util.implements(np.eye)
@util.implements(np.eye, extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
""")
def eye(N: DimSize, M: DimSize | None = None,
k: int | ArrayLike = 0,
dtype: DTypeLike | None = None,
*, device: xc.Device | Sharding | None = None) -> Array:
# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _eye(N, M=M, k=k, dtype=dtype)
if device is not None:
return jax.device_put(output, device=device)
return output


def _eye(N: DimSize, M: DimSize | None = None,
k: int | ArrayLike = 0,
dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "eye")
Expand All @@ -3581,7 +3598,7 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
return eye(n, dtype=dtype)


@util.implements(np.arange,lax_description= """
@util.implements(np.arange, lax_description= """
.. note::
Using ``arange`` with the ``step`` argument can lead to precision errors,
Expand All @@ -3590,8 +3607,25 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
To avoid precision errors, consider using an expression like
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
and then convert it to the desired lower precision.
""")
""", extra_params="""
device : :py:class:`Device`, :py:class:`Sharding`, optional
The (optional) :py:class:`Device`, :py:class:`Sharding`,
representing the device(s) to which created array should be
transferred. If given, then the result is committed to the device(s).
"""
)
def arange(start: DimSize, stop: DimSize | None = None,
step: DimSize | None = None, dtype: DTypeLike | None = None,
*, device: xc.Device | Sharding | None = None) -> Array:
# TODO(vfdev-5): optimize putting the array directly on the device specified
# instead of putting it on default device and then on the specific device
output = _arange(start, stop=stop, step=step, dtype=dtype)
if device is not None:
return jax.device_put(output, device=device)
return output


def _arange(start: DimSize, stop: DimSize | None = None,
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "arange")
if not config.dynamic_shapes.value:
Expand Down
6 changes: 4 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def arange(
start: DimSize,
stop: Optional[DimSize] = ...,
step: Optional[DimSize] = ...,
dtype: Optional[DTypeLike] = ...,
dtype: Optional[DTypeLike] = ..., *,
device: _Device | _Sharding | None = ...,
) -> Array: ...
def arccos(x: ArrayLike, /) -> Array: ...
def arccosh(x: ArrayLike, /) -> Array: ...
Expand Down Expand Up @@ -352,7 +353,8 @@ def expm1(x: ArrayLike, /) -> Array: ...
def extract(condition: ArrayLike, arr: ArrayLike, *,
size: int | None = None, fill_value: ArrayLike = 0) -> Array: ...
def eye(N: DimSize, M: Optional[DimSize] = ..., k: int | ArrayLike = ...,
dtype: Optional[DTypeLike] = ...) -> Array: ...
dtype: Optional[DTypeLike] = ..., *,
device: _Device | _Sharding | None = ...) -> Array: ...
def fabs(x: ArrayLike, /) -> Array: ...
finfo = _dtypes.finfo
def fix(x: ArrayLike, out: None = ...) -> Array: ...
Expand Down
42 changes: 37 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,30 @@ def testArrayCreationWithSharding(self, func, shape, dtype):
out = func(**kwds, shape=shape, dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithDevice(self, func, dtype):
device = jax.devices()[-1]
out = func(dtype=dtype, device=device)
self.assertEqual(out.devices(), {device})

@jtu.sample_product(
func=[
lambda dtype, device: jnp.arange(5, dtype=dtype, device=device),
lambda dtype, device: jnp.eye(5, 6, dtype=dtype, device=device),
],
dtype=default_dtypes,
)
def testArangeEyeWithSharding(self, func, dtype):
sharding = SingleDeviceSharding(jax.devices()[-1])
out = func(dtype=dtype, device=sharding)
self.assertEqual(out.sharding, sharding)

@jtu.sample_product(
func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like],
shape=array_shapes,
Expand Down Expand Up @@ -4796,11 +4820,19 @@ def testArangeJit(self):
expected = jtu.with_jax_dtype_defaults(np.arange)(5)
self.assertAllClose(ans, expected)

@jtu.sample_product(args=[(5,), (0, 5)])
@jtu.sample_product(
args=[(5,), (0, 5)],
)
def testArangeJaxpr(self, args):
jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))()
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)
for device in [None, jax.devices()[-1]]:
kwargs = {"device": device}
jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args, **kwargs))()
# We have 2 statements in jaxpr:
# [a:i32[5] = iota[dimension=0 dtype=int32 shape=(5,)],
# a:i32[5] = device_put[devices=[None] srcs=[None]] b]
num_eqs = 2 if device is not None else 1
self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p)

def testIssue830(self):
a = jnp.arange(4, dtype=jnp.complex64)
Expand Down Expand Up @@ -5941,7 +5973,7 @@ def testWrappedSignaturesMatch(self):
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
'eye': ['device', 'order', 'like'],
'eye': ['order', 'like'],
'hstack': ['casting'],
'identity': ['like'],
'isin': ['kind'],
Expand Down

0 comments on commit 737246a

Please sign in to comment.