From de4846ed260beb9b66819da0d1476b42d39315a4 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Mon, 13 Jul 2020 11:41:33 +0100 Subject: [PATCH] #1104 tidy up --- pybamm/solvers/jax_bdf_solver.py | 5 ++--- tests/unit/test_solvers/test_jax_bdf_solver.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 9c7885386a..4ebd951d81 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -13,7 +13,6 @@ from jax.interpreters import partial_eval as pe from jax import linear_util as lu from jax.config import config -from jax.lib import pytree config.update("jax_enable_x64", True) @@ -688,8 +687,8 @@ def block_fun(i, j, Ai, Aj): ) blocks = [ - [ block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)] - for i, Ai in enumerate(lst) + [block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)] + for i, Ai in enumerate(lst) ] return jnp.block(blocks) diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index e52f2843f5..f99b2efab2 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -92,7 +92,6 @@ def fun(y, t): np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-7, atol=1e-7) - def test_solver_sensitivities(self): # Create model model = pybamm.BaseModel()