Skip to content

Commit

Permalink
skip jax on arm
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jul 28, 2021
1 parent 211bd90 commit 12cca81
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 31 deletions.
12 changes: 9 additions & 3 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#
import sys
import os
from platform import system
import platform

#
# Version info
Expand Down Expand Up @@ -102,7 +102,10 @@ def version(formatted=False):
EvaluatorPython,
)

if system() != "Windows":
if not (
platform.system() == "Windows"
or (platform.system() == "Darwin" and "ARM64" in platform.version())
):
from .expression_tree.operations.evaluate_python import EvaluatorJax
from .expression_tree.operations.evaluate_python import JaxCooMatrix

Expand Down Expand Up @@ -223,7 +226,10 @@ def version(formatted=False):
from .solvers.scipy_solver import ScipySolver

# Jax not supported under windows
if system() != "Windows":
if not (
platform.system() == "Windows"
or (platform.system() == "Darwin" and "ARM64" in platform.version())
):
from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate

Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from collections import OrderedDict

import numbers
from platform import system
from platform import system, version

if system() != "Windows":
if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
import jax

from jax.config import config
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import subprocess
from pathlib import Path
from platform import system
from platform import system, version
import wheel.bdist_wheel as orig
import site
import shutil
Expand Down Expand Up @@ -162,7 +162,7 @@ def compile_KLU():
ext_modules = [idaklu_ext] if compile_KLU() else []

jax_dependencies = []
if system() != "Windows":
if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
jax_dependencies = ["jax==0.2.12", "jaxlib==0.1.65"]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_well_posed_irreversible_plating_with_porosity(self):
"lithium plating porosity change": "true",
}
model = pybamm.lithium_ion.DFN(options)
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Yang2017)
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Mohtat2020)
modeltest = tests.StandardModelTest(model, parameter_values=param)
modeltest.test_all()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tests
import numpy as np
import unittest
from platform import system
from platform import system, version


class TestSPM(unittest.TestCase):
Expand Down Expand Up @@ -61,7 +61,9 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)

if system() != "Windows":
if not (
system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
):
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down Expand Up @@ -164,7 +166,7 @@ def test_well_posed_irreversible_plating_with_porosity(self):
"lithium plating porosity change": "true",
}
model = pybamm.lithium_ion.SPM(options)
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Yang2017)
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Mohtat2020)
modeltest = tests.StandardModelTest(model, parameter_values=param)
modeltest.test_all()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import unittest
from platform import system
from platform import system, version


class TestSPMe(unittest.TestCase):
Expand Down Expand Up @@ -68,7 +68,9 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)

if system() != "Windows":
if not (
system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
):
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down Expand Up @@ -162,7 +164,7 @@ def test_well_posed_irreversible_plating_with_porosity(self):
"lithium plating porosity change": "true",
}
model = pybamm.lithium_ion.SPMe(options)
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Yang2017)
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Mohtat2020)
modeltest = tests.StandardModelTest(model, parameter_values=param)
modeltest.test_all()

Expand Down
7 changes: 5 additions & 2 deletions tests/unit/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm
import unittest
from platform import system
from platform import system, version


class TestCitations(unittest.TestCase):
Expand Down Expand Up @@ -237,7 +237,10 @@ def test_solver_citations(self):
pybamm.IDAKLUSolver()
self.assertIn("Hindmarsh2005", citations._papers_to_cite)

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import scipy.sparse
from collections import OrderedDict
from platform import system
from platform import system, version


def test_function(arg):
Expand Down Expand Up @@ -457,7 +457,10 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
Expand All @@ -470,7 +473,10 @@ def test_find_symbols_jax(self):
list(constant_symbols.values())[0].toarray(), A.entries.toarray()
)

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
Expand Down Expand Up @@ -632,7 +638,10 @@ def test_evaluator_jax(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -647,23 +656,32 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac.evaluate(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a ** 2
y_test = np.array([[2.0], [3.0]])
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a ** 2
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator.evaluate(inputs={"a": 2})
self.assertEqual(result, 4)

@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
def test_jax_coo_matrix(self):
import jax

Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import sys
import time
import numpy as np
from platform import system
from platform import system, version

if system() != "Windows":
if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
import jax


@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
class TestJaxBDFSolver(unittest.TestCase):
def test_solver(self):
# Create model
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import sys
import time
import numpy as np
from platform import system
from platform import system, version

if system() != "Windows":
if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
import jax


@unittest.skipIf(system() == "Windows", "JAX not supported on windows")
@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
class TestJaxSolver(unittest.TestCase):
def test_model_solver(self):
# Create model
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from tests import get_mesh_for_testing
import warnings
import sys
from platform import system
from platform import system, version


class TestScipySolver(unittest.TestCase):
def test_model_solver_python_and_jax(self):

if system() != "Windows":
if not (
system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
):
formats = ["python", "jax"]
else:
formats = ["python"]
Expand Down

0 comments on commit 12cca81

Please sign in to comment.