Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I2643 banded #2677

Merged
merged 12 commits into from
Feb 18, 2023
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ PYBIND11_MODULE(idaklu, m)
py::arg("number_of_parameters"), py::arg("rhs_alg"),
py::arg("jac_times_cjmass"), py::arg("jac_times_cjmass_colptrs"),
py::arg("jac_times_cjmass_rowvals"), py::arg("jac_times_cjmass_nnz"),
py::arg("jac_bandwidth_lower"), py::arg("jac_bandwidth_upper"),
py::arg("jac_action"), py::arg("mass_action"), py::arg("sens"),
py::arg("events"), py::arg("number_of_events"), py::arg("rhs_alg_id"),
py::arg("atol"), py::arg("rtol"), py::arg("inputs"), py::arg("options"),
Expand Down
11 changes: 8 additions & 3 deletions pybamm/solvers/c_solvers/idaklu/casadi_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@ void CasadiFunction::operator()()
CasadiFunctions::CasadiFunctions(
const Function &rhs_alg, const Function &jac_times_cjmass,
const int jac_times_cjmass_nnz,
const int jac_bandwidth_lower, const int jac_bandwidth_upper,
const np_array_int &jac_times_cjmass_rowvals_arg,
const np_array_int &jac_times_cjmass_colptrs_arg,
const int inputs_length, const Function &jac_action,
const Function &mass_action, const Function &sens, const Function &events,
const int n_s, int n_e, const int n_p, const Options& options)
: number_of_states(n_s), number_of_events(n_e), number_of_parameters(n_p),
number_of_nnz(jac_times_cjmass_nnz), rhs_alg(rhs_alg),
number_of_nnz(jac_times_cjmass_nnz),
jac_bandwidth_lower(jac_bandwidth_lower), jac_bandwidth_upper(jac_bandwidth_upper),
rhs_alg(rhs_alg),
jac_times_cjmass(jac_times_cjmass), jac_action(jac_action),
mass_action(mass_action), sens(sens), events(events),
tmp(number_of_states),
tmp_state_vector(number_of_states),
tmp_sparse_jacobian_data(jac_times_cjmass_nnz),
options(options)
{

Expand All @@ -66,4 +70,5 @@ CasadiFunctions::CasadiFunctions(

}

realtype *CasadiFunctions::get_tmp() { return tmp.data(); }
realtype *CasadiFunctions::get_tmp_state_vector() { return tmp_state_vector.data(); }
realtype *CasadiFunctions::get_tmp_sparse_jacobian_data() { return tmp_sparse_jacobian_data.data(); }
9 changes: 7 additions & 2 deletions pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class CasadiFunctions
int number_of_parameters;
int number_of_events;
int number_of_nnz;
int jac_bandwidth_lower;
int jac_bandwidth_upper;
CasadiFunction rhs_alg;
CasadiFunction sens;
CasadiFunction jac_times_cjmass;
Expand All @@ -44,17 +46,20 @@ class CasadiFunctions

CasadiFunctions(const Function &rhs_alg, const Function &jac_times_cjmass,
const int jac_times_cjmass_nnz,
const int jac_bandwidth_lower, const int jac_bandwidth_upper,
const np_array_int &jac_times_cjmass_rowvals,
const np_array_int &jac_times_cjmass_colptrs,
const int inputs_length, const Function &jac_action,
const Function &mass_action, const Function &sens,
const Function &events, const int n_s, int n_e,
const int n_p, const Options& options);

realtype *get_tmp();
realtype *get_tmp_state_vector();
realtype *get_tmp_sparse_jacobian_data();

private:
std::vector<realtype> tmp;
std::vector<realtype> tmp_state_vector;
std::vector<realtype> tmp_sparse_jacobian_data;
};

