Skip to content

Commit

Permalink
test_util: avoid overflow errors in NumPy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 7, 2024
1 parent 7004497 commit 77f030c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,9 @@ def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
to rand but scaled, converted to the appropriate dtype, and post-processed.
"""
if _dtypes.issubdtype(dtype, np.unsignedinteger):
r = lambda: np.asarray(scale * abs(rand(*_dims_of_shape(shape))), dtype)
r = lambda: np.asarray(scale * abs(rand(*_dims_of_shape(shape)))).astype(dtype)
else:
r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype)
r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)).astype(dtype)
if _dtypes.issubdtype(dtype, np.complexfloating):
vals = r() + 1.0j * r()
else:
Expand Down

0 comments on commit 77f030c

Please sign in to comment.