Skip to content

Commit

Permalink
pybamm-team#976 added error for extrapolation with CasADI
Browse files Browse the repository at this point in the history
  • Loading branch information
brosaplanella committed Dec 30, 2020
1 parent c5588d5 commit bbcfc22
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 3 deletions.
1 change: 1 addition & 0 deletions pybamm/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class EventType(Enum):

TERMINATION = 0
DISCONTINUITY = 1
INTERPOLANT_EXTRAPOLATION = 2


class Event:
Expand Down
29 changes: 29 additions & 0 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, values=None, chemistry=None):

# Initialise empty _processed_symbols dict (for caching)
self._processed_symbols = {}
self.parameter_events = []

def __getitem__(self, key):
return self._dict_items[key]
Expand Down Expand Up @@ -410,6 +411,17 @@ def process_model(self, unprocessed_model, inplace=True):
event.name, self.process_symbol(event.expression), event.event_type
)
)

for event in self.parameter_events:
pybamm.logger.debug(
"Processing parameters for event'{}''".format(event.name)
)
new_events.append(
pybamm.Event(
event.name, self.process_symbol(event.expression), event.event_type
)
)

model.events = new_events

# Set external variables
Expand Down Expand Up @@ -545,6 +557,23 @@ def _process_symbol(self, symbol):
# to create an Interpolant
name, data = function_name
function = pybamm.Interpolant(data, *new_children, name=name)
# Define event to catch extrapolation. In these events the sign is
# important: it should be positive inside of the range and negative
# outside of it
self.parameter_events.append(
pybamm.Event(
"Interpolant {} lower bound".format(name),
new_children[0] - min(function.x),
pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
)
)
self.parameter_events.append(
pybamm.Event(
"Interpolant {} upper bound".format(name),
max(function.x) - new_children[0],
pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
)
)
elif isinstance(function_name, numbers.Number):
# If the "function" is provided is actually a scalar, return a Scalar
# object instead of throwing an error.
Expand Down
7 changes: 7 additions & 0 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def report(string):
if event.event_type == pybamm.EventType.TERMINATION
]

interpolant_extrapolation_events_eval = [
process(event.expression, "event", use_jacobian=False)[1]
for event in model.events
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION
]

# discontinuity events are evaluated before the solver is called, so don't need
# to process them
discontinuity_events_eval = [
Expand All @@ -376,6 +382,7 @@ def report(string):
model.jac_algebraic_eval = jac_algebraic
model.terminate_events_eval = terminate_events_eval
model.discontinuity_events_eval = discontinuity_events_eval
model.interpolant_extrapolation_events_eval = interpolant_extrapolation_events_eval

# Calculate initial conditions
model.y0 = init_eval(inputs)
Expand Down
56 changes: 56 additions & 0 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
root_tol=1e-6,
max_step_decrease_count=5,
dt_max=None,
extrap_tol=1e-3,
extra_options_setup=None,
extra_options_call=None,
):
Expand All @@ -88,6 +89,7 @@ def __init__(

self.extra_options_setup = extra_options_setup or {}
self.extra_options_call = extra_options_call or {}
self.extrap_tol = extrap_tol

self.name = "CasADi solver with '{}' mode".format(mode)

Expand Down Expand Up @@ -143,6 +145,32 @@ def _integrate(self, model, t_eval, inputs=None):
[event(t, y0, inputs) for event in model.terminate_events_eval]
)
)

extrap_event = [
event(t, y0, inputs)
for event in model.interpolant_extrapolation_events_eval
]

if extrap_event:
if (np.concatenate(extrap_event) < self.extrap_tol).any():
extrap_event_names = []
for event in model.events:
if (
event.event_type
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
and (
event.expression.evaluate(t, y0, inputs=inputs,)
< self.extrap_tol
).any()
):
extrap_event_names.append(event.name[12:])

raise pybamm.SolverError(
"CasADI solver failed because the following interpolation bounds were exceeded: {}".format(
extrap_event_names
)
)

pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))

if self.mode == "safe without grid":
Expand Down Expand Up @@ -217,6 +245,34 @@ def _integrate(self, model, t_eval, inputs=None):
]
)
)

extrap_event = [
event(t, current_step_sol.y[:, -1], inputs=inputs)
for event in model.interpolant_extrapolation_events_eval
]

if extrap_event:
if (np.concatenate(extrap_event) < self.extrap_tol).any():
extrap_event_names = []
for event in model.events:
if (
event.event_type
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
and (
event.expression.evaluate(
t, current_step_sol.y[:, -1], inputs=inputs
)
< self.extrap_tol
).any()
):
extrap_event_names.append(event.name[12:])

raise pybamm.SolverError(
"CasADI solver failed because the following interpolation bounds were exceeded: {}".format(
extrap_event_names
)
)

# Exit loop if the sign of an event changes
# Locate the event time using a root finding algorithm and
# event state using interpolation. The solution is then truncated
Expand Down
22 changes: 19 additions & 3 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,35 @@ def test_interpolant_extrapolate(self):
model = pybamm.lithium_ion.DFN()
param = pybamm.ParameterValues(chemistry=pybamm.parameter_sets.Chen2020)
experiment = pybamm.Experiment(
["Discharge at 1C until 2.5 V", "Rest for 2 hours",], period="5 seconds"
["Charge at 1C until 4.6 V"], period="10 seconds"
)

param["Upper voltage cut-off [V]"] = 4.8

sim = pybamm.Simulation(
model,
parameter_values=param,
experiment=experiment,
solver=pybamm.CasadiSolver(mode="safe", dt_max=0.001, extra_options_setup={"max_num_steps": 500}),
)
with self.assertRaisesRegex(
pybamm.SolverError, "interpolation bounds"
):
sim.solve()

ci = param["Initial concentration in positive electrode [mol.m-3]"]
param["Initial concentration in positive electrode [mol.m-3]"] = 0.8 * ci

sim = pybamm.Simulation(
model,
parameter_values=param,
experiment=experiment,
solver=pybamm.CasadiSolver(mode="safe"),
solver=pybamm.CasadiSolver(mode="safe", dt_max=0.05),
)
sim.solve()
with self.assertRaisesRegex(
pybamm.SolverError, "interpolation bounds"
):
sim.solve()


class TestCasadiSolverSensitivity(unittest.TestCase):
Expand Down

0 comments on commit bbcfc22

Please sign in to comment.