#endif // PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP
28 changes: 24 additions & 4 deletions pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,31 @@ create_casadi_solver(int number_of_states, int number_of_parameters,
const Function &rhs_alg, const Function &jac_times_cjmass,
const np_array_int &jac_times_cjmass_colptrs,
const np_array_int &jac_times_cjmass_rowvals,
const int jac_times_cjmass_nnz, const Function &jac_action,
const int jac_times_cjmass_nnz,
const int jac_bandwidth_lower, const int jac_bandwidth_upper,
const Function &jac_action,
const Function &mass_action, const Function &sens,
const Function &events, const int number_of_events,
np_array rhs_alg_id, np_array atol_np, double rel_tol,
int inputs_length, py::dict options)
{
auto options_cpp = Options(options);
auto functions = std::make_unique<CasadiFunctions>(
rhs_alg, jac_times_cjmass, jac_times_cjmass_nnz, jac_times_cjmass_rowvals,
rhs_alg, jac_times_cjmass, jac_times_cjmass_nnz, jac_bandwidth_lower, jac_bandwidth_upper, jac_times_cjmass_rowvals,
jac_times_cjmass_colptrs, inputs_length, jac_action, mass_action, sens,
events, number_of_states, number_of_events, number_of_parameters,
options_cpp);

return new CasadiSolver(atol_np, rel_tol, rhs_alg_id, number_of_parameters,
number_of_events, jac_times_cjmass_nnz,
number_of_events, jac_times_cjmass_nnz,
jac_bandwidth_lower, jac_bandwidth_upper,
std::move(functions), options_cpp);
}

