Skip to content

Commit

Permalink
pybamm-team#976 added extrapolation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
brosaplanella committed Jan 3, 2021
1 parent 31b8a05 commit 07f7916
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import sys
import itertools
import warnings


class BaseSolver(object):
Expand All @@ -30,6 +31,8 @@ class BaseSolver(object):
specified by 'root_method' (e.g. "lm", "hybr", ...)
root_tol : float, optional
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
"""

def __init__(
Expand All @@ -39,13 +42,15 @@ def __init__(
atol=1e-6,
root_method=None,
root_tol=1e-6,
extrap_tol=0,
max_steps="deprecated",
):
self._method = method
self._rtol = rtol
self._atol = atol
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol
if max_steps != "deprecated":
raise ValueError(
"max_steps has been deprecated, and should be set using the "
Expand Down Expand Up @@ -706,6 +711,16 @@ def solve(
solution.timescale_eval = model.timescale_eval
solution.length_scales_eval = model.length_scales_eval

# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
if extrapolation:
warnings.warn(
"While solving {} extrapolation occurred for {}".format(
model.name, extrapolation
),
pybamm.SolverWarning,
)

# Identify the event that caused termination
termination = self.get_termination_reason(solution, model.events)

Expand Down Expand Up @@ -861,6 +876,16 @@ def step(
solution.timescale_eval = temp_timescale_eval
solution.length_scales_eval = temp_length_scales_eval

# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
if extrapolation:
warnings.warn(
"While solving {} extrapolation occurred for {}".format(
model.name, extrapolation
),
pybamm.SolverWarning,
)

# Identify the event that caused termination
termination = self.get_termination_reason(solution, model.events)

Expand Down Expand Up @@ -930,6 +955,48 @@ def get_termination_reason(self, solution, events):

return "the termination event '{}' occurred".format(termination_event)

def check_extrapolation(self, solution, events):
"""
Check if extrapolation occurred for any of the interpolants. Note that with the
current approach (evaluating all the events at the solution times) some
extrapolations might not be found if they only occurred for a small period of
time.
Parameters
----------
solution : :class:`pybamm.Solution`
The solution object
events : dict
Dictionary of events
"""
extrap_events = {}

for event in events:
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
extrap_events[event.name] = False

try:
y_full = solution.y.full()
except AttributeError:
y_full = solution.y

for event in events:
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
if (
event.expression.evaluate(
solution.t,
y_full,
inputs={k: v for k, v in solution.inputs.items()},
)
< self.extrap_tol
).any():
extrap_events[event.name] = True

# Add the event dictionaryto the solution object
solution.extrap_events = extrap_events

return [k for k, v in extrap_events.items() if v]

def _set_up_ext_and_inputs(self, model, external_variables, inputs):
"Set up external variables and input parameters"
inputs = inputs or {}
Expand Down

0 comments on commit 07f7916

Please sign in to comment.