Skip to content

Commit

Permalink
#1104 improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 14, 2020
1 parent 3e31c5f commit 4e3a78b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
13 changes: 0 additions & 13 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,19 +919,6 @@ def abstractify(x):
return core.raise_to_shaped(core.get_aval(x))


def ravel_2d_pytree(pytree):
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = jax.api.vjp(ravel_2d_list, *leaves)

def unravel_pytree(flat):
return tree_unflatten(treedef, unravel_list(flat))
return flat, unravel_pytree


def ravel_2d_list(*lst):
return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])


def ravel_first_arg(f, unravel):
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ def test_get_solve(self):
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)

# test that another method string gives error
with self.assertRaises(ValueError):
solver = pybamm.JaxSolver(method='not_real')

# Solve
solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 5, 80)
Expand Down

0 comments on commit 4e3a78b

Please sign in to comment.