CasadiSolver::CasadiSolver(np_array atol_np, double rel_tol,
np_array rhs_alg_id, int number_of_parameters,
int number_of_events, int jac_times_cjmass_nnz,
int jac_bandwidth_lower, int jac_bandwidth_upper,
std::unique_ptr<CasadiFunctions> functions_arg,
const Options &options)
: number_of_states(atol_np.request().size),
Expand Down Expand Up @@ -107,7 +111,14 @@ CasadiSolver::CasadiSolver(np_array atol_np, double rel_tol,
jac_times_cjmass_nnz, CSC_MAT);
#endif
}
else if (options.jacobian == "dense" || options.jacobian == "none")
else if (options.jacobian == "banded") {
DEBUG("\tsetting banded matrix");
#if SUNDIALS_VERSION_MAJOR >= 6
J = SUNBandMatrix(number_of_states, jac_bandwidth_upper, jac_bandwidth_lower, sunctx);
#else
J = SUNBandMatrix(number_of_states, jac_bandwidth_upper, jac_bandwidth_lower);
#endif
} else if (options.jacobian == "dense" || options.jacobian == "none")
{
DEBUG("\tsetting dense matrix");
#if SUNDIALS_VERSION_MAJOR >= 6
Expand Down Expand Up @@ -151,6 +162,15 @@ CasadiSolver::CasadiSolver(np_array atol_np, double rel_tol,
LS = SUNLinSol_KLU(yy, J, sunctx);
#else
LS = SUNLinSol_KLU(yy, J);
#endif
}
else if (options.linear_solver == "SUNLinSol_Band")
{
DEBUG("\tsetting SUNLinSol_Band linear solver");
#if SUNDIALS_VERSION_MAJOR >= 6
LS = SUNLinSol_Band(yy, J, sunctx);
#else
LS = SUNLinSol_Band(yy, J);
#endif
}
else if (options.linear_solver == "SUNLinSol_SPBCGS")
Expand Down
6 changes: 4 additions & 2 deletions pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CasadiSolver
public:
CasadiSolver(np_array atol_np, double rel_tol, np_array rhs_alg_id,
int number_of_parameters, int number_of_events,
int jac_times_cjmass_nnz,
int jac_times_cjmass_nnz, int jac_bandwidth_lower, int jac_bandwidth_upper,
std::unique_ptr<CasadiFunctions> functions, const Options& options);
~CasadiSolver();

Expand Down Expand Up @@ -48,7 +48,9 @@ create_casadi_solver(int number_of_states, int number_of_parameters,
const Function &rhs_alg, const Function &jac_times_cjmass,
const np_array_int &jac_times_cjmass_colptrs,
const np_array_int &jac_times_cjmass_rowvals,
const int jac_times_cjmass_nnz, const Function &jac_action,
const int jac_times_cjmass_nnz,
const int jac_bandwidth_lower, const int jac_bandwidth_upper,
const Function &jac_action,
const Function &mass_action, const Function &sens,
const Function &event, const int number_of_events,
np_array rhs_alg_id, np_array atol_np,
Expand Down
36 changes: 28 additions & 8 deletions pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
p_python_functions->rhs_alg.m_res[0] = NV_DATA_S(rr);
p_python_functions->rhs_alg();

realtype *tmp = p_python_functions->get_tmp();
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(yp);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();
Expand Down Expand Up @@ -108,7 +108,7 @@ int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr,
p_python_functions->jac_action();

// tmp has -∂F/∂y˙ v
realtype *tmp = p_python_functions->get_tmp();
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(v);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();
Expand Down Expand Up @@ -148,15 +148,14 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp,
static_cast<CasadiFunctions *>(user_data);

// create pointer to jac data, column pointers, and row values
sunindextype *jac_colptrs;
sunindextype *jac_rowvals;
realtype *jac_data;
if (p_python_functions->options.using_sparse_matrix)
{
jac_colptrs = SUNSparseMatrix_IndexPointers(JJ);
jac_rowvals = SUNSparseMatrix_IndexValues(JJ);
jac_data = SUNSparseMatrix_Data(JJ);
}
else if (p_python_functions->options.using_banded_matrix) {
jac_data = p_python_functions->get_tmp_sparse_jacobian_data();
}
else
{
jac_data = SUNDenseMatrix_Data(JJ);
Expand All @@ -169,10 +168,31 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp,
p_python_functions->inputs.data();
p_python_functions->jac_times_cjmass.m_arg[3] = &cj;
p_python_functions->jac_times_cjmass.m_res[0] = jac_data;

p_python_functions->jac_times_cjmass();

if (p_python_functions->options.using_sparse_matrix)

if (p_python_functions->options.using_banded_matrix)
{
// copy data from temporary matrix to the banded matrix
auto jac_colptrs = p_python_functions->jac_times_cjmass_colptrs.data();
auto jac_rowvals = p_python_functions->jac_times_cjmass_rowvals.data();
int ncols = p_python_functions->number_of_states;
for (int col_ij = 0; col_ij < ncols; col_ij++) {
realtype *banded_col = SM_COLUMN_B(JJ, col_ij);
for (auto data_i = jac_colptrs[col_ij]; data_i < jac_colptrs[col_ij+1]; data_i++) {
auto row_ij = jac_rowvals[data_i];
const realtype value_ij = jac_data[data_i];
DEBUG("(" << row_ij << ", " << col_ij << ") = " << value_ij);
SM_COLUMN_ELEMENT_B(banded_col, row_ij, col_ij) = value_ij;
}
}
}
else if (p_python_functions->options.using_sparse_matrix)
{

sunindextype *jac_colptrs = SUNSparseMatrix_IndexPointers(JJ);
sunindextype *jac_rowvals = SUNSparseMatrix_IndexValues(JJ);
// row vals and col ptrs
const int n_row_vals = p_python_functions->jac_times_cjmass_rowvals.size();
auto p_jac_times_cjmass_rowvals =
Expand Down Expand Up @@ -262,7 +282,7 @@ int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp,
for (int i = 0; i < np; i++)
{
// put (∂F/∂y)s i (t) in tmp
realtype *tmp = p_python_functions->get_tmp();
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->jac_action.m_arg[0] = &t;
p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data();
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <sunlinsol/sunlinsol_klu.h> /* access to KLU linear solver */
#include <sunlinsol/sunlinsol_dense.h> /* access to dense linear solver */
#include <sunlinsol/sunlinsol_band.h> /* access to dense linear solver */
#include <sunlinsol/sunlinsol_spbcgs.h> /* access to spbcgs iterative linear solver */
#include <sunlinsol/sunlinsol_spfgmr.h>
#include <sunlinsol/sunlinsol_spgmr.h>
Expand Down
19 changes: 18 additions & 1 deletion pybamm/solvers/c_solvers/idaklu/options.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "options.hpp"
#include <iostream>
#include <stdexcept>


Expand All @@ -15,9 +16,14 @@ Options::Options(py::dict options)
{

using_sparse_matrix = true;
using_banded_matrix = false;
if (jacobian == "sparse")
{
}
else if (jacobian == "banded") {
using_banded_matrix = true;
using_sparse_matrix = false;
}
else if (jacobian == "dense" || jacobian == "none")
{
using_sparse_matrix = false;
Expand All @@ -29,7 +35,7 @@ Options::Options(py::dict options)
{
throw std::domain_error(
"Unknown jacobian type \""s + jacobian +
"\". Should be one of \"sparse\", \"dense\", \"matrix-free\" or \"none\"."s
"\". Should be one of \"sparse\", \"banded\", \"dense\", \"matrix-free\" or \"none\"."s
);
}

Expand All @@ -40,6 +46,17 @@ Options::Options(py::dict options)
else if (linear_solver == "SUNLinSol_KLU" && jacobian == "sparse")
{
}
else if (linear_solver == "SUNLinSol_Band" && jacobian == "banded")
{
}
else if (jacobian == "banded") {
throw std::domain_error(
"Unknown linear solver or incompatible options: "
"jacobian = \"" + jacobian + "\" linear solver = \"" + linear_solver +
"\". For a banded jacobian "
"please use the SUNLinSol_Band linear solver"
);
}
else if ((linear_solver == "SUNLinSol_SPBCGS" ||
linear_solver == "SUNLinSol_SPFGMR" ||
linear_solver == "SUNLinSol_SPGMR" ||
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
struct Options {
bool print_stats;
bool using_sparse_matrix;
bool using_banded_matrix;
bool using_iterative_solver;
std::string jacobian;
std::string linear_solver; // klu, lapack, spbcg
Expand Down
12 changes: 10 additions & 2 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ class IDAKLUSolver(pybamm.BaseSolver):
# print statistics of the solver after every solve
"print_stats": False,

# jacobian form, can be "none", "dense", "sparse", "matrix-free"
# jacobian form, can be "none", "dense",
# "banded", "sparse", "matrix-free"
"jacobian": "sparse",

# name of sundials linear solver to use options are: "SUNLinSol_KLU",
# "SUNLinSol_Dense", "SUNLinSol_SPBCGS",
# "SUNLinSol_Dense", "SUNLinSol_Band", "SUNLinSol_SPBCGS",
# "SUNLinSol_SPFGMR", "SUNLinSol_SPGMR", "SUNLinSol_SPTFQMR",
"linear_solver": "SUNLinSol_KLU",

Expand Down Expand Up @@ -275,7 +276,10 @@ def resfn(t, y, inputs, ydot):
- cj_casadi * mass_matrix
],
)

jac_times_cjmass_sparsity = jac_times_cjmass.sparsity_out(0)
jac_bw_lower = jac_times_cjmass_sparsity.bw_lower()
jac_bw_upper = jac_times_cjmass_sparsity.bw_upper()
jac_times_cjmass_nnz = jac_times_cjmass_sparsity.nnz()
jac_times_cjmass_colptrs = np.array(
jac_times_cjmass_sparsity.colind(), dtype=np.int64
Expand Down Expand Up @@ -448,6 +452,8 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS):
sensfn = idaklu.generate_function(sensfn.serialize())

self._setup = {
"jac_bandwidth_upper": jac_bw_upper,
"jac_bandwidth_lower": jac_bw_lower,
"rhs_algebraic": rhs_algebraic,
"jac_times_cjmass": jac_times_cjmass,
"jac_times_cjmass_colptrs": jac_times_cjmass_colptrs,
Expand All @@ -471,6 +477,8 @@ def sensfn(resvalS, t, y, inputs, yp, yS, ypS):
self._setup["jac_times_cjmass_colptrs"],
self._setup["jac_times_cjmass_rowvals"],
self._setup["jac_times_cjmass_nnz"],
jac_bw_lower,
jac_bw_upper,
self._setup["jac_rhs_algebraic_action"],
self._setup["mass_action"],
self._setup["sensfn"],
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,30 @@ def test_dae_solver_algebraic_model(self):
solution = solver.solve(model, t_eval)
np.testing.assert_array_equal(solution.y, -1)

def test_banded(self):
model = pybamm.lithium_ion.SPM()
model.convert_to_format = "casadi"
param = model.default_parameter_values
param.process_model(model)
geometry = model.default_geometry
param.process_geometry(geometry)
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

t_eval = np.linspace(0, 3600, 100)
solver = pybamm.IDAKLUSolver()
soln = solver.solve(model, t_eval)

options = {
"jacobian": "banded",
"linear_solver": "SUNLinSol_Band",
}
solver_banded = pybamm.IDAKLUSolver(options=options)
soln_banded = solver_banded.solve(model, t_eval)

np.testing.assert_array_almost_equal(soln.y, soln_banded.y, 5)

def test_options(self):
model = pybamm.BaseModel()
u = pybamm.Variable("u")
Expand Down