diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 607e89fdb8db..682b7599b7d2 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2382,7 +2382,7 @@ def _convert_elt_type_folding_rule(consts, eqn): isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended)): with warnings.catch_warnings(): - warnings.simplefilter('ignore', np.ComplexWarning) + warnings.simplefilter('ignore', util.NumpyComplexWarning) out = np.array(c).astype(eqn.params['new_dtype']) if not o.aval.weak_type: return [out], None