diff --git a/scico/__init__.py b/scico/__init__.py index ec2f4f9a8..6ea6df49d 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -19,19 +19,6 @@ import jax, jaxlib -jax_ver_req = "0.3.0" -jaxlib_ver_req = "0.3.0" -if jax.__version__ < jax_ver_req: - raise RuntimeError( - f"SCICO {__version__} requires jax>={jax_ver_req}; got {jax.__version__}; " - "please upgrade jax." - ) -if jaxlib.__version__ < jaxlib_ver_req: - raise RuntimeError( - f"SCICO {__version__} requires jaxlib>={jaxlib_ver_req}; got {jaxlib.__version__}; " - "please upgrade jaxlib." - ) - from jax import custom_jvp, custom_vjp, jacfwd, jvp, linearize, vjp, hessian from . import numpy diff --git a/setup.py b/setup.py index d1434ddf4..dafa23d72 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,6 @@ import importlib.util import os import os.path -import re import site import sys @@ -17,7 +16,7 @@ module = importlib.util.module_from_spec(spec) sys.modules["_version"] = module spec.loader.exec_module(module) -from _version import init_variable_assign_value, package_version +from _version import package_version name = "scico" version = package_version() @@ -32,34 +31,6 @@ lines = f.readlines() install_requires = [line.strip() for line in lines] -# Check that jaxlib version requirements in __init__.py and requirements.txt match -jaxlib_ver = init_variable_assign_value("jaxlib_ver_req") -jaxlib_req_str = list(filter(lambda s: s.startswith("jaxlib"), install_requires))[0] -m = re.match("jaxlib[=<>]+([\d\.]+)", jaxlib_req_str) -if not m: - raise ValueError(f"Could not extract jaxlib version number from specification {jaxlib_req_str}") -req_jaxlib_ver = m[1] -if jaxlib_ver != req_jaxlib_ver: - raise ValueError( - f"Version requirements for jaxlib in __init__.py ({jaxlib_ver}) and " - f"requirements.txt ({req_jaxlib_ver}) do not match" - ) - -# Check that jax version requirements in __init__.py and requirements.txt match -jax_ver = init_variable_assign_value("jax_ver_req") -jax_req_str = list( - filter(lambda s: s.startswith("jax") and not s.startswith("jaxlib"), install_requires) -)[0] -m = re.match("jax[=<>]+([\d\.]+)", jax_req_str) -if not m: - raise ValueError(f"Could not extract jax version number from specification {jax_req_str}") -req_jax_ver = m[1] -if jax_ver != req_jax_ver: - raise ValueError( - f"Version requirements for jax in __init__.py ({jax_ver}) and " - f"requirements.txt ({req_jax_ver}) do not match" - ) - python_requires = ">=3.7" tests_require = ["pytest", "pytest-runner"]