Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax windows restrictions #3955

Merged
merged 10 commits into from
Apr 2, 2024
69 changes: 23 additions & 46 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,21 @@ def run_coverage(session):
"""Run the coverage tests and generate an XML report."""
set_environment_variables(PYBAMM_ENV, session=session)
session.install("coverage", silent=False)
if sys.platform != "win32":
session.install("-e", ".[all,dev,jax]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
session.install("-e", ".[all,dev,jax]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit")


@nox.session(name="integration")
def run_integration(session):
"""Run the integration tests."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,dev,jax]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
session.install("-e", ".[all,dev,jax]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("python", "run-tests.py", "--integration")


Expand All @@ -96,13 +90,10 @@ def run_doctests(session):
def run_unit(session):
"""Run the unit tests."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,dev,jax]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
session.install("-e", ".[all,dev,jax]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("python", "run-tests.py", "--unit")


Expand Down Expand Up @@ -138,50 +129,36 @@ def set_dev(session):
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
session.run(python, "-m", "pip", "install", "setuptools", external=True)
if sys.platform == "linux":
if sys.version_info < (3, 9):
session.run(
python,
"-m",
"pip",
"install",
"-e",
".[all,dev,jax]",
".[all,dev]",
external=True,
)
else:
if sys.version_info < (3, 9):
session.run(
python,
"-m",
"pip",
"install",
"-e",
".[all,dev]",
external=True,
)
else:
session.run(
python,
"-m",
"pip",
"install",
"-e",
".[all,dev,jax]",
external=True,
)
session.run(
python,
"-m",
"pip",
"install",
"-e",
".[all,dev,jax]",
external=True,
)


@nox.session(name="tests")
def run_tests(session):
"""Run the unit tests and integration tests sequentially."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,dev,jax]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
if sys.version_info < (3, 9):
session.install("-e", ".[all,dev]", silent=False)
else:
session.install("-e", ".[all,dev,jax]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("python", "run-tests.py", "--all")


Expand Down
8 changes: 3 additions & 5 deletions pybamm/meshes/one_dimensional_submeshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ class Exponential1DSubMesh(SubMesh1D):

.. math::
x_{k} = (b-a) +
\\frac{\mathrm{e}^{\\alpha k / N} - 1}{\mathrm{e}^{\\alpha} - 1} + a,
\\frac{\\mathrm{e}^{\\alpha k / N} - 1}{\\mathrm{e}^{\\alpha} - 1} + a,

for k = 1, ..., N, where N is the number of nodes.

Is side is "right", the gridpoints are given by

.. math::
x_{k} = (b-a) +
\\frac{\mathrm{e}^{-\\alpha k / N} - 1}{\mathrm{e}^{-\\alpha} - 1} + a,
\\frac{\\mathrm{e}^{-\\alpha k / N} - 1}{\\mathrm{e}^{-\\alpha} - 1} + a,

for k = 1, ..., N.

Expand All @@ -149,7 +149,7 @@ class Exponential1DSubMesh(SubMesh1D):

.. math::
x_{k} = (b/2-a) +
\\frac{\mathrm{e}^{\\alpha k / N} - 1}{\mathrm{e}^{\\alpha} - 1} + a,
\\frac{\\mathrm{e}^{\\alpha k / N} - 1}{\\mathrm{e}^{\\alpha} - 1} + a,

for k = 1, ..., N. The grid spacing is then reflected to contruct the grid
on the full interval [a,b].
Expand Down Expand Up @@ -289,14 +289,12 @@ class UserSupplied1DSubMesh(SubMesh1D):
"""

def __init__(self, lims, npts, edges=None):
# raise error if no edges passed
if edges is None:
raise pybamm.GeometryError("User mesh requires parameter 'edges'")

spatial_var, spatial_lims, tabs = self.read_lims(lims)
npts = npts[spatial_var.name]

# check that npts + 1 equals number of user-supplied edges
if (npts + 1) != len(edges):
raise pybamm.GeometryError(
f"""User-suppled edges has should have length (npts + 1) but has length
Expand Down
Loading