Skip to content

Commit

Permalink
#1104 add test in jax solver for semi-explicit dae case
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 13, 2020
1 parent 9b23870 commit cc46590
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 18 deletions.
26 changes: 18 additions & 8 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from jax.interpreters import partial_eval as pe
from jax import linear_util as lu
from jax.config import config
from jax.lib import pytree

config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -165,11 +166,8 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
state['D'] = D
state['y0'] = None
state['scale_y0'] = None
state['M'] = mass
state = _predict(state)
if mass is None:
state['M'] = jnp.identity(len(y0), dtype=y0.dtype)
else:
state['M'] = mass

# kappa values for difference orders, taken from Table 1 of [1]
kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0])
Expand Down Expand Up @@ -681,7 +679,13 @@ def block_fun(i, j, Ai, Aj):
if i == j:
return Ai
else:
return jnp.zeros(Ai.shape[0], Aj.shape[1])
return jnp.zeros(
(
Ai.shape[0] if Ai.ndim > 1 else 1,
Aj.shape[1] if Aj.ndim > 1 else 1,
),
dtype=Ai.dtype
)

blocks = [
[ block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)]
Expand Down Expand Up @@ -815,10 +819,12 @@ def flax_scan(f, init, xs, length=None): # pragma: no cover
return carry, onp.stack(ys)


#@partial(jax.jit, static_argnums=(0, 1, 2, 3))
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
if mass is not None:
if mass is None:
mass = jnp.identity(y0.shape[0], dtype=y0.dtype)
else:
mass = block_diag(tree_flatten(mass)[0])
func = ravel_first_arg(func, unravel)
out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
Expand All @@ -844,7 +850,11 @@ def aug_dynamics(augmented_state, t, *args):
y_bar = g[-1]
ts_bar = []
t0_bar = 0.
aug_mass = (mass, mass, jnp.ones((1, 1)), tree_map(jnp.ones_like, args))

def arg_to_identity(arg):
return jnp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)

aug_mass = (mass, mass, jnp.array(1.), tree_map(arg_to_identity, args))

def scan_fun(carry, i):
y_bar, t0_bar, args_bar = carry
Expand Down
32 changes: 22 additions & 10 deletions pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import jax
from jax.experimental.ode import odeint
import jax.numpy as np
import jax.numpy as jnp
import numpy as onp


Expand Down Expand Up @@ -42,13 +42,15 @@ class JaxSolver(pybamm.BaseSolver):
"""

def __init__(self, method='RK45', rtol=1e-6, atol=1e-6, extra_options=None):
super().__init__(method, rtol, atol)
self.ode_solver = True
super().__init__(method, rtol, atol, root_method='lm')
method_options = ['RK45', 'BDF']
if method not in method_options:
raise ValueError('method must be one of {}'.format(method_options))
self.ode_solver = False
if method == 'RK45':
self.ode_solver = True
self.extra_options = extra_options or {}
self.name = "JAX solver"
self.name = "JAX solver ({})".format(method)
self._cached_solves = dict()

def get_solve(self, model, t_eval):
Expand Down Expand Up @@ -111,33 +113,43 @@ def create_solve(self, model, t_eval):

# Initial conditions
y0 = model.y0
mass = None
if self.method == 'BDF':
mass = model.mass_matrix.entries.toarray()

def rhs_odeint(y, t, inputs):
return model.rhs_eval(t, y, inputs)
def rhs_ode(y, t, inputs):
return model.rhs_eval(t, y, inputs),

def rhs_dae(y, t, inputs):
return jnp.concatenate([
model.rhs_eval(t, y, inputs),
model.algebraic_eval(t, y, inputs),
])

def solve_model_rk45(inputs):
y = odeint(
rhs_odeint,
rhs_ode,
y0,
t_eval,
inputs,
rtol=self.rtol,
atol=self.atol,
**self.extra_options
)
return np.transpose(y)
return jnp.transpose(y)

def solve_model_bdf(inputs):
y = pybamm.jax_bdf_integrate(
rhs_odeint,
rhs_dae,
y0,
t_eval,
inputs,
rtol=self.rtol,
atol=self.atol,
mass=mass,
**self.extra_options
)
return np.transpose(y)
return jnp.transpose(y)

if self.method == 'RK45':
return jax.jit(solve_model_rk45)
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,47 @@ def solve_bdf(rate):

self.assertAlmostEqual(grad_bdf, grad_num, places=3)

@unittest.skip("sensitivities do not yet work with semi-explict dae")
def test_mass_matrix_with_sensitivities(self):
# Solve
t_eval = np.linspace(0.0, 1.0, 80)

def fun(y, t, inputs):
return jax.numpy.stack([
inputs['rate'] * y[0],
y[1] - 2.0 * y[0],
])

mass = jax.numpy.array([
[1.0, 0.0],
[0.0, 0.0],
])

y0 = jax.numpy.array([1.0, 2.0])

h = 0.0001
rate = 0.1

# create a dummy "model" where we calculate the sum of the time series
@jax.jit
def solve_bdf(rate):
return jax.numpy.sum(
pybamm.jax_bdf_integrate(fun, y0, t_eval,
{'rate': rate},
mass=mass,
rtol=1e-9, atol=1e-9)
)

# check answers with finite difference
eval_plus = solve_bdf(rate + h)
eval_neg = solve_bdf(rate - h)
grad_num = (eval_plus - eval_neg) / (2 * h)

grad_solve_bdf = jax.jit(jax.grad(solve_bdf))
grad_bdf = grad_solve_bdf(rate)

self.assertAlmostEqual(grad_bdf, grad_num, places=3)

def test_solver_with_inputs(self):
# Create model
model = pybamm.BaseModel()
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,52 @@ def test_model_solver(self):
self.assertLess(t_second_solve, t_first_solve)
np.testing.assert_array_equal(second_solution.y, solution.y)

def test_semi_explicit_model(self):
# Create model
model = pybamm.BaseModel()
model.convert_to_format = "jax"
domain = ["negative electrode", "separator", "positive electrode"]
var = pybamm.Variable("var", domain=domain)
var2 = pybamm.Variable("var2", domain=domain)
model.rhs = {var: 0.1 * var}
model.algebraic = {var2: var2 - 2.0 * var}
model.initial_conditions = {var: 1.0, var2: 1.0}
# No need to set parameters; can use base discretisation (no spatial operators)

# create discretisation
mesh = get_mesh_for_testing()
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
disc = pybamm.Discretisation(mesh, spatial_methods)
disc.process_model(model)

# Solve
solver = pybamm.JaxSolver(
method='BDF', rtol=1e-8, atol=1e-8
)
t_eval = np.linspace(0, 1, 80)
t0 = time.perf_counter()
solution = solver.solve(model, t_eval)
t_first_solve = time.perf_counter() - t0
np.testing.assert_array_equal(solution.t, t_eval)
soln = np.exp(0.1 * solution.t)
np.testing.assert_allclose(solution.y[0], soln,
rtol=1e-7, atol=1e-7)
np.testing.assert_allclose(solution.y[-1], 2 * soln,
rtol=1e-7, atol=1e-7)

# Test time
self.assertEqual(
solution.total_time, solution.solve_time + solution.set_up_time
)
self.assertEqual(solution.termination, "final time")

t0 = time.perf_counter()
second_solution = solver.solve(model, t_eval)
t_second_solve = time.perf_counter() - t0

self.assertLess(t_second_solve, t_first_solve)
np.testing.assert_array_equal(second_solution.y, solution.y)

def test_solver_sensitivities(self):
# Create model
model = pybamm.BaseModel()
Expand Down

0 comments on commit cc46590

Please sign in to comment.