diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 270a179f4e..ad20445139 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -8,6 +8,7 @@ import numpy as np import sys import itertools +import warnings class BaseSolver(object): @@ -30,6 +31,8 @@ class BaseSolver(object): specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not. Default is 0. """ def __init__( @@ -39,6 +42,7 @@ def __init__( atol=1e-6, root_method=None, root_tol=1e-6, + extrap_tol=0, max_steps="deprecated", ): self._method = method @@ -46,6 +50,7 @@ def __init__( self._atol = atol self.root_tol = root_tol self.root_method = root_method + self.extrap_tol = extrap_tol if max_steps != "deprecated": raise ValueError( "max_steps has been deprecated, and should be set using the " @@ -706,6 +711,16 @@ def solve( solution.timescale_eval = model.timescale_eval solution.length_scales_eval = model.length_scales_eval + # Check if extrapolation occurred + extrapolation = self.check_extrapolation(solution, model.events) + if extrapolation: + warnings.warn( + "While solving {} extrapolation occurred for {}".format( + model.name, extrapolation + ), + pybamm.SolverWarning, + ) + # Identify the event that caused termination termination = self.get_termination_reason(solution, model.events) @@ -861,6 +876,16 @@ def step( solution.timescale_eval = temp_timescale_eval solution.length_scales_eval = temp_length_scales_eval + # Check if extrapolation occurred + extrapolation = self.check_extrapolation(solution, model.events) + if extrapolation: + warnings.warn( + "While solving {} extrapolation occurred for {}".format( + model.name, extrapolation + ), + pybamm.SolverWarning, + ) + # Identify the event that caused termination termination = self.get_termination_reason(solution, model.events) @@ -930,6 +955,48 @@ def get_termination_reason(self, solution, events): return "the termination event '{}' occurred".format(termination_event) + def check_extrapolation(self, solution, events): + """ + Check if extrapolation occurred for any of the interpolants. Note that with the + current approach (evaluating all the events at the solution times) some + extrapolations might not be found if they only occurred for a small period of + time. + + Parameters + ---------- + solution : :class:`pybamm.Solution` + The solution object + events : dict + Dictionary of events + """ + extrap_events = {} + + for event in events: + if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION: + extrap_events[event.name] = False + + try: + y_full = solution.y.full() + except AttributeError: + y_full = solution.y + + for event in events: + if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION: + if ( + event.expression.evaluate( + solution.t, + y_full, + inputs={k: v for k, v in solution.inputs.items()}, + ) + < self.extrap_tol + ).any(): + extrap_events[event.name] = True + + # Add the event dictionaryto the solution object + solution.extrap_events = extrap_events + + return [k for k, v in extrap_events.items() if v] + def _set_up_ext_and_inputs(self, model, external_variables, inputs): "Set up external variables and input parameters" inputs = inputs or {}