Skip to content

Commit

Permalink
feat: support sensitivities for pybamm.Simulation and pybamm.Experime…
Browse files Browse the repository at this point in the history
…nt (#4415)

* main changes relate to updating the `BaseSolver.step` function to support this
* `BaseSolver.step` now can use the input Solution to initialise the sensitivities for the new step

---------

Co-authored-by: Eric G. Kratz <[email protected]>
  • Loading branch information
martinjrobins and kratman authored Sep 13, 2024
1 parent ba2aa67 commit 35bcb78
Show file tree
Hide file tree
Showing 14 changed files with 553 additions and 133 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))

## Optimizations
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))

Expand Down
50 changes: 46 additions & 4 deletions src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,54 @@ def _set_random_seed(self):
% (2**32)
)

def set_up_and_parameterise_experiment(self):
def set_up_and_parameterise_experiment(self, solve_kwargs=None):
"""
Create and parameterise the models for each step in the experiment.
This increases set-up time since several models to be processed, but
reduces simulation time since the model formulation is efficient.
"""
parameter_values = self._parameter_values.copy()

# some parameters are used to control the experiment, and should not be
# input parameters
restrict_list = {"Initial temperature [K]", "Ambient temperature [K]"}
for step in self.experiment.steps:
if issubclass(step.__class__, pybamm.experiment.step.BaseStepImplicit):
restrict_list.update(step.get_parameter_values([]).keys())
elif issubclass(step.__class__, pybamm.experiment.step.BaseStepExplicit):
restrict_list.update(["Current function [A]"])
for key in restrict_list:
if key in parameter_values.keys() and isinstance(
parameter_values[key], pybamm.InputParameter
):
raise pybamm.ModelError(
f"Cannot use '{key}' as an input parameter in this experiment. "
f"This experiment is controlled via the following parameters: {restrict_list}. "
f"None of these parameters are able to be input parameters."
)

if (
solve_kwargs is not None
and "calculate_sensitivities" in solve_kwargs
and solve_kwargs["calculate_sensitivities"]
):
for step in self.experiment.steps:
if any(
[
isinstance(
term,
pybamm.experiment.step.step_termination.BaseTermination,
)
for term in step.termination
]
):
pybamm.logger.warning(
f"Step '{step}' has a termination condition based on an event. Sensitivity calculation will be inaccurate "
"if the time of each step event changes rapidly with respect to the parameters. "
)
break

# Set the initial temperature to be the temperature of the first step
# We can set this globally for all steps since any subsequent steps will either
# start at the temperature at the end of the previous step (if non-isothermal
Expand Down Expand Up @@ -303,7 +343,7 @@ def build(self, initial_soc=None, inputs=None):
# rebuilt model so clear solver setup
self._solver._model_set_up = {}

def build_for_experiment(self, initial_soc=None, inputs=None):
def build_for_experiment(self, initial_soc=None, inputs=None, solve_kwargs=None):
"""
Similar to :meth:`Simulation.build`, but for the case of simulating an
experiment, where there may be several models and solvers to build.
Expand All @@ -314,7 +354,7 @@ def build_for_experiment(self, initial_soc=None, inputs=None):
if self.steps_to_built_models:
return
else:
self.set_up_and_parameterise_experiment()
self.set_up_and_parameterise_experiment(solve_kwargs)

# Can process geometry with default parameter values (only electrical
# parameters change between parameter values)
Expand Down Expand Up @@ -497,7 +537,9 @@ def solve(

elif self.operating_mode == "with experiment":
callbacks.on_experiment_start(logs)
self.build_for_experiment(initial_soc=initial_soc, inputs=inputs)
self.build_for_experiment(
initial_soc=initial_soc, inputs=inputs, solve_kwargs=kwargs
)
if t_eval is not None:
pybamm.logger.warning(
"Ignoring t_eval as solution times are specified by the experiment"
Expand Down
158 changes: 138 additions & 20 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,33 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
y0 = root_sol.all_ys[0]
return y0

def _solve_process_calculate_sensitivities_arg(
inputs, model, calculate_sensitivities
):
# get a list-only version of calculate_sensitivities
if isinstance(calculate_sensitivities, bool):
if calculate_sensitivities:
calculate_sensitivities_list = [p for p in inputs.keys()]
else:
calculate_sensitivities_list = []
else:
calculate_sensitivities_list = calculate_sensitivities

calculate_sensitivities_list.sort()
if not hasattr(model, "calculate_sensitivities"):
model.calculate_sensitivities = []

# Check that calculate_sensitivites have not been updated
sensitivities_have_changed = (
calculate_sensitivities_list != model.calculate_sensitivities
)

# save sensitivity parameters so we can identify them later on
# (FYI: this is used in the Solution class)
model.calculate_sensitivities = calculate_sensitivities_list

return calculate_sensitivities_list, sensitivities_have_changed

def solve(
self,
model,
Expand Down Expand Up @@ -700,7 +727,11 @@ def solve(
calculate_sensitivities : list of str or bool, optional
Whether the solver calculates sensitivities of all input parameters. Defaults to False.
If only a subset of sensitivities are required, can also pass a
list of input parameter names
list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances
so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you
require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities
do not take into account the movement of step-transitions wrt input parameters, so do not use this feature
if the timings of your experimental protocol change rapidly with respect to your input parameters.
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
Expand All @@ -722,15 +753,6 @@ def solve(
"""
pybamm.logger.info(f"Start solving {model.name} with {self.name}")

# get a list-only version of calculate_sensitivities
if isinstance(calculate_sensitivities, bool):
if calculate_sensitivities:
calculate_sensitivities_list = [p for p in inputs.keys()]
else:
calculate_sensitivities_list = []
else:
calculate_sensitivities_list = calculate_sensitivities

# Make sure model isn't empty
self._check_empty_model(model)

Expand Down Expand Up @@ -772,6 +794,12 @@ def solve(
self._set_up_model_inputs(model, inputs) for inputs in inputs_list
]

calculate_sensitivities_list, sensitivities_have_changed = (
BaseSolver._solve_process_calculate_sensitivities_arg(
model_inputs_list[0], model, calculate_sensitivities
)
)

# (Re-)calculate consistent initialization
# Assuming initial conditions do not depend on input parameters
# when len(inputs_list) > 1, only `model_inputs_list[0]`
Expand All @@ -792,13 +820,8 @@ def solve(
"for initial conditions."
)

# Check that calculate_sensitivites have not been updated
calculate_sensitivities_list.sort()
if hasattr(model, "calculate_sensitivities"):
model.calculate_sensitivities.sort()
else:
model.calculate_sensitivities = []
if calculate_sensitivities_list != model.calculate_sensitivities:
# if any setup configuration has changed, we need to re-set up
if sensitivities_have_changed:
self._model_set_up.pop(model, None)
# CasadiSolver caches its integrators using model, so delete this too
if isinstance(self, pybamm.CasadiSolver):
Expand Down Expand Up @@ -1066,6 +1089,58 @@ def _check_events_with_initialization(t_eval, model, inputs_dict):
f"Events {event_names} are non-positive at initial conditions"
)

def _set_sens_initial_conditions_from(
self, solution: pybamm.Solution, model: pybamm.BaseModel
) -> tuple:
"""
A restricted version of BaseModel.set_initial_conditions_from that only extracts the
sensitivities from a solution object, and only for a model that has been descretised.
This is used when setting the initial conditions for a sensitivity model.
Parameters
----------
solution : :class:`pybamm.Solution`
The solution to use to initialize the model
model: :class:`pybamm.BaseModel`
The model whose sensitivities to set
Returns
-------
initial_conditions : tuple of ndarray
The initial conditions for the sensitivities, each element of the tuple
corresponds to an input parameter
"""

ninputs = len(model.calculate_sensitivities)
initial_conditions = tuple([] for _ in range(ninputs))
solution = solution.last_state
for var in model.initial_conditions:
final_state = solution[var.name]
final_state = final_state.sensitivities
final_state_eval = tuple(
final_state[key] for key in model.calculate_sensitivities
)

scale, reference = var.scale.value, var.reference.value
for i in range(ninputs):
scaled_final_state_eval = (final_state_eval[i] - reference) / scale
initial_conditions[i].append(scaled_final_state_eval)

# Also update the concatenated initial conditions if the model is already
# discretised
# Unpack slices for sorting
y_slices = {var: slce for var, slce in model.y_slices.items()}
slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()]

# sort equations according to slices
concatenated_initial_conditions = [
casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))])
for init in initial_conditions
]
return concatenated_initial_conditions

def process_t_interp(self, t_interp):
# set a variable for this
no_interp = (not self.supports_interp) and (
Expand All @@ -1092,6 +1167,7 @@ def step(
npts=None,
inputs=None,
save=True,
calculate_sensitivities=False,
t_interp=None,
):
"""
Expand All @@ -1117,6 +1193,14 @@ def step(
Any input parameters to pass to the model when solving
save : bool, optional
Save solution with all previous timesteps. Defaults to True.
calculate_sensitivities : list of str or bool, optional
Whether the solver calculates sensitivities of all input parameters. Defaults to False.
If only a subset of sensitivities are required, can also pass a
list of input parameter names. **Limitations**: sensitivities are not calculated up to numerical tolerances
so are not guarenteed to be within the tolerances set by the solver, please raise an issue if you
require this functionality. Also, when using this feature with `pybamm.Experiment`, the sensitivities
do not take into account the movement of step-transitions wrt input parameters, so do not use this feature
if the timings of your experimental protocol change rapidly with respect to your input parameters.
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
Expand Down Expand Up @@ -1188,8 +1272,15 @@ def step(
# Set up inputs
model_inputs = self._set_up_model_inputs(model, inputs)

# process calculate_sensitivities argument
calculate_sensitivities_list, sensitivities_have_changed = (
BaseSolver._solve_process_calculate_sensitivities_arg(
model_inputs, model, calculate_sensitivities
)
)

first_step_this_model = model not in self._model_set_up
if first_step_this_model:
if first_step_this_model or sensitivities_have_changed:
if len(self._model_set_up) > 0:
existing_model = next(iter(self._model_set_up))
raise RuntimeError(
Expand All @@ -1208,18 +1299,45 @@ def step(
):
pybamm.logger.verbose(f"Start stepping {model.name} with {self.name}")

using_sensitivities = len(model.calculate_sensitivities) > 0

if isinstance(old_solution, pybamm.EmptySolution):
if not first_step_this_model:
# reset y0 to original initial conditions
self.set_up(model, model_inputs, ics_only=True)
elif old_solution.all_models[-1] == model:
# initialize with old solution
model.y0 = old_solution.all_ys[-1][:, -1]
last_state = old_solution.last_state
model.y0 = last_state.all_ys[0]
if using_sensitivities and isinstance(last_state._all_sensitivities, dict):
full_sens = last_state._all_sensitivities["all"][0]
model.y0S = tuple(full_sens[:, i] for i in range(full_sens.shape[1]))

else:
_, concatenated_initial_conditions = model.set_initial_conditions_from(
old_solution, return_type="ics"
)
model.y0 = concatenated_initial_conditions.evaluate(0, inputs=model_inputs)
if using_sensitivities:
model.y0S = self._set_sens_initial_conditions_from(old_solution, model)

# hopefully we'll get rid of explicit sensitivities soon so we can remove this
explicit_sensitivities = model.len_rhs_sens > 0 or model.len_alg_sens > 0
if (
explicit_sensitivities
and using_sensitivities
and not isinstance(old_solution, pybamm.EmptySolution)
and not old_solution.all_models[-1] == model
):
y0_list = []
if model.len_rhs > 0:
y0_list.append(model.y0[: model.len_rhs])
for s in model.y0S:
y0_list.append(s[: model.len_rhs])
if model.len_alg > 0:
y0_list.append(model.y0[model.len_rhs :])
for s in model.y0S:
y0_list.append(s[model.len_rhs :])
model.y0 = casadi.vertcat(*y0_list)

set_up_time = timer.time()

Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
}

if (sensitivity) {
CheckErrors(IDAGetSens(ida_mem, &t_val, yyS));
CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS));
}

// Store Consistent initialization
Expand Down Expand Up @@ -478,7 +478,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
bool hit_adaptive = save_adaptive_steps && retval == IDA_SUCCESS;

if (sensitivity) {
CheckErrors(IDAGetSens(ida_mem, &t_val, yyS));
CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS));
}

if (hit_tinterp) {
Expand All @@ -499,7 +499,7 @@ Solution IDAKLUSolverOpenMP<ExprSet>::solve(
// Reset the states and sensitivities at t = t_val
CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy));
if (sensitivity) {
CheckErrors(IDAGetSens(ida_mem, &t_val, yyS));
CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/casadi_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None):
model,
inputs_dict,
termination="final time",
sensitivities=explicit_sensitivities,
all_sensitivities=explicit_sensitivities,
)
sol.integration_time = integration_time
return sol
Loading

0 comments on commit 35bcb78

Please sign in to comment.