diff --git a/setup.py b/setup.py index 8f558ccf0434..9cbcd0d950e8 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def load_version_module(pkg_path): python_requires='>=3.10', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', - 'ml_dtypes>=0.2.0', + 'ml_dtypes>=0.4.0', 'numpy>=1.24', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum',