diff --git a/CHANGELOG.md b/CHANGELOG.md index 86dd9c9570..ca2eaafb59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ ## Bug fixes +- When an `Interpolant` is extrapolated an error is raised for `CasadiSolver` (and a warning is raised for the other solvers) ([#1315](https://github.com/pybamm-team/PyBaMM/pull/1315)) - Fixed `Simulation` and `model.new_copy` to fix a bug where changes to the model were overwritten ([#1278](https://github.com/pybamm-team/PyBaMM/pull/1278)) ## Breaking changes diff --git a/pybamm/models/event.py b/pybamm/models/event.py index 578710e952..1fa86b8793 100644 --- a/pybamm/models/event.py +++ b/pybamm/models/event.py @@ -16,6 +16,7 @@ class EventType(Enum): TERMINATION = 0 DISCONTINUITY = 1 + INTERPOLANT_EXTRAPOLATION = 2 class Event: diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index d12c29cbe7..569e46c72e 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -83,6 +83,7 @@ def __init__(self, values=None, chemistry=None): # Initialise empty _processed_symbols dict (for caching) self._processed_symbols = {} + self.parameter_events = [] def __getitem__(self, key): return self._dict_items[key] @@ -403,13 +404,24 @@ def process_model(self, unprocessed_model, inplace=True): new_events = [] for event in unprocessed_model.events: pybamm.logger.debug( - "Processing parameters for event'{}''".format(event.name) + "Processing parameters for event '{}''".format(event.name) ) new_events.append( pybamm.Event( event.name, self.process_symbol(event.expression), event.event_type ) ) + + for event in self.parameter_events: + pybamm.logger.debug( + "Processing parameters for event '{}''".format(event.name) + ) + new_events.append( + pybamm.Event( + event.name, self.process_symbol(event.expression), event.event_type + ) + ) + model.events = new_events # Set external variables @@ -547,6 +559,23 @@ def _process_symbol(self, symbol): function = pybamm.Interpolant( data[:, 0], data[:, 1], *new_children, name=name ) + # Define event to catch extrapolation. In these events the sign is + # important: it should be positive inside of the range and negative + # outside of it + self.parameter_events.append( + pybamm.Event( + "Interpolant {} lower bound".format(name), + new_children[0] - min(data[:, 0]), + pybamm.EventType.INTERPOLANT_EXTRAPOLATION, + ) + ) + self.parameter_events.append( + pybamm.Event( + "Interpolant {} upper bound".format(name), + max(data[:, 0]) - new_children[0], + pybamm.EventType.INTERPOLANT_EXTRAPOLATION, + ) + ) elif isinstance(function_name, numbers.Number): # If the "function" is provided is actually a scalar, return a Scalar # object instead of throwing an error. diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 497b295be7..ad20445139 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -8,6 +8,7 @@ import numpy as np import sys import itertools +import warnings class BaseSolver(object): @@ -30,6 +31,8 @@ class BaseSolver(object): specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not. Default is 0. """ def __init__( @@ -39,6 +42,7 @@ def __init__( atol=1e-6, root_method=None, root_tol=1e-6, + extrap_tol=0, max_steps="deprecated", ): self._method = method @@ -46,6 +50,7 @@ def __init__( self._atol = atol self.root_tol = root_tol self.root_method = root_method + self.extrap_tol = extrap_tol if max_steps != "deprecated": raise ValueError( "max_steps has been deprecated, and should be set using the " @@ -361,6 +366,12 @@ def report(string): if event.event_type == pybamm.EventType.TERMINATION ] + interpolant_extrapolation_events_eval = [ + process(event.expression, "event", use_jacobian=False)[1] + for event in model.events + if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION + ] + # discontinuity events are evaluated before the solver is called, so don't need # to process them discontinuity_events_eval = [ @@ -376,6 +387,9 @@ def report(string): model.jac_algebraic_eval = jac_algebraic model.terminate_events_eval = terminate_events_eval model.discontinuity_events_eval = discontinuity_events_eval + model.interpolant_extrapolation_events_eval = ( + interpolant_extrapolation_events_eval + ) # Calculate initial conditions model.y0 = init_eval(inputs) @@ -697,6 +711,16 @@ def solve( solution.timescale_eval = model.timescale_eval solution.length_scales_eval = model.length_scales_eval + # Check if extrapolation occurred + extrapolation = self.check_extrapolation(solution, model.events) + if extrapolation: + warnings.warn( + "While solving {} extrapolation occurred for {}".format( + model.name, extrapolation + ), + pybamm.SolverWarning, + ) + # Identify the event that caused termination termination = self.get_termination_reason(solution, model.events) @@ -852,6 +876,16 @@ def step( solution.timescale_eval = temp_timescale_eval solution.length_scales_eval = temp_length_scales_eval + # Check if extrapolation occurred + extrapolation = self.check_extrapolation(solution, model.events) + if extrapolation: + warnings.warn( + "While solving {} extrapolation occurred for {}".format( + model.name, extrapolation + ), + pybamm.SolverWarning, + ) + # Identify the event that caused termination termination = self.get_termination_reason(solution, model.events) @@ -921,6 +955,48 @@ def get_termination_reason(self, solution, events): return "the termination event '{}' occurred".format(termination_event) + def check_extrapolation(self, solution, events): + """ + Check if extrapolation occurred for any of the interpolants. Note that with the + current approach (evaluating all the events at the solution times) some + extrapolations might not be found if they only occurred for a small period of + time. + + Parameters + ---------- + solution : :class:`pybamm.Solution` + The solution object + events : dict + Dictionary of events + """ + extrap_events = {} + + for event in events: + if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION: + extrap_events[event.name] = False + + try: + y_full = solution.y.full() + except AttributeError: + y_full = solution.y + + for event in events: + if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION: + if ( + event.expression.evaluate( + solution.t, + y_full, + inputs={k: v for k, v in solution.inputs.items()}, + ) + < self.extrap_tol + ).any(): + extrap_events[event.name] = True + + # Add the event dictionaryto the solution object + solution.extrap_events = extrap_events + + return [k for k, v in extrap_events.items() if v] + def _set_up_ext_and_inputs(self, model, external_variables, inputs): "Set up external variables and input parameters" inputs = inputs or {} diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index 7d002a2c60..cb4973c03c 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -47,6 +47,8 @@ class CasadiSolver(pybamm.BaseSolver): The maximum global step size (in seconds) used in "safe" mode. If None the default value corresponds to a non-dimensional time of 0.01 (i.e. ``0.01 * model.timescale_eval``). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not. Default is 0. extra_options_setup : dict, optional Any options to pass to the CasADi integrator when creating the integrator. Please consult `CasADi documentation `_ for @@ -71,10 +73,13 @@ def __init__( root_tol=1e-6, max_step_decrease_count=5, dt_max=None, + extrap_tol=0, extra_options_setup=None, extra_options_call=None, ): - super().__init__("problem dependent", rtol, atol, root_method, root_tol) + super().__init__( + "problem dependent", rtol, atol, root_method, root_tol, extrap_tol + ) if mode in ["safe", "fast", "safe without grid"]: self.mode = mode else: @@ -88,6 +93,7 @@ def __init__( self.extra_options_setup = extra_options_setup or {} self.extra_options_call = extra_options_call or {} + self.extrap_tol = extrap_tol self.name = "CasADi solver with '{}' mode".format(mode) @@ -141,6 +147,33 @@ def _integrate(self, model, t_eval, inputs=None): [event(t, y0, inputs) for event in model.terminate_events_eval] ) ) + + extrap_event = [ + event(t, y0, inputs) + for event in model.interpolant_extrapolation_events_eval + ] + + if extrap_event: + if (np.concatenate(extrap_event) < self.extrap_tol).any(): + extrap_event_names = [] + for event in model.events: + if ( + event.event_type + == pybamm.EventType.INTERPOLANT_EXTRAPOLATION + and ( + event.expression.evaluate(t, y0.full(), inputs=inputs,) + < self.extrap_tol + ).any() + ): + extrap_event_names.append(event.name[12:]) + + raise pybamm.SolverError( + "CasADI solver failed because the following interpolation " + "bounds were exceeded at the initial conditions: {}. " + "You may need to provide additional interpolation points " + "outside these bounds.".format(extrap_event_names) + ) + pybamm.logger.info("Start solving {} with {}".format(model.name, self.name)) if self.mode == "safe without grid": @@ -215,6 +248,37 @@ def _integrate(self, model, t_eval, inputs=None): ] ) ) + + extrap_event = [ + event(t, current_step_sol.y[:, -1], inputs=inputs) + for event in model.interpolant_extrapolation_events_eval + ] + + if extrap_event: + if (np.concatenate(extrap_event) < self.extrap_tol).any(): + extrap_event_names = [] + for event in model.events: + if ( + event.event_type + == pybamm.EventType.INTERPOLANT_EXTRAPOLATION + and ( + event.expression.evaluate( + t, + current_step_sol.y[:, -1].full(), + inputs=inputs, + ) + < self.extrap_tol + ).any() + ): + extrap_event_names.append(event.name[12:]) + + raise pybamm.SolverError( + "CasADI solver failed because the following interpolation " + "bounds were exceeded: {}. You may need to provide " + "additional interpolation points outside these " + "bounds.".format(extrap_event_names) + ) + # Exit loop if the sign of an event changes # Locate the event time using a root finding algorithm and # event state using interpolation. The solution is then truncated diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 34476d7672..da4a5144f7 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -36,6 +36,8 @@ class IDAKLUSolver(pybamm.BaseSolver): specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not (default is 0). """ def __init__( @@ -44,13 +46,16 @@ def __init__( atol=1e-6, root_method="casadi", root_tol=1e-6, + extrap_tol=0, max_steps="deprecated", ): if idaklu_spec is None: raise ImportError("KLU is not installed") - super().__init__("ida", rtol, atol, root_method, root_tol, max_steps) + super().__init__( + "ida", rtol, atol, root_method, root_tol, extrap_tol, max_steps + ) self.name = "IDA KLU solver" pybamm.citations.register("hindmarsh2000pvode") diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index f4463f1320..baff2ad883 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -715,10 +715,7 @@ def block_fun(i, j, Ai, Aj): return Ai else: return onp.zeros( - ( - Ai.shape[0] if Ai.ndim > 1 else 1, - Aj.shape[1] if Aj.ndim > 1 else 1, - ), + (Ai.shape[0] if Ai.ndim > 1 else 1, Aj.shape[1] if Aj.ndim > 1 else 1,), dtype=Ai.dtype, ) diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index d81e973c53..889b87c606 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -38,6 +38,8 @@ class JaxSolver(pybamm.BaseSolver): The relative tolerance for the solver (default is 1e-6). atol : float, optional The absolute tolerance for the solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not (default is 0). extra_options : dict, optional Any options to pass to the solver. Please consult `JAX documentation @@ -46,11 +48,19 @@ class JaxSolver(pybamm.BaseSolver): """ def __init__( - self, method="RK45", root_method=None, rtol=1e-6, atol=1e-6, extra_options=None + self, + method="RK45", + root_method=None, + rtol=1e-6, + atol=1e-6, + extrap_tol=0, + extra_options=None, ): # note: bdf solver itself calculates consistent initial conditions so can set # root_method to none, allow user to override this behavior - super().__init__(method, rtol, atol, root_method=root_method) + super().__init__( + method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol + ) method_options = ["RK45", "BDF"] if method not in method_options: raise ValueError("method must be one of {}".format(method_options)) diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py index df2272e14c..303d041b68 100644 --- a/pybamm/solvers/scikits_dae_solver.py +++ b/pybamm/solvers/scikits_dae_solver.py @@ -36,6 +36,8 @@ class ScikitsDaeSolver(pybamm.BaseSolver): specified by 'root_method' (e.g. "lm", "hybr", ...) root_tol : float, optional The tolerance for the initial-condition solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not (default is 0). extra_options : dict, optional Any options to pass to the solver. Please consult `scikits.odes documentation @@ -52,13 +54,16 @@ def __init__( atol=1e-6, root_method="casadi", root_tol=1e-6, + extrap_tol=0, extra_options=None, max_steps="deprecated", ): if scikits_odes_spec is None: raise ImportError("scikits.odes is not installed") - super().__init__(method, rtol, atol, root_method, root_tol, max_steps) + super().__init__( + method, rtol, atol, root_method, root_tol, extrap_tol, max_steps + ) self.name = "Scikits DAE solver ({})".format(method) self.extra_options = extra_options or {} diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py index 0c4d6b9cd4..314b5d1ae5 100644 --- a/pybamm/solvers/scikits_ode_solver.py +++ b/pybamm/solvers/scikits_ode_solver.py @@ -31,6 +31,8 @@ class ScikitsOdeSolver(pybamm.BaseSolver): The relative tolerance for the solver (default is 1e-6). atol : float, optional The absolute tolerance for the solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not (default is 0). extra_options : dict, optional Any options to pass to the solver. Please consult `scikits.odes documentation @@ -46,13 +48,14 @@ def __init__( method="cvode", rtol=1e-6, atol=1e-6, + extrap_tol=0, linsolver="deprecated", extra_options=None, ): if scikits_odes_spec is None: raise ImportError("scikits.odes is not installed") - super().__init__(method, rtol, atol) + super().__init__(method, rtol, atol, extrap_tol=extrap_tol) self.extra_options = extra_options or {} if linsolver != "deprecated": raise ValueError( diff --git a/pybamm/solvers/scipy_solver.py b/pybamm/solvers/scipy_solver.py index 41eb69838a..00c33da7b1 100644 --- a/pybamm/solvers/scipy_solver.py +++ b/pybamm/solvers/scipy_solver.py @@ -19,14 +19,18 @@ class ScipySolver(pybamm.BaseSolver): The relative tolerance for the solver (default is 1e-6). atol : float, optional The absolute tolerance for the solver (default is 1e-6). + extrap_tol : float, optional + The tolerance to assert whether extrapolation occurs or not (default is 0). extra_options : dict, optional Any options to pass to the solver. Please consult `SciPy documentation `_ for details. """ - def __init__(self, method="BDF", rtol=1e-6, atol=1e-6, extra_options=None): - super().__init__(method, rtol, atol) + def __init__( + self, method="BDF", rtol=1e-6, atol=1e-6, extrap_tol=0, extra_options=None + ): + super().__init__(method, rtol, atol, extrap_tol=extrap_tol) self.ode_solver = True self.extra_options = extra_options or {} self.name = "Scipy solver ({})".format(method) diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index d7055f14d3..ecdd23212d 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -287,6 +287,31 @@ def test_timescale_input_fail(self): with self.assertRaisesRegex(pybamm.SolverError, "The model timescale"): sol = solver.step(old_solution=sol, model=model, dt=1.0, inputs={"a": 20}) + def test_extrapolation_warnings(self): + # Make sure the extrapolation warnings work + model = pybamm.BaseModel() + v = pybamm.Variable("v") + model.rhs = {v: -1} + model.initial_conditions = {v: 1} + model.events.append( + pybamm.Event( + "Triggered event", v - 0.5, pybamm.EventType.INTERPOLANT_EXTRAPOLATION, + ) + ) + model.events.append( + pybamm.Event( + "Ignored event", v + 10, pybamm.EventType.INTERPOLANT_EXTRAPOLATION, + ) + ) + solver = pybamm.ScipySolver() + solver.set_up(model) + + with self.assertWarns(pybamm.SolverWarning): + solver.step(old_solution=None, model=model, dt=1.0) + + with self.assertWarns(pybamm.SolverWarning): + solver.solve(model, t_eval=[0, 1]) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_solvers/test_casadi_solver.py b/tests/unit/test_solvers/test_casadi_solver.py index 5548ac3f8e..6d097bb3a3 100644 --- a/tests/unit/test_solvers/test_casadi_solver.py +++ b/tests/unit/test_solvers/test_casadi_solver.py @@ -432,6 +432,41 @@ def test_dae_solver_algebraic_model(self): ): solver.solve(model, t_eval) + def test_interpolant_extrapolate(self): + model = pybamm.lithium_ion.DFN() + param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Chen2020) + experiment = pybamm.Experiment( + ["Charge at 1C until 4.6 V"], period="10 seconds" + ) + + param["Upper voltage cut-off [V]"] = 4.8 + + sim = pybamm.Simulation( + model, + parameter_values=param, + experiment=experiment, + solver=pybamm.CasadiSolver( + mode="safe", + dt_max=0.001, + extrap_tol=1e-3, + extra_options_setup={"max_num_steps": 500}, + ), + ) + with self.assertRaisesRegex(pybamm.SolverError, "interpolation bounds"): + sim.solve() + + ci = param["Initial concentration in positive electrode [mol.m-3]"] + param["Initial concentration in positive electrode [mol.m-3]"] = 0.8 * ci + + sim = pybamm.Simulation( + model, + parameter_values=param, + experiment=experiment, + solver=pybamm.CasadiSolver(mode="safe", dt_max=0.05), + ) + with self.assertRaisesRegex(pybamm.SolverError, "interpolation bounds"): + sim.solve() + class TestCasadiSolverSensitivity(unittest.TestCase): def test_solve_with_symbolic_input(self):