From 12cca8157b479d5590e52e23e9a8c4b28de4499b Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Wed, 28 Jul 2021 14:09:50 -0400 Subject: [PATCH] skip jax on arm --- pybamm/__init__.py | 12 +++++-- .../operations/evaluate_python.py | 4 +-- setup.py | 4 +-- .../test_lithium_ion/test_dfn.py | 2 +- .../test_lithium_ion/test_spm.py | 8 +++-- .../test_lithium_ion/test_spme.py | 8 +++-- tests/unit/test_citations.py | 7 ++-- .../test_operations/test_evaluate_python.py | 32 +++++++++++++++---- .../unit/test_solvers/test_jax_bdf_solver.py | 9 ++++-- tests/unit/test_solvers/test_jax_solver.py | 9 ++++-- tests/unit/test_solvers/test_scipy_solver.py | 6 ++-- 11 files changed, 70 insertions(+), 31 deletions(-) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index e5f476195a..074332b9c2 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -7,7 +7,7 @@ # import sys import os -from platform import system +import platform # # Version info @@ -102,7 +102,10 @@ def version(formatted=False): EvaluatorPython, ) -if system() != "Windows": +if not ( + platform.system() == "Windows" + or (platform.system() == "Darwin" and "ARM64" in platform.version()) +): from .expression_tree.operations.evaluate_python import EvaluatorJax from .expression_tree.operations.evaluate_python import JaxCooMatrix @@ -223,7 +226,10 @@ def version(formatted=False): from .solvers.scipy_solver import ScipySolver # Jax not supported under windows -if system() != "Windows": +if not ( + platform.system() == "Windows" + or (platform.system() == "Darwin" and "ARM64" in platform.version()) +): from .solvers.jax_solver import JaxSolver from .solvers.jax_bdf_solver import jax_bdf_integrate diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index 638dc33418..c8005c5625 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -8,9 +8,9 @@ from collections import OrderedDict import numbers -from platform import system +from platform import system, version -if system() != "Windows": +if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): import jax from jax.config import config diff --git a/setup.py b/setup.py index e455058478..0a4c646c07 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import logging import subprocess from pathlib import Path -from platform import system +from platform import system, version import wheel.bdist_wheel as orig import site import shutil @@ -162,7 +162,7 @@ def compile_KLU(): ext_modules = [idaklu_ext] if compile_KLU() else [] jax_dependencies = [] -if system() != "Windows": +if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): jax_dependencies = ["jax==0.2.12", "jaxlib==0.1.65"] diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py index fef5bb35de..43a134aed4 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_dfn.py @@ -174,7 +174,7 @@ def test_well_posed_irreversible_plating_with_porosity(self): "lithium plating porosity change": "true", } model = pybamm.lithium_ion.DFN(options) - param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Yang2017) + param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Mohtat2020) modeltest = tests.StandardModelTest(model, parameter_values=param) modeltest.test_all() diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index 6439ac7bb1..bc45d5c2d2 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -5,7 +5,7 @@ import tests import numpy as np import unittest -from platform import system +from platform import system, version class TestSPM(unittest.TestCase): @@ -61,7 +61,9 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if system() != "Windows": + if not ( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) + ): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) @@ -164,7 +166,7 @@ def test_well_posed_irreversible_plating_with_porosity(self): "lithium plating porosity change": "true", } model = pybamm.lithium_ion.SPM(options) - param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Yang2017) + param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Mohtat2020) modeltest = tests.StandardModelTest(model, parameter_values=param) modeltest.test_all() diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py index b9cadb31a4..e7f8137dc1 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py @@ -6,7 +6,7 @@ import numpy as np import unittest -from platform import system +from platform import system, version class TestSPMe(unittest.TestCase): @@ -68,7 +68,9 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if system() != "Windows": + if not ( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) + ): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) @@ -162,7 +164,7 @@ def test_well_posed_irreversible_plating_with_porosity(self): "lithium plating porosity change": "true", } model = pybamm.lithium_ion.SPMe(options) - param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Yang2017) + param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Mohtat2020) modeltest = tests.StandardModelTest(model, parameter_values=param) modeltest.test_all() diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py index 22c2e92a38..095789f1bb 100644 --- a/tests/unit/test_citations.py +++ b/tests/unit/test_citations.py @@ -3,7 +3,7 @@ # import pybamm import unittest -from platform import system +from platform import system, version class TestCitations(unittest.TestCase): @@ -237,7 +237,10 @@ def test_solver_citations(self): pybamm.IDAKLUSolver() self.assertIn("Hindmarsh2005", citations._papers_to_cite) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_jax_citations(self): citations = pybamm.citations citations._reset() diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 3907470234..3cca9eef2d 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -8,7 +8,7 @@ import numpy as np import scipy.sparse from collections import OrderedDict -from platform import system +from platform import system, version def test_function(arg): @@ -457,7 +457,10 @@ def test_evaluator_python(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_find_symbols_jax(self): # test sparse conversion constant_symbols = OrderedDict() @@ -470,7 +473,10 @@ def test_find_symbols_jax(self): list(constant_symbols.values())[0].toarray(), A.entries.toarray() ) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) @@ -632,7 +638,10 @@ def test_evaluator_jax(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_evaluator_jax_jacobian(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])] @@ -647,7 +656,10 @@ def test_evaluator_jax_jacobian(self): result_true = evaluator_jac.evaluate(t=None, y=y) np.testing.assert_allclose(result_test, result_true) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_evaluator_jax_debug(self): a = pybamm.StateVector(slice(0, 1)) expr = a ** 2 @@ -655,7 +667,10 @@ def test_evaluator_jax_debug(self): evaluator = pybamm.EvaluatorJax(expr) evaluator.debug(y=y_test) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_evaluator_jax_inputs(self): a = pybamm.InputParameter("a") expr = a ** 2 @@ -663,7 +678,10 @@ def test_evaluator_jax_inputs(self): result = evaluator.evaluate(inputs={"a": 2}) self.assertEqual(result, 4) - @unittest.skipIf(system() == "Windows", "JAX not supported on windows") + @unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", + ) def test_jax_coo_matrix(self): import jax diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 70ae7dc26e..772bc937d0 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -4,13 +4,16 @@ import sys import time import numpy as np -from platform import system +from platform import system, version -if system() != "Windows": +if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): import jax -@unittest.skipIf(system() == "Windows", "JAX not supported on windows") +@unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", +) class TestJaxBDFSolver(unittest.TestCase): def test_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 3c6f727583..74dccdaf99 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -4,13 +4,16 @@ import sys import time import numpy as np -from platform import system +from platform import system, version -if system() != "Windows": +if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): import jax -@unittest.skipIf(system() == "Windows", "JAX not supported on windows") +@unittest.skipIf( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), + "JAX not supported on windows or Mac M1", +) class TestJaxSolver(unittest.TestCase): def test_model_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index 62d51ab437..9c586372ef 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -7,13 +7,15 @@ from tests import get_mesh_for_testing import warnings import sys -from platform import system +from platform import system, version class TestScipySolver(unittest.TestCase): def test_model_solver_python_and_jax(self): - if system() != "Windows": + if not ( + system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) + ): formats = ["python", "jax"] else: formats = ["python"]