Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 976 casadi extrapolate warning #1315

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

## Bug fixes

- When an `Interpolant` is extrapolated an error is raised for `CasadiSolver` (and a warning is raised for the other solvers) ([#1315](https://github.com/pybamm-team/PyBaMM/pull/1315))
- Fixed `Simulation` and `model.new_copy` to fix a bug where changes to the model were overwritten ([#1278](https://github.com/pybamm-team/PyBaMM/pull/1278))

## Breaking changes
Expand Down
1 change: 1 addition & 0 deletions pybamm/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class EventType(Enum):

TERMINATION = 0
DISCONTINUITY = 1
INTERPOLANT_EXTRAPOLATION = 2


class Event:
Expand Down
31 changes: 30 additions & 1 deletion pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, values=None, chemistry=None):

# Initialise empty _processed_symbols dict (for caching)
self._processed_symbols = {}
self.parameter_events = []

def __getitem__(self, key):
return self._dict_items[key]
Expand Down Expand Up @@ -403,13 +404,24 @@ def process_model(self, unprocessed_model, inplace=True):
new_events = []
for event in unprocessed_model.events:
pybamm.logger.debug(
"Processing parameters for event'{}''".format(event.name)
"Processing parameters for event '{}''".format(event.name)
)
new_events.append(
pybamm.Event(
event.name, self.process_symbol(event.expression), event.event_type
)
)

for event in self.parameter_events:
pybamm.logger.debug(
"Processing parameters for event '{}''".format(event.name)
)
new_events.append(
pybamm.Event(
event.name, self.process_symbol(event.expression), event.event_type
)
)

model.events = new_events

# Set external variables
Expand Down Expand Up @@ -547,6 +559,23 @@ def _process_symbol(self, symbol):
function = pybamm.Interpolant(
data[:, 0], data[:, 1], *new_children, name=name
)
# Define event to catch extrapolation. In these events the sign is
# important: it should be positive inside of the range and negative
# outside of it
self.parameter_events.append(
pybamm.Event(
"Interpolant {} lower bound".format(name),
new_children[0] - min(data[:, 0]),
pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
)
)
self.parameter_events.append(
pybamm.Event(
"Interpolant {} upper bound".format(name),
max(data[:, 0]) - new_children[0],
pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
)
)
elif isinstance(function_name, numbers.Number):
# If the "function" is provided is actually a scalar, return a Scalar
# object instead of throwing an error.
Expand Down
76 changes: 76 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import sys
import itertools
import warnings


class BaseSolver(object):
Expand All @@ -30,6 +31,8 @@ class BaseSolver(object):
specified by 'root_method' (e.g. "lm", "hybr", ...)
root_tol : float, optional
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
"""

def __init__(
Expand All @@ -39,13 +42,15 @@ def __init__(
atol=1e-6,
root_method=None,
root_tol=1e-6,
extrap_tol=0,
max_steps="deprecated",
):
self._method = method
self._rtol = rtol
self._atol = atol
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol
if max_steps != "deprecated":
raise ValueError(
"max_steps has been deprecated, and should be set using the "
Expand Down Expand Up @@ -361,6 +366,12 @@ def report(string):
if event.event_type == pybamm.EventType.TERMINATION
]

interpolant_extrapolation_events_eval = [
process(event.expression, "event", use_jacobian=False)[1]
for event in model.events
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION
]

# discontinuity events are evaluated before the solver is called, so don't need
# to process them
discontinuity_events_eval = [
Expand All @@ -376,6 +387,9 @@ def report(string):
model.jac_algebraic_eval = jac_algebraic
model.terminate_events_eval = terminate_events_eval
model.discontinuity_events_eval = discontinuity_events_eval
model.interpolant_extrapolation_events_eval = (
interpolant_extrapolation_events_eval
)

# Calculate initial conditions
model.y0 = init_eval(inputs)
Expand Down Expand Up @@ -697,6 +711,16 @@ def solve(
solution.timescale_eval = model.timescale_eval
solution.length_scales_eval = model.length_scales_eval

# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
if extrapolation:
warnings.warn(
"While solving {} extrapolation occurred for {}".format(
model.name, extrapolation
),
pybamm.SolverWarning,
)

# Identify the event that caused termination
termination = self.get_termination_reason(solution, model.events)

Expand Down Expand Up @@ -852,6 +876,16 @@ def step(
solution.timescale_eval = temp_timescale_eval
solution.length_scales_eval = temp_length_scales_eval

# Check if extrapolation occurred
extrapolation = self.check_extrapolation(solution, model.events)
if extrapolation:
warnings.warn(
"While solving {} extrapolation occurred for {}".format(
model.name, extrapolation
),
pybamm.SolverWarning,
)

# Identify the event that caused termination
termination = self.get_termination_reason(solution, model.events)

Expand Down Expand Up @@ -921,6 +955,48 @@ def get_termination_reason(self, solution, events):

return "the termination event '{}' occurred".format(termination_event)

def check_extrapolation(self, solution, events):
"""
Check if extrapolation occurred for any of the interpolants. Note that with the
current approach (evaluating all the events at the solution times) some
extrapolations might not be found if they only occurred for a small period of
time.

Parameters
----------
solution : :class:`pybamm.Solution`
The solution object
events : dict
Dictionary of events
"""
extrap_events = {}

for event in events:
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
extrap_events[event.name] = False

try:
y_full = solution.y.full()
except AttributeError:
y_full = solution.y

for event in events:
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
if (
event.expression.evaluate(
solution.t,
y_full,
inputs={k: v for k, v in solution.inputs.items()},
)
< self.extrap_tol
).any():
extrap_events[event.name] = True

# Add the event dictionaryto the solution object
solution.extrap_events = extrap_events

return [k for k, v in extrap_events.items() if v]

def _set_up_ext_and_inputs(self, model, external_variables, inputs):
"Set up external variables and input parameters"
inputs = inputs or {}
Expand Down
66 changes: 65 additions & 1 deletion pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class CasadiSolver(pybamm.BaseSolver):
The maximum global step size (in seconds) used in "safe" mode. If None
the default value corresponds to a non-dimensional time of 0.01
(i.e. ``0.01 * model.timescale_eval``).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
extra_options_setup : dict, optional
Any options to pass to the CasADi integrator when creating the integrator.
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
Expand All @@ -71,10 +73,13 @@ def __init__(
root_tol=1e-6,
max_step_decrease_count=5,
dt_max=None,
extrap_tol=0,
extra_options_setup=None,
extra_options_call=None,
):
super().__init__("problem dependent", rtol, atol, root_method, root_tol)
super().__init__(
"problem dependent", rtol, atol, root_method, root_tol, extrap_tol
)
if mode in ["safe", "fast", "safe without grid"]:
self.mode = mode
else:
Expand All @@ -88,6 +93,7 @@ def __init__(

self.extra_options_setup = extra_options_setup or {}
self.extra_options_call = extra_options_call or {}
self.extrap_tol = extrap_tol

self.name = "CasADi solver with '{}' mode".format(mode)

Expand Down Expand Up @@ -141,6 +147,33 @@ def _integrate(self, model, t_eval, inputs=None):
[event(t, y0, inputs) for event in model.terminate_events_eval]
)
)

extrap_event = [
brosaplanella marked this conversation as resolved.
Show resolved Hide resolved
event(t, y0, inputs)
for event in model.interpolant_extrapolation_events_eval
]

if extrap_event:
if (np.concatenate(extrap_event) < self.extrap_tol).any():
extrap_event_names = []
for event in model.events:
if (
event.event_type
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
and (
event.expression.evaluate(t, y0.full(), inputs=inputs,)
< self.extrap_tol
).any()
):
extrap_event_names.append(event.name[12:])

raise pybamm.SolverError(
"CasADI solver failed because the following interpolation "
"bounds were exceeded at the initial conditions: {}. "
"You may need to provide additional interpolation points "
"outside these bounds.".format(extrap_event_names)
)

pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))

if self.mode == "safe without grid":
Expand Down Expand Up @@ -215,6 +248,37 @@ def _integrate(self, model, t_eval, inputs=None):
]
)
)

extrap_event = [
event(t, current_step_sol.y[:, -1], inputs=inputs)
for event in model.interpolant_extrapolation_events_eval
]

if extrap_event:
if (np.concatenate(extrap_event) < self.extrap_tol).any():
extrap_event_names = []
for event in model.events:
if (
event.event_type
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
and (
event.expression.evaluate(
t,
current_step_sol.y[:, -1].full(),
inputs=inputs,
)
< self.extrap_tol
).any()
):
extrap_event_names.append(event.name[12:])

raise pybamm.SolverError(
"CasADI solver failed because the following interpolation "
"bounds were exceeded: {}. You may need to provide "
"additional interpolation points outside these "
"bounds.".format(extrap_event_names)
)

# Exit loop if the sign of an event changes
# Locate the event time using a root finding algorithm and
# event state using interpolation. The solution is then truncated
Expand Down
7 changes: 6 additions & 1 deletion pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class IDAKLUSolver(pybamm.BaseSolver):
specified by 'root_method' (e.g. "lm", "hybr", ...)
root_tol : float, optional
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
"""

def __init__(
Expand All @@ -44,13 +46,16 @@ def __init__(
atol=1e-6,
root_method="casadi",
root_tol=1e-6,
extrap_tol=0,
max_steps="deprecated",
):

if idaklu_spec is None:
raise ImportError("KLU is not installed")

super().__init__("ida", rtol, atol, root_method, root_tol, max_steps)
super().__init__(
"ida", rtol, atol, root_method, root_tol, extrap_tol, max_steps
)
self.name = "IDA KLU solver"

pybamm.citations.register("hindmarsh2000pvode")
Expand Down
5 changes: 1 addition & 4 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,7 @@ def block_fun(i, j, Ai, Aj):
return Ai
else:
return onp.zeros(
(
Ai.shape[0] if Ai.ndim > 1 else 1,
Aj.shape[1] if Aj.ndim > 1 else 1,
),
(Ai.shape[0] if Ai.ndim > 1 else 1, Aj.shape[1] if Aj.ndim > 1 else 1,),
dtype=Ai.dtype,
)

Expand Down
14 changes: 12 additions & 2 deletions pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class JaxSolver(pybamm.BaseSolver):
The relative tolerance for the solver (default is 1e-6).
atol : float, optional
The absolute tolerance for the solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not (default is 0).
extra_options : dict, optional
Any options to pass to the solver.
Please consult `JAX documentation
Expand All @@ -46,11 +48,19 @@ class JaxSolver(pybamm.BaseSolver):
"""

def __init__(
self, method="RK45", root_method=None, rtol=1e-6, atol=1e-6, extra_options=None
self,
method="RK45",
root_method=None,
rtol=1e-6,
atol=1e-6,
extrap_tol=0,
extra_options=None,
):
# note: bdf solver itself calculates consistent initial conditions so can set
# root_method to none, allow user to override this behavior
super().__init__(method, rtol, atol, root_method=root_method)
super().__init__(
method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol
)
method_options = ["RK45", "BDF"]
if method not in method_options:
raise ValueError("method must be one of {}".format(method_options))
Expand Down
Loading