diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 189795f29bea..e7f8abba06b3 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -575,9 +575,9 @@ def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x): to rand but scaled, converted to the appropriate dtype, and post-processed. """ if _dtypes.issubdtype(dtype, np.unsignedinteger): - r = lambda: np.asarray(scale * abs(rand(*_dims_of_shape(shape))), dtype) + r = lambda: np.asarray(scale * abs(rand(*_dims_of_shape(shape)))).astype(dtype) else: - r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype) + r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)).astype(dtype) if _dtypes.issubdtype(dtype, np.complexfloating): vals = r() + 1.0j * r() else: