Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Hermite interpolation and observables #4464

Merged
merged 22 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464))
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))
- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449))
- Added phase-dependent particle options to LAM
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ pybind11_add_module(idaklu
src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp
src/pybamm/solvers/c_solvers/idaklu/observe.hpp
src/pybamm/solvers/c_solvers/idaklu/observe.cpp
# IDAKLU expressions - concrete implementations
${IDAKLU_EXPR_CASADI_SOURCE_FILES}
${IDAKLU_EXPR_IREE_SOURCE_FILES}
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def compile_KLU():
"src/pybamm/solvers/c_solvers/idaklu/Solution.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Options.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Options.cpp",
"src/pybamm/solvers/c_solvers/idaklu/observe.hpp",
"src/pybamm/solvers/c_solvers/idaklu/observe.cpp",
"src/pybamm/solvers/c_solvers/idaklu.cpp",
],
)
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@

# Solver classes
from .solvers.solution import Solution, EmptySolution, make_cycle_solution
from .solvers.processed_variable import ProcessedVariable
from .solvers.processed_variable import ProcessedVariable, process_variable
from .solvers.processed_variable_computed import ProcessedVariableComputed
from .solvers.base_solver import BaseSolver
from .solvers.dummy_solver import DummySolver
Expand Down
17 changes: 8 additions & 9 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ def reset_axis(self):
spatial_vars = self.spatial_variable_dict[key]
var_min = np.min(
[
ax_min(var(self.ts_seconds[i], **spatial_vars, warn=False))
ax_min(var(self.ts_seconds[i], **spatial_vars))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
)
var_max = np.max(
[
ax_max(var(self.ts_seconds[i], **spatial_vars, warn=False))
ax_max(var(self.ts_seconds[i], **spatial_vars))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
Expand Down Expand Up @@ -512,7 +512,7 @@ def plot(self, t, dynamic=False):
full_t = self.ts_seconds[i]
(self.plots[key][i][j],) = ax.plot(
full_t / self.time_scaling_factor,
variable(full_t, warn=False),
variable(full_t),
color=self.colors[i],
linestyle=linestyle,
)
Expand Down Expand Up @@ -548,7 +548,7 @@ def plot(self, t, dynamic=False):
linestyle = self.linestyles[j]
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars, warn=False),
variable(t_in_seconds, **spatial_vars),
color=self.colors[i],
linestyle=linestyle,
zorder=10,
Expand All @@ -570,13 +570,13 @@ def plot(self, t, dynamic=False):
y_name = next(iter(spatial_vars.keys()))[0]
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars, warn=False)
var = variable(t_in_seconds, **spatial_vars)
else:
x_name = next(iter(spatial_vars.keys()))[0]
y_name = list(spatial_vars.keys())[1][0]
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars, warn=False).T
var = variable(t_in_seconds, **spatial_vars).T
ax.set_xlabel(f"{x_name} [{self.spatial_unit}]")
ax.set_ylabel(f"{y_name} [{self.spatial_unit}]")
vmin, vmax = self.variable_limits[key]
Expand Down Expand Up @@ -710,7 +710,6 @@ def slider_update(self, t):
var = variable(
time_in_seconds,
**self.spatial_variable_dict[key],
warn=False,
)
plot[i][j].set_ydata(var)
var_min = min(var_min, ax_min(var))
Expand All @@ -729,11 +728,11 @@ def slider_update(self, t):
if self.x_first_and_y_second[key] is False:
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(time_in_seconds, **spatial_vars, warn=False)
var = variable(time_in_seconds, **spatial_vars)
else:
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(time_in_seconds, **spatial_vars, warn=False).T
var = variable(time_in_seconds, **spatial_vars).T
# store the plot and the var data (for testing) as cant access
# z data from QuadMesh or QuadContourSet object
if self.is_y_z[key] is True:
Expand Down
26 changes: 26 additions & 0 deletions src/pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <pybind11/stl_bind.h>

#include "idaklu/idaklu_solver.hpp"
#include "idaklu/observe.hpp"
#include "idaklu/IDAKLUSolverGroup.hpp"
#include "idaklu/IdakluJax.hpp"
#include "idaklu/common.hpp"
Expand All @@ -27,13 +28,15 @@ casadi::Function generate_casadi_function(const std::string &data)
namespace py = pybind11;

PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
PYBIND11_MAKE_OPAQUE(std::vector<np_array_realtype>);
PYBIND11_MAKE_OPAQUE(std::vector<Solution>);

PYBIND11_MODULE(idaklu, m)
{
m.doc() = "sundials solvers"; // optional module docstring

py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
py::bind_vector<std::vector<np_array_realtype>>(m, "VectorRealtypeNdArray");
py::bind_vector<std::vector<Solution>>(m, "VectorSolution");

py::class_<IDAKLUSolverGroup>(m, "IDAKLUSolverGroup")
Expand Down Expand Up @@ -72,6 +75,27 @@ PYBIND11_MODULE(idaklu, m)
py::arg("options"),
py::return_value_policy::take_ownership);

m.def("observe", &observe,
"Observe variables",
py::arg("ts"),
py::arg("ys"),
py::arg("inputs"),
py::arg("funcs"),
py::arg("is_f_contiguous"),
py::arg("shape"),
py::return_value_policy::take_ownership);

m.def("observe_hermite_interp", &observe_hermite_interp,
"Observe and Hermite interpolate variables",
py::arg("t_interp"),
py::arg("ts"),
py::arg("ys"),
py::arg("yps"),
py::arg("inputs"),
py::arg("funcs"),
py::arg("shape"),
py::return_value_policy::take_ownership);

#ifdef IREE_ENABLE
m.def("create_iree_solver_group", &create_idaklu_solver_group<IREEFunctions>,
"Create a group of iree idaklu solver objects",
Expand Down Expand Up @@ -167,7 +191,9 @@ PYBIND11_MODULE(idaklu, m)
py::class_<Solution>(m, "solution")
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yp", &Solution::yp)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("ypS", &Solution::ypS)
.def_readwrite("y_term", &Solution::y_term)
.def_readwrite("flag", &Solution::flag);
}
43 changes: 40 additions & 3 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
int const number_of_events; // cppcheck-suppress unusedStructMember
int number_of_timesteps;
int precon_type; // cppcheck-suppress unusedStructMember
N_Vector yy, yp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance
N_Vector yy, yyp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance
N_Vector *yyS; // cppcheck-suppress unusedStructMember
N_Vector *ypS; // cppcheck-suppress unusedStructMember
N_Vector *yypS; // cppcheck-suppress unusedStructMember
N_Vector id; // rhs_alg_id
realtype rtol;
int const jac_times_cjmass_nnz; // cppcheck-suppress unusedStructMember
Expand All @@ -70,11 +70,14 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
vector<realtype> res_dvar_dp;
bool const sensitivity; // cppcheck-suppress unusedStructMember
bool const save_outputs_only; // cppcheck-suppress unusedStructMember
bool save_hermite; // cppcheck-suppress unusedStructMember
bool is_ODE; // cppcheck-suppress unusedStructMember
int length_of_return_vector; // cppcheck-suppress unusedStructMember
vector<realtype> t; // cppcheck-suppress unusedStructMember
vector<vector<realtype>> y; // cppcheck-suppress unusedStructMember
vector<vector<realtype>> yp; // cppcheck-suppress unusedStructMember
vector<vector<vector<realtype>>> yS; // cppcheck-suppress unusedStructMember
vector<vector<vector<realtype>>> ypS; // cppcheck-suppress unusedStructMember
SetupOptions const setup_opts;
SolverOptions const solver_opts;

Expand Down Expand Up @@ -144,6 +147,11 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
*/
void InitializeStorage(int const N);

/**
* @brief Initialize the storage for Hermite interpolation
*/
void InitializeHermiteStorage(int const N);

/**
* @brief Apply user-configurable IDA options
*/
Expand Down Expand Up @@ -190,13 +198,20 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
*/
void ExtendAdaptiveArrays();

/**
* @brief Extend the Hermite interpolation info by 1
*/
void ExtendHermiteArrays();

/**
* @brief Set the step values
*/
void SetStep(
realtype &t_val,
realtype &tval,
realtype *y_val,
realtype *yp_val,
vector<realtype *> const &yS_val,
vector<realtype *> const &ypS_val,
int &i_save
);

Expand All @@ -211,7 +226,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
realtype &t_prev,
realtype const &t_next,
realtype *y_val,
realtype *yp_val,
vector<realtype *> const &yS_val,
vector<realtype *> const &ypS_val,
int &i_save
);

Expand Down Expand Up @@ -255,6 +272,26 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
int &i_save
);

/**
* @brief Save the output function results at the requested time
*/
void SetStepHermite(
realtype &t_val,
realtype *yp_val,
const vector<realtype*> &ypS_val,
int &i_save
);

/**
* @brief Save the output function sensitivities at the requested time
*/
void SetStepHermiteSensitivities(
realtype &t_val,
realtype *yp_val,
const vector<realtype*> &ypS_val,
int &i_save
);

};

#include "IDAKLUSolverOpenMP.inl"
Expand Down
Loading
Loading