Skip to content

Commit

Permalink
#684 set up stepping
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Oct 29, 2019
1 parent a0ad975 commit ec7e934
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 24 deletions.
4 changes: 2 additions & 2 deletions pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class AlgebraicSolver(object):
def __init__(self, method="lm", tol=1e-6):
self.method = method
self.tol = tol
self.name = "Algebraic solver ({})".format(method)

@property
def method(self):
Expand Down Expand Up @@ -51,7 +52,7 @@ def solve(self, model):
equations.
"""
pybamm.logger.info("Start solving {}".format(model.name))
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))

# Set up
timer = pybamm.Timer()
Expand Down Expand Up @@ -87,7 +88,6 @@ def jacobian(y):

# Assign times
solution.solve_time = timer.time() - solve_start_time
solution.total_time = timer.time() - start_time
solution.set_up_time = set_up_time

pybamm.logger.info("Finish solving {}".format(model.name))
Expand Down
7 changes: 2 additions & 5 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def solve(self, model, t_eval):
If an empty model is passed (`model.rhs = {}` and `model.algebraic={}`)
"""
pybamm.logger.info("Start solving {}".format(model.name))
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))

# Make sure model isn't empty
if len(model.rhs) == 0 and len(model.algebraic) == 0:
Expand All @@ -84,7 +84,6 @@ def solve(self, model, t_eval):

# Assign times
solution.solve_time = solve_time
solution.total_time = timer.time() - start_time
solution.set_up_time = set_up_time

pybamm.logger.info("Finish solving {} ({})".format(model.name, termination))
Expand Down Expand Up @@ -129,15 +128,14 @@ def step(self, model, dt, npts=2):

# Run set up on first step
if not hasattr(self, "y0"):
start_time = timer.time()
if model.convert_to_format == "casadi" or isinstance(
self, pybamm.CasadiSolver
):
self.set_up_casadi(model)
else:
self.set_up(model)
self.t = 0.0
set_up_time = timer.time() - start_time
set_up_time = timer.time()
else:
set_up_time = None

Expand All @@ -149,7 +147,6 @@ def step(self, model, dt, npts=2):
# Assign times
solution.solve_time = solve_time
if set_up_time:
solution.total_time = timer.time() - start_time
solution.set_up_time = set_up_time

# Set self.t and self.y0 to their values at the final step
Expand Down
64 changes: 63 additions & 1 deletion pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CasadiSolver(pybamm.DaeSolver):

def __init__(
self,
method="idas",
method="ida",
rtol=1e-6,
atol=1e-6,
root_method="lm",
Expand All @@ -31,6 +31,68 @@ def __init__(
):
super().__init__(method, rtol, atol, root_method, root_tol, max_steps)
self.extra_options = extra_options
self.name = "CasADi solver ({})".format(method)

def solve(self, model, t_eval, mode="safe"):
"""
Execute the solver setup and calculate the solution of the model at
specified times.
Parameters
----------
model : :class:`pybamm.BaseModel`
The model whose solution to calculate. Must have attributes rhs and
initial_conditions
t_eval : numeric type
The times at which to compute the solution
mode : str
How to solve the model (default is "safe"):
- "fast": perform direct integration, without accounting for events. \
Recommended when simulating a drive cycle or other simulation where \
no events should be triggered.
- "safe": perform step-and-check integration, checking whether events have \
been triggered. Recommended for simulations of a full charge or discharge.
Raises
------
:class:`pybamm.ValueError`
If an invalid mode is passed.
:class:`pybamm.ModelError`
If an empty model is passed (`model.rhs = {}` and `model.algebraic={}`)
"""
if mode == "fast":
# Solve model normally by calling the solve method from parent class
return super().solve(model, t_eval)
elif mode == "safe":
# Step-and-check
# old_event_signs = np.sign(
# np.concatenate([event(0, y0) for event in self.events])
# )
timer = pybamm.Timer()
self.set_up_casadi(model)
set_up_time = timer.time()
self.t = 0.0
solution = None
for dt in np.diff(t_eval):
current_step_sol = self.step(model, dt)
if not solution:
# create solution object on first step
solution = current_step_sol
solution.set_up_time = set_up_time
else:
# append solution from the current step to solution
solution.append(current_step_sol)
return solution
else:
raise ValueError(
"""
invalid mode '{}'. Must be either 'safe', for solving with events,
or 'fast', for solving quickly without events""".format(
mode
)
)

def compute_solution(self, model, t_eval):
"""Calculate the solution of the model at specified times. In this class, we
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
raise ImportError("KLU is not installed")

super().__init__("ida", rtol, atol, root_method, root_tol, max_steps)
self.name = "IDA KLU solver"

def integrate(self, residuals, y0, t_eval, events, mass_matrix, jacobian):
"""
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/scikits_dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
raise ImportError("scikits.odes is not installed")

super().__init__(method, rtol, atol, root_method, root_tol, max_steps)
self.name = "Scikits DAE solver ({})".format(method)

def integrate(
self, residuals, y0, t_eval, events=None, mass_matrix=None, jacobian=None
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, method="cvode", rtol=1e-6, atol=1e-6, linsolver="dense"):

super().__init__(method, rtol, atol)
self.linsolver = linsolver
self.name = "Scikits ODE solver ({})".format(method)

def integrate(
self, derivs, y0, t_eval, events=None, mass_matrix=None, jacobian=None
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ScipySolver(pybamm.OdeSolver):

def __init__(self, method="BDF", rtol=1e-6, atol=1e-6):
super().__init__(method, rtol, atol)
self.name = "Scipy solver ({})".format(method)

def integrate(
self, derivs, y0, t_eval, events=None, mass_matrix=None, jacobian=None
Expand Down
5 changes: 5 additions & 0 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,8 @@ def append(self, solution):
"""
self.t = np.concatenate((self.t, solution.t[1:]))
self.y = np.concatenate((self.y, solution.y[:, 1:]), axis=1)
self.solve_time += solution.solve_time

@property
def total_time(self):
return self.set_up_time + self.solve_time
10 changes: 5 additions & 5 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def test_integrate_failure(self):
# Turn warnings back on
warnings.simplefilter("default")

def test_bad_mode(self):
solver = pybamm.CasadiSolver()
with self.assertRaisesRegex(ValueError, "invalid mode"):
solver.solve(None, None, "bad mode")

def test_model_solver(self):
# Create model
model = pybamm.BaseModel()
Expand All @@ -81,11 +86,6 @@ def test_model_solver(self):
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))

# Test time
self.assertGreater(
solution.total_time, solution.solve_time + solution.set_up_time
)

def test_model_step(self):
# Create model
model = pybamm.BaseModel()
Expand Down
10 changes: 0 additions & 10 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ def test_model_solver_ode(self):
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))

# Test time
self.assertGreater(
solution.total_time, solution.solve_time + solution.set_up_time
)

def test_model_solver_ode_events(self):
# Create model
model = pybamm.BaseModel()
Expand Down Expand Up @@ -508,11 +503,6 @@ def test_model_solver_dae(self):
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t))

# Test time
self.assertGreater(
solution.total_time, solution.solve_time + solution.set_up_time
)

def test_model_solver_dae_bad_ics(self):
# Create model
model = pybamm.BaseModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_model_solver(self):
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))

# Test time
self.assertGreater(
self.assertEqual(
solution.total_time, solution.solve_time + solution.set_up_time
)

Expand Down

0 comments on commit ec7e934

Please sign in to comment.