Skip to content

Commit

Permalink
Merge pull request #358 from lanl/brendt/jax-ver-test
Browse files Browse the repository at this point in the history
Remove `jaxlib`/`jax` version tests in `scico/__init__.py`
  • Loading branch information
crstngc authored Oct 25, 2022
2 parents 7498b95 + 46a548a commit 9314695
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 43 deletions.
13 changes: 0 additions & 13 deletions scico/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,6 @@

import jax, jaxlib

jax_ver_req = "0.3.0"
jaxlib_ver_req = "0.3.0"
if jax.__version__ < jax_ver_req:
raise RuntimeError(
f"SCICO {__version__} requires jax>={jax_ver_req}; got {jax.__version__}; "
"please upgrade jax."
)
if jaxlib.__version__ < jaxlib_ver_req:
raise RuntimeError(
f"SCICO {__version__} requires jaxlib>={jaxlib_ver_req}; got {jaxlib.__version__}; "
"please upgrade jaxlib."
)

from jax import custom_jvp, custom_vjp, jacfwd, jvp, linearize, vjp, hessian

from . import numpy
Expand Down
31 changes: 1 addition & 30 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import importlib.util
import os
import os.path
import re
import site
import sys

Expand All @@ -17,7 +16,7 @@
module = importlib.util.module_from_spec(spec)
sys.modules["_version"] = module
spec.loader.exec_module(module)
from _version import init_variable_assign_value, package_version
from _version import package_version

name = "scico"
version = package_version()
Expand All @@ -32,34 +31,6 @@
lines = f.readlines()
install_requires = [line.strip() for line in lines]

# Check that jaxlib version requirements in __init__.py and requirements.txt match
jaxlib_ver = init_variable_assign_value("jaxlib_ver_req")
jaxlib_req_str = list(filter(lambda s: s.startswith("jaxlib"), install_requires))[0]
m = re.match("jaxlib[=<>]+([\d\.]+)", jaxlib_req_str)
if not m:
raise ValueError(f"Could not extract jaxlib version number from specification {jaxlib_req_str}")
req_jaxlib_ver = m[1]
if jaxlib_ver != req_jaxlib_ver:
raise ValueError(
f"Version requirements for jaxlib in __init__.py ({jaxlib_ver}) and "
f"requirements.txt ({req_jaxlib_ver}) do not match"
)

# Check that jax version requirements in __init__.py and requirements.txt match
jax_ver = init_variable_assign_value("jax_ver_req")
jax_req_str = list(
filter(lambda s: s.startswith("jax") and not s.startswith("jaxlib"), install_requires)
)[0]
m = re.match("jax[=<>]+([\d\.]+)", jax_req_str)
if not m:
raise ValueError(f"Could not extract jax version number from specification {jax_req_str}")
req_jax_ver = m[1]
if jax_ver != req_jax_ver:
raise ValueError(
f"Version requirements for jax in __init__.py ({jax_ver}) and "
f"requirements.txt ({req_jax_ver}) do not match"
)

python_requires = ">=3.7"
tests_require = ["pytest", "pytest-runner"]

Expand Down

0 comments on commit 9314695

Please sign in to comment.