Skip to content

Commit

Permalink
Use equality for extended dtypes in isclose and add test for `jax.r…
Browse files Browse the repository at this point in the history
…andom.key`.
  • Loading branch information
tillahoffmann committed Feb 23, 2024
1 parent 6348a54 commit 062dc8d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
5 changes: 5 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,11 @@ def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -
@partial(jit, static_argnames=('equal_nan',))
def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08,
equal_nan: bool = False) -> Array:
a, b = util.promote_args("isclose", a, b)
dtype = _dtype(a)
if dtypes.issubdtype(dtype, dtypes.extended):
return lax.eq(a, b)

a, b = util.promote_args_inexact("isclose", a, b)
dtype = _dtype(a)
if issubdtype(dtype, complexfloating):
Expand Down
2 changes: 2 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3533,6 +3533,8 @@ def testIsClose(self):
self.assertTrue(jnp.all(jnp.equal(result_np, result_jit)))

self.assertEqual(np.isclose(6, 10, rtol=0.5), jnp.isclose(6, 10, rtol=0.5))
key = jax.random.key(0)
self.assertTrue(jnp.isclose(key, key))

@jtu.sample_product(
x=[1, [1], [1, 1 + 1E-4], [1, np.nan]],
Expand Down

0 comments on commit 062dc8d

Please sign in to comment.