diff --git a/CHANGELOG.md b/CHANGELOG.md index ee21541239..82be87f95f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ - Fixed memory issue that caused failure when `output variables` were specified with (`IDAKLUSolver`). ([#4379](https://github.com/pybamm-team/PyBaMM/issues/4379)) - Fixed bug where IDAKLU solver failed when `output variables` were specified and an event triggered. ([#4300](https://github.com/pybamm-team/PyBaMM/pull/4300)) +## Breaking changes + +- Removed legacy python-IDAKLU solver. ([#4326](https://github.com/pybamm-team/PyBaMM/pull/4326)) + # [v24.5](https://github.com/pybamm-team/PyBaMM/tree/v24.5) - 2024-07-26 ## Features diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b3a2adfe5..ad56ac34ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,8 +92,6 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp src/pybamm/solvers/c_solvers/idaklu/common.hpp src/pybamm/solvers/c_solvers/idaklu/common.cpp - src/pybamm/solvers/c_solvers/idaklu/python.hpp - src/pybamm/solvers/c_solvers/idaklu/python.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.hpp src/pybamm/solvers/c_solvers/idaklu/Options.hpp diff --git a/setup.py b/setup.py index 6ceb049b31..74de1baca4 100644 --- a/setup.py +++ b/setup.py @@ -323,8 +323,6 @@ def compile_KLU(): "src/pybamm/solvers/c_solvers/idaklu/IdakluJax.hpp", "src/pybamm/solvers/c_solvers/idaklu/common.hpp", "src/pybamm/solvers/c_solvers/idaklu/common.cpp", - "src/pybamm/solvers/c_solvers/idaklu/python.hpp", - "src/pybamm/solvers/c_solvers/idaklu/python.cpp", "src/pybamm/solvers/c_solvers/idaklu/Solution.cpp", "src/pybamm/solvers/c_solvers/idaklu/Solution.hpp", "src/pybamm/solvers/c_solvers/idaklu/Options.hpp", diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index e90f974676..0d4638e178 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -39,6 +39,7 @@ class BaseModel: calling `evaluate(t, y)` on the given expression treeself. - "casadi": convert into CasADi expression tree, which then uses CasADi's \ algorithm to calculate the Jacobian. + - "jax": convert into JAX expression tree Default is "casadi". """ diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 9027bd51c4..efef7e9357 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -1514,26 +1514,13 @@ def report(string): elif model.convert_to_format != "casadi": y = vars_for_processing["y"] jacobian = vars_for_processing["jacobian"] - # Process with pybamm functions, converting - # to python evaluator + if model.calculate_sensitivities: - report( - f"Calculating sensitivities for {name} with respect " - f"to parameters {model.calculate_sensitivities}" + raise pybamm.SolverError( # pragma: no cover + "Sensitivies are no longer supported for the python " + "evaluator. Please use `convert_to_format = 'casadi'`, or `jax` " + "to calculate sensitivities." ) - jacp_dict = { - p: symbol.diff(pybamm.InputParameter(p)) - for p in model.calculate_sensitivities - } - - report(f"Converting sensitivities for {name} to python") - jacp_dict = { - p: pybamm.EvaluatorPython(jacp) for p, jacp in jacp_dict.items() - } - - # jacp should be a function that returns a dict of sensitivities - def jacp(*args, **kwargs): - return {k: v(*args, **kwargs) for k, v in jacp_dict.items()} else: jacp = None diff --git a/src/pybamm/solvers/c_solvers/idaklu.cpp b/src/pybamm/solvers/c_solvers/idaklu.cpp index bb9466d40b..3ef0194403 100644 --- a/src/pybamm/solvers/c_solvers/idaklu.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu.cpp @@ -11,7 +11,6 @@ #include "idaklu/idaklu_solver.hpp" #include "idaklu/IdakluJax.hpp" #include "idaklu/common.hpp" -#include "idaklu/python.hpp" #include "idaklu/Expressions/Casadi/CasadiFunctions.hpp" #ifdef IREE_ENABLE @@ -34,28 +33,6 @@ PYBIND11_MODULE(idaklu, m) py::bind_vector>(m, "VectorNdArray"); - m.def("solve_python", &solve_python, - "The solve function for python evaluators", - py::arg("t"), - py::arg("y0"), - py::arg("yp0"), - py::arg("res"), - py::arg("jac"), - py::arg("sens"), - py::arg("get_jac_data"), - py::arg("get_jac_row_vals"), - py::arg("get_jac_col_ptr"), - py::arg("nnz"), - py::arg("events"), - py::arg("number_of_events"), - py::arg("use_jacobian"), - py::arg("rhs_alg_id"), - py::arg("atol"), - py::arg("rtol"), - py::arg("inputs"), - py::arg("number_of_sensitivity_parameters"), - py::return_value_policy::take_ownership); - py::class_(m, "IDAKLUSolver") .def("solve", &IDAKLUSolver::solve, "perform a solve", diff --git a/src/pybamm/solvers/c_solvers/idaklu/python.cpp b/src/pybamm/solvers/c_solvers/idaklu/python.cpp deleted file mode 100644 index 015f504086..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/python.cpp +++ /dev/null @@ -1,486 +0,0 @@ -#include "common.hpp" -#include "python.hpp" -#include - -class PybammFunctions -{ -public: - int number_of_states; - int number_of_parameters; - int number_of_events; - - PybammFunctions(const residual_type &res, const jacobian_type &jac, - const sensitivities_type &sens, - const jac_get_type &get_jac_data_in, - const jac_get_type &get_jac_row_vals_in, - const jac_get_type &get_jac_col_ptrs_in, - const event_type &event, - const int n_s, int n_e, const int n_p, - const np_array &inputs) - : number_of_states(n_s), number_of_events(n_e), - number_of_parameters(n_p), - py_res(res), py_jac(jac), - py_sens(sens), - py_event(event), py_get_jac_data(get_jac_data_in), - py_get_jac_row_vals(get_jac_row_vals_in), - py_get_jac_col_ptrs(get_jac_col_ptrs_in), - inputs(inputs) - { - } - - np_array operator()(double t, np_array y, np_array yp) - { - return py_res(t, y, inputs, yp); - } - - np_array res(double t, np_array y, np_array yp) - { - return py_res(t, y, inputs, yp); - } - - void jac(double t, np_array y, double cj) - { - // this function evaluates the jacobian and sets it to be the attribute - // of a python class which can then be called by get_jac_data, - // get_jac_col_ptr, etc - py_jac(t, y, inputs, cj); - } - - void sensitivities( - std::vector& resvalS, - const double t, const np_array& y, const np_array& yp, - const std::vector& yS, const std::vector& ypS) - { - // this function evaluates the sensitivity equations required by IDAS, - // returning them in resvalS, which is preallocated as a numpy array - // of size (np, n), where n is the number of states and np is the number - // of parameters - // - // yS and ypS are also shape (np, n), y and yp are shape (n) - // - // dF/dy * s_i + dF/dyd * sd + dFdp_i for i in range(np) - py_sens(resvalS, t, y, inputs, yp, yS, ypS); - } - - np_array get_jac_data() { - return py_get_jac_data(); - } - - np_array get_jac_row_vals() { - return py_get_jac_row_vals(); - } - - np_array get_jac_col_ptrs() { - return py_get_jac_col_ptrs(); - } - - np_array events(double t, np_array y) { - return py_event(t, y, inputs); - } - -private: - residual_type py_res; - sensitivities_type py_sens; - jacobian_type py_jac; - event_type py_event; - jac_get_type py_get_jac_data; - jac_get_type py_get_jac_row_vals; - jac_get_type py_get_jac_col_ptrs; - const np_array &inputs; -}; - -int residual(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, - void *user_data) -{ - PybammFunctions *python_functions_ptr = - static_cast(user_data); - PybammFunctions python_functions = *python_functions_ptr; - - realtype *yval, *ypval, *rval; - yval = N_VGetArrayPointer(yy); - ypval = N_VGetArrayPointer(yp); - rval = N_VGetArrayPointer(rr); - - int n = python_functions.number_of_states; - py::array_t y_np = py::array_t(n, yval); - py::array_t yp_np = py::array_t(n, ypval); - - py::array_t r_np; - - r_np = python_functions.res(tres, y_np, yp_np); - - auto r_np_ptr = r_np.unchecked<1>(); - - // just copying data - int i; - for (i = 0; i < n; i++) - { - rval[i] = r_np_ptr[i]; - } - return 0; -} - -int jacobian(realtype tt, realtype cj, N_Vector yy, N_Vector yp, - N_Vector resvec, SUNMatrix JJ, void *user_data, N_Vector tempv1, - N_Vector tempv2, N_Vector tempv3) -{ - realtype *yval; - yval = N_VGetArrayPointer(yy); - - PybammFunctions *python_functions_ptr = - static_cast(user_data); - PybammFunctions python_functions = *python_functions_ptr; - - int n = python_functions.number_of_states; - py::array_t y_np = py::array_t(n, yval); - - // create pointer to jac data, column pointers, and row values - sunindextype *jac_colptrs = SUNSparseMatrix_IndexPointers(JJ); - sunindextype *jac_rowvals = SUNSparseMatrix_IndexValues(JJ); - realtype *jac_data = SUNSparseMatrix_Data(JJ); - - py::array_t jac_np_array; - - python_functions.jac(tt, y_np, cj); - - np_array jac_np_data = python_functions.get_jac_data(); - int n_data = jac_np_data.request().size; - auto jac_np_data_ptr = jac_np_data.unchecked<1>(); - - // just copy across data - int i; - for (i = 0; i < n_data; i++) - { - jac_data[i] = jac_np_data_ptr[i]; - } - - np_array jac_np_row_vals = python_functions.get_jac_row_vals(); - int n_row_vals = jac_np_row_vals.request().size; - - auto jac_np_row_vals_ptr = jac_np_row_vals.unchecked<1>(); - // just copy across row vals (this might be unneeded) - for (i = 0; i < n_row_vals; i++) - { - jac_rowvals[i] = jac_np_row_vals_ptr[i]; - } - - np_array jac_np_col_ptrs = python_functions.get_jac_col_ptrs(); - int n_col_ptrs = jac_np_col_ptrs.request().size; - auto jac_np_col_ptrs_ptr = jac_np_col_ptrs.unchecked<1>(); - - // just copy across col ptrs (this might be unneeded) - for (i = 0; i < n_col_ptrs; i++) - { - jac_colptrs[i] = jac_np_col_ptrs_ptr[i]; - } - - return (0); -} - -int events(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, - void *user_data) -{ - realtype *yval; - yval = N_VGetArrayPointer(yy); - - PybammFunctions *python_functions_ptr = - static_cast(user_data); - PybammFunctions python_functions = *python_functions_ptr; - - int number_of_events = python_functions.number_of_events; - int number_of_states = python_functions.number_of_states; - py::array_t y_np = py::array_t(number_of_states, yval); - - py::array_t events_np_array; - - events_np_array = python_functions.events(t, y_np); - - auto events_np_data_ptr = events_np_array.unchecked<1>(); - - // just copying data (figure out how to pass pointers later) - int i; - for (i = 0; i < number_of_events; i++) - { - events_ptr[i] = events_np_data_ptr[i]; - } - - return (0); -} - -int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp, - N_Vector resval, N_Vector *yS, N_Vector *ypS, N_Vector *resvalS, - void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) { -// This function computes the sensitivity residual for all sensitivity -// equations. It must compute the vectors -// (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) and store them in resvalS[i]. -// Ns is the number of sensitivities. -// t is the current value of the independent variable. -// yy is the current value of the state vector, y(t). -// yp is the current value of ẏ(t). -// resval contains the current value F of the original DAE residual. -// yS contains the current values of the sensitivities s i . -// ypS contains the current values of the sensitivity derivatives ṡ i . -// resvalS contains the output sensitivity residual vectors. -// Memory allocation for resvalS is handled within idas. -// user data is a pointer to user data. -// tmp1, tmp2, tmp3 are N Vectors of length N which can be used as -// temporary storage. -// -// Return value An IDASensResFn should return 0 if successful, -// a positive value if a recoverable error -// occurred (in which case idas will attempt to correct), -// or a negative value if it failed unrecoverably (in which case the integration is halted and IDA SRES FAIL is returned) -// - PybammFunctions *python_functions_ptr = - static_cast(user_data); - PybammFunctions python_functions = *python_functions_ptr; - - int n = python_functions.number_of_states; - int np = python_functions.number_of_parameters; - - // memory managed by sundials, so pass a destructor that does nothing - auto state_vector_shape = std::vector {n, 1}; - np_array y_np = np_array(state_vector_shape, N_VGetArrayPointer(yy), - py::capsule(&yy, [](void* p) {})); - np_array yp_np = np_array(state_vector_shape, N_VGetArrayPointer(yp), - py::capsule(&yp, [](void* p) {})); - - std::vector yS_np(np); - for (int i = 0; i < np; i++) { - auto capsule = py::capsule(yS + i, [](void* p) {}); - yS_np[i] = np_array(state_vector_shape, N_VGetArrayPointer(yS[i]), capsule); - } - - std::vector ypS_np(np); - for (int i = 0; i < np; i++) { - auto capsule = py::capsule(ypS + i, [](void* p) {}); - ypS_np[i] = np_array(state_vector_shape, N_VGetArrayPointer(ypS[i]), capsule); - } - - std::vector resvalS_np(np); - for (int i = 0; i < np; i++) { - auto capsule = py::capsule(resvalS + i, [](void* p) {}); - resvalS_np[i] = np_array(state_vector_shape, - N_VGetArrayPointer(resvalS[i]), capsule); - } - - realtype *ptr1 = static_cast(resvalS_np[0].request().ptr); - const realtype* resvalSval = N_VGetArrayPointer(resvalS[0]); - - python_functions.sensitivities(resvalS_np, t, y_np, yp_np, yS_np, ypS_np); - - return 0; -} - -/* main program */ -Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np, - residual_type res, jacobian_type jac, - sensitivities_type sens, - jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp, - int nnz, event_type event, - int number_of_events, int use_jacobian, np_array rhs_alg_id, - np_array atol_np, double rel_tol, np_array inputs, - int number_of_parameters) -{ - auto t = t_np.unchecked<1>(); - auto y0 = y0_np.unchecked<1>(); - auto yp0 = yp0_np.unchecked<1>(); - auto atol = atol_np.unchecked<1>(); - - int number_of_states = y0_np.request().size; - int number_of_timesteps = t_np.request().size; - void *ida_mem; // pointer to memory - N_Vector yy, yp, avtol; // y, y', and absolute tolerance - N_Vector *yyS, *ypS; // y, y' for sensitivities - N_Vector id; - realtype rtol, *yval, *ypval, *atval; - std::vector ySval(number_of_parameters); - int retval; - SUNMatrix J; - SUNLinearSolver LS; - -#if SUNDIALS_VERSION_MAJOR >= 6 - SUNContext sunctx; - SUNContext_Create(NULL, &sunctx); - - // allocate memory for solver - ida_mem = IDACreate(sunctx); - - // allocate vectors - yy = N_VNew_Serial(number_of_states, sunctx); - yp = N_VNew_Serial(number_of_states, sunctx); - avtol = N_VNew_Serial(number_of_states, sunctx); - id = N_VNew_Serial(number_of_states, sunctx); -#else - // allocate memory for solver - ida_mem = IDACreate(); - - // allocate vectors - yy = N_VNew_Serial(number_of_states); - yp = N_VNew_Serial(number_of_states); - avtol = N_VNew_Serial(number_of_states); - id = N_VNew_Serial(number_of_states); -#endif - - if (number_of_parameters > 0) { - yyS = N_VCloneVectorArray(number_of_parameters, yy); - ypS = N_VCloneVectorArray(number_of_parameters, yp); - } - - // set initial value - yval = N_VGetArrayPointer(yy); - ypval = N_VGetArrayPointer(yp); - atval = N_VGetArrayPointer(avtol); - int i; - for (i = 0; i < number_of_states; i++) - { - yval[i] = y0[i]; - ypval[i] = yp0[i]; - atval[i] = atol[i]; - } - - for (int is = 0 ; is < number_of_parameters; is++) { - ySval[is] = N_VGetArrayPointer(yyS[is]); - N_VConst(RCONST(0.0), yyS[is]); - N_VConst(RCONST(0.0), ypS[is]); - } - - // initialise solver - realtype t0 = RCONST(t(0)); - IDAInit(ida_mem, residual, t0, yy, yp); - - // set tolerances - rtol = RCONST(rel_tol); - - IDASVtolerances(ida_mem, rtol, avtol); - - // set events - IDARootInit(ida_mem, number_of_events, events); - - // set pybamm functions by passing pointer to it - PybammFunctions pybamm_functions(res, jac, sens, gjd, gjrv, gjcp, event, - number_of_states, number_of_events, - number_of_parameters, inputs); - void *user_data = &pybamm_functions; - IDASetUserData(ida_mem, user_data); - - // set linear solver -#if SUNDIALS_VERSION_MAJOR >= 6 - J = SUNSparseMatrix(number_of_states, number_of_states, nnz, CSR_MAT, sunctx); - LS = SUNLinSol_KLU(yy, J, sunctx); -#else - J = SUNSparseMatrix(number_of_states, number_of_states, nnz, CSR_MAT); - LS = SUNLinSol_KLU(yy, J); -#endif - - IDASetLinearSolver(ida_mem, LS, J); - - if (use_jacobian == 1) - { - IDASetJacFn(ida_mem, jacobian); - } - - if (number_of_parameters > 0) - { - IDASensInit(ida_mem, number_of_parameters, - IDA_SIMULTANEOUS, sensitivities, yyS, ypS); - IDASensEEtolerances(ida_mem); - } - - int t_i = 1; - realtype tret; - realtype t_next; - realtype t_final = t(number_of_timesteps - 1); - - // set return vectors - std::vector t_return(number_of_timesteps); - std::vector y_return(number_of_timesteps * number_of_states); - std::vector yS_return(number_of_parameters * number_of_timesteps * number_of_states); - - t_return[0] = t(0); - for (int j = 0; j < number_of_states; j++) - { - y_return[j] = yval[j]; - } - for (int j = 0; j < number_of_parameters; j++) { - const int base_index = j * number_of_timesteps * number_of_states; - for (int k = 0; k < number_of_states; k++) { - yS_return[base_index + k] = ySval[j][k]; - } - } - - // calculate consistent initial conditions - auto id_np_val = rhs_alg_id.unchecked<1>(); - realtype *id_val; - id_val = N_VGetArrayPointer(id); - - int ii; - for (ii = 0; ii < number_of_states; ii++) - { - id_val[ii] = id_np_val[ii]; - } - - IDASetId(ida_mem, id); - IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t(1)); - - while (true) - { - t_next = t(t_i); - IDASetStopTime(ida_mem, t_next); - retval = IDASolve(ida_mem, t_final, &tret, yy, yp, IDA_NORMAL); - - if (retval == IDA_TSTOP_RETURN || retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN) - { - if (number_of_parameters > 0) { - IDAGetSens(ida_mem, &tret, yyS); - } - - t_return[t_i] = tret; - for (int j = 0; j < number_of_states; j++) - { - y_return[t_i * number_of_states + j] = yval[j]; - } - for (int j = 0; j < number_of_parameters; j++) { - const int base_index = j * number_of_timesteps * number_of_states - + t_i * number_of_states; - for (int k = 0; k < number_of_states; k++) { - yS_return[base_index + k] = ySval[j][k]; - } - } - t_i += 1; - if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN) { - break; - } - - } - } - - /* Free memory */ - if (number_of_parameters > 0) { - IDASensFree(ida_mem); - } - IDAFree(&ida_mem); - SUNLinSolFree(LS); - SUNMatDestroy(J); - N_VDestroy(avtol); - N_VDestroy(yp); - if (number_of_parameters > 0) { - N_VDestroyVectorArray(yyS, number_of_parameters); - N_VDestroyVectorArray(ypS, number_of_parameters); - } -#if SUNDIALS_VERSION_MAJOR >= 6 - SUNContext_Free(&sunctx); -#endif - - np_array t_ret = np_array(t_i, &t_return[0]); - np_array y_ret = np_array(t_i * number_of_states, &y_return[0]); - np_array yS_ret = np_array( - std::vector {number_of_parameters, number_of_timesteps, number_of_states}, - &yS_return[0] - ); - np_array yterm_ret = np_array(0); - - Solution sol(retval, t_ret, y_ret, yS_ret, yterm_ret); - - return sol; -} diff --git a/src/pybamm/solvers/c_solvers/idaklu/python.hpp b/src/pybamm/solvers/c_solvers/idaklu/python.hpp deleted file mode 100644 index 6231d13eb6..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/python.hpp +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef PYBAMM_IDAKLU_HPP -#define PYBAMM_IDAKLU_HPP - -#include "common.hpp" -#include "Solution.hpp" -#include - -using residual_type = std::function< - np_array(realtype, np_array, np_array, np_array) - >; -using sensitivities_type = std::function&, realtype, const np_array&, - const np_array&, - const np_array&, const std::vector&, - const std::vector& - )>; -using jacobian_type = std::function; - -using event_type = - std::function; - -using jac_get_type = std::function; - - -/** - * @brief Interface to the python solver - */ -Solution solve_python(np_array t_np, np_array y0_np, np_array yp0_np, - residual_type res, jacobian_type jac, - sensitivities_type sens, - jac_get_type gjd, jac_get_type gjrv, jac_get_type gjcp, - int nnz, event_type event, - int number_of_events, int use_jacobian, np_array rhs_alg_id, - np_array atol_np, double rel_tol, np_array inputs, - int number_of_parameters); - -#endif // PYBAMM_IDAKLU_HPP diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 85731f4e12..b92006d12d 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -289,7 +289,21 @@ def inputs_to_dict(inputs): if ics_only: return base_set_up_return + if model.convert_to_format not in ["casadi", "jax"]: + msg = ( + "The python-idaklu solver has been deprecated. " + "To use the IDAKLU solver set `convert_to_format = 'casadi'`, or `jax`" + " if using IREE." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + if model.convert_to_format == "jax": + if self._options["jax_evaluator"] != "iree": + raise pybamm.SolverError( + "Unsupported evaluation engine for convert_to_format=" + f"{model.convert_to_format} " + f"(jax_evaluator={self._options['jax_evaluator']})" + ) mass_matrix = model.mass_matrix.entries.toarray() elif model.convert_to_format == "casadi": if self._options["jacobian"] == "dense": @@ -297,19 +311,15 @@ def inputs_to_dict(inputs): else: mass_matrix = casadi.DM(model.mass_matrix.entries) else: - mass_matrix = model.mass_matrix.entries + raise pybamm.SolverError( + "Unsupported option for convert_to_format=" + f"{model.convert_to_format} " + ) # construct residuals function by binding inputs if model.convert_to_format == "casadi": # TODO: do we need densify here? rhs_algebraic = model.rhs_algebraic_eval - else: - - def resfn(t, y, inputs, ydot): - return ( - model.rhs_algebraic_eval(t, y, inputs_to_dict(inputs)).flatten() - - mass_matrix @ ydot - ) if not model.use_jacobian: raise pybamm.SolverError("KLU requires the Jacobian") @@ -384,52 +394,6 @@ def resfn(t, y, inputs, ydot): ) ) - elif self._options["jax_evaluator"] == "jax": - t0 = 0 if t_eval is None else t_eval[0] - jac_y0_t0 = model.jac_rhs_algebraic_eval(t0, y0, inputs_dict) - if sparse.issparse(jac_y0_t0): - - def jacfn(t, y, inputs, cj): - j = ( - model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs)) - - cj * mass_matrix - ) - return j - - else: - - def jacfn(t, y, inputs, cj): - jac_eval = ( - model.jac_rhs_algebraic_eval(t, y, inputs_to_dict(inputs)) - - cj * mass_matrix - ) - return sparse.csr_matrix(jac_eval) - - class SundialsJacobian: - def __init__(self): - self.J = None - - random = np.random.random(size=y0.size) - J = jacfn(10, random, inputs, 20) - self.nnz = J.nnz # hoping nnz remains constant... - - def jac_res(self, t, y, inputs, cj): - # must be of form j_res = (dr/dy) - (cj) (dr/dy') - # cj is just the input parameter - # see p68 of the ida_guide.pdf for more details - self.J = jacfn(t, y, inputs, cj) - - def get_jac_data(self): - return self.J.data - - def get_jac_row_vals(self): - return self.J.indices - - def get_jac_col_ptrs(self): - return self.J.indptr - - jac_class = SundialsJacobian() - num_of_events = len(model.terminate_events_eval) # rootfn needs to return an array of length num_of_events @@ -446,15 +410,6 @@ def get_jac_col_ptrs(self): ) ], ) - elif self._options["jax_evaluator"] == "jax": - - def rootfn(t, y, inputs): - new_inputs = inputs_to_dict(inputs) - return_root = np.array( - [event(t, y, new_inputs) for event in model.terminate_events_eval] - ).reshape(-1) - - return return_root # get ids of rhs and algebraic variables if model.convert_to_format == "casadi": @@ -481,301 +436,242 @@ def rootfn(t, y, inputs): else: sensfn = model.jacp_rhs_algebraic_eval - else: - # for the python solver we give it the full sensitivity equations - # required by IDAS - def sensfn(resvalS, t, y, inputs, yp, yS, ypS): - """ - this function evaluates the sensitivity equations required by IDAS, - returning them in resvalS, which is preallocated as a numpy array of - size (np, n), where n is the number of states and np is the number of - parameters - - The equations returned are: - - dF/dy * s_i + dF/dyd * sd_i + dFdp_i for i in range(np) - - Parameters - ---------- - resvalS: ndarray of shape (np, n) - returns the sensitivity equations in this preallocated array - t: number - time value - y: ndarray of shape (n) - current state vector - yp: list (np) of ndarray of shape (n) - current time derivative of state vector - yS: list (np) of ndarray of shape (n) - current state vector of sensitivity equations - ypS: list (np) of ndarray of shape (n) - current time derivative of state vector of sensitivity equations - - """ - - new_inputs = inputs_to_dict(inputs) - dFdy = model.jac_rhs_algebraic_eval(t, y, new_inputs) - dFdyd = mass_matrix - dFdp = model.jacp_rhs_algebraic_eval(t, y, new_inputs) - - for i, dFdp_i in enumerate(dFdp.values()): - resvalS[i][:] = dFdy @ yS[i] - dFdyd @ ypS[i] + dFdp_i - atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, y0.size) rtol = self.rtol - if model.convert_to_format == "casadi" or ( + if model.convert_to_format == "casadi": + # Serialize casadi functions + idaklu_solver_fcn = idaklu.create_casadi_solver + rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) + jac_times_cjmass = idaklu.generate_function(jac_times_cjmass.serialize()) + jac_rhs_algebraic_action = idaklu.generate_function( + jac_rhs_algebraic_action.serialize() + ) + rootfn = idaklu.generate_function(rootfn.serialize()) + mass_action = idaklu.generate_function(mass_action.serialize()) + sensfn = idaklu.generate_function(sensfn.serialize()) + elif ( model.convert_to_format == "jax" and self._options["jax_evaluator"] == "iree" ): - if model.convert_to_format == "casadi": - # Serialize casadi functions - idaklu_solver_fcn = idaklu.create_casadi_solver - rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) - jac_times_cjmass = idaklu.generate_function( - jac_times_cjmass.serialize() + # Convert Jax functions to MLIR (also, demote to single precision) + idaklu_solver_fcn = idaklu.create_iree_solver + pybamm.demote_expressions_to_32bit = True + if pybamm.demote_expressions_to_32bit: + warnings.warn( + "Demoting expressions to 32-bit for MLIR conversion", + stacklevel=2, ) - jac_rhs_algebraic_action = idaklu.generate_function( - jac_rhs_algebraic_action.serialize() + jnpfloat = jnp.float32 + else: # pragma: no cover + jnpfloat = jnp.float64 + raise pybamm.SolverError( + "Demoting expressions to 32-bit is required for MLIR conversion" + " at this time" ) - rootfn = idaklu.generate_function(rootfn.serialize()) - mass_action = idaklu.generate_function(mass_action.serialize()) - sensfn = idaklu.generate_function(sensfn.serialize()) - elif ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ): - # Convert Jax functions to MLIR (also, demote to single precision) - idaklu_solver_fcn = idaklu.create_iree_solver - pybamm.demote_expressions_to_32bit = True - if pybamm.demote_expressions_to_32bit: - warnings.warn( - "Demoting expressions to 32-bit for MLIR conversion", - stacklevel=2, - ) - jnpfloat = jnp.float32 - else: # pragma: no cover - jnpfloat = jnp.float64 - raise pybamm.SolverError( - "Demoting expressions to 32-bit is required for MLIR conversion" - " at this time" - ) - # input arguments (used for lowering) - t_eval = self._demote_64_to_32(jnp.array([0.0], dtype=jnpfloat)) - y0 = self._demote_64_to_32(model.y0) - inputs0 = self._demote_64_to_32(inputs_to_dict(inputs)) - cj = self._demote_64_to_32(jnp.array([1.0], dtype=jnpfloat)) # array - v0 = jnp.zeros(model.len_rhs_and_alg, jnpfloat) - mass_matrix = model.mass_matrix.entries.toarray() - mass_matrix_demoted = self._demote_64_to_32(mass_matrix) - - # rhs_algebraic - rhs_algebraic_demoted = model.rhs_algebraic_eval - rhs_algebraic_demoted._demote_constants() - - def fcn_rhs_algebraic(t, y, inputs): - # function wraps an expression tree (and names MLIR module) - return rhs_algebraic_demoted(t, y, inputs) - - rhs_algebraic = self._make_iree_function( - fcn_rhs_algebraic, t_eval, y0, inputs0 - ) + # input arguments (used for lowering) + t_eval = self._demote_64_to_32(jnp.array([0.0], dtype=jnpfloat)) + y0 = self._demote_64_to_32(model.y0) + inputs0 = self._demote_64_to_32(inputs_to_dict(inputs)) + cj = self._demote_64_to_32(jnp.array([1.0], dtype=jnpfloat)) # array + v0 = jnp.zeros(model.len_rhs_and_alg, jnpfloat) + mass_matrix = model.mass_matrix.entries.toarray() + mass_matrix_demoted = self._demote_64_to_32(mass_matrix) - # jac_times_cjmass - jac_rhs_algebraic_demoted = rhs_algebraic_demoted.get_jacobian() + # rhs_algebraic + rhs_algebraic_demoted = model.rhs_algebraic_eval + rhs_algebraic_demoted._demote_constants() - def fcn_jac_times_cjmass(t, y, p, cj): - return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted + def fcn_rhs_algebraic(t, y, inputs): + # function wraps an expression tree (and names MLIR module) + return rhs_algebraic_demoted(t, y, inputs) - sparse_eval = sparse.csc_matrix( - fcn_jac_times_cjmass(t_eval, y0, inputs0, cj) - ) - jac_times_cjmass_nnz = sparse_eval.nnz - jac_times_cjmass_colptrs = sparse_eval.indptr - jac_times_cjmass_rowvals = sparse_eval.indices - jac_bw_lower, jac_bw_upper = bandwidth( - sparse_eval.todense() - ) # potentially slow - if jac_bw_upper <= 1: - jac_bw_upper = jac_bw_lower - 1 - if jac_bw_lower <= 1: - jac_bw_lower = jac_bw_upper + 1 - coo = sparse_eval.tocoo() # convert to COOrdinate format for indexing - - def fcn_jac_times_cjmass_sparse(t, y, p, cj): - return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col] - - jac_times_cjmass = self._make_iree_function( - fcn_jac_times_cjmass_sparse, t_eval, y0, inputs0, cj - ) + rhs_algebraic = self._make_iree_function( + fcn_rhs_algebraic, t_eval, y0, inputs0 + ) - # Mass action - def fcn_mass_action(v): - return mass_matrix_demoted @ v + # jac_times_cjmass + jac_rhs_algebraic_demoted = rhs_algebraic_demoted.get_jacobian() - mass_action_demoted = self._demote_64_to_32(fcn_mass_action) - mass_action = self._make_iree_function(mass_action_demoted, v0) + def fcn_jac_times_cjmass(t, y, p, cj): + return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted - # rootfn - for ix, _ in enumerate(model.terminate_events_eval): - model.terminate_events_eval[ix]._demote_constants() + sparse_eval = sparse.csc_matrix( + fcn_jac_times_cjmass(t_eval, y0, inputs0, cj) + ) + jac_times_cjmass_nnz = sparse_eval.nnz + jac_times_cjmass_colptrs = sparse_eval.indptr + jac_times_cjmass_rowvals = sparse_eval.indices + jac_bw_lower, jac_bw_upper = bandwidth( + sparse_eval.todense() + ) # potentially slow + if jac_bw_upper <= 1: + jac_bw_upper = jac_bw_lower - 1 + if jac_bw_lower <= 1: + jac_bw_lower = jac_bw_upper + 1 + coo = sparse_eval.tocoo() # convert to COOrdinate format for indexing + + def fcn_jac_times_cjmass_sparse(t, y, p, cj): + return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col] + + jac_times_cjmass = self._make_iree_function( + fcn_jac_times_cjmass_sparse, t_eval, y0, inputs0, cj + ) - def fcn_rootfn(t, y, inputs): - return jnp.array( - [event(t, y, inputs) for event in model.terminate_events_eval], - dtype=jnpfloat, - ).reshape(-1) + # Mass action + def fcn_mass_action(v): + return mass_matrix_demoted @ v - def fcn_rootfn_demoted(t, y, inputs): - return self._demote_64_to_32(fcn_rootfn)(t, y, inputs) + mass_action_demoted = self._demote_64_to_32(fcn_mass_action) + mass_action = self._make_iree_function(mass_action_demoted, v0) - rootfn = self._make_iree_function( - fcn_rootfn_demoted, t_eval, y0, inputs0 - ) + # rootfn + for ix, _ in enumerate(model.terminate_events_eval): + model.terminate_events_eval[ix]._demote_constants() - # jac_rhs_algebraic_action - jac_rhs_algebraic_action_demoted = ( - rhs_algebraic_demoted.get_jacobian_action() - ) + def fcn_rootfn(t, y, inputs): + return jnp.array( + [event(t, y, inputs) for event in model.terminate_events_eval], + dtype=jnpfloat, + ).reshape(-1) - def fcn_jac_rhs_algebraic_action( - t, y, p, v - ): # sundials calls (t, y, inputs, v) - return jac_rhs_algebraic_action_demoted( - t, y, v, p - ) # jvp calls (t, y, v, inputs) + def fcn_rootfn_demoted(t, y, inputs): + return self._demote_64_to_32(fcn_rootfn)(t, y, inputs) - jac_rhs_algebraic_action = self._make_iree_function( - fcn_jac_rhs_algebraic_action, t_eval, y0, inputs0, v0 - ) + rootfn = self._make_iree_function(fcn_rootfn_demoted, t_eval, y0, inputs0) - # sensfn - if model.jacp_rhs_algebraic_eval is None: - sensfn = idaklu.IREEBaseFunctionType() # empty equation - else: - sensfn_demoted = rhs_algebraic_demoted.get_sensitivities() + # jac_rhs_algebraic_action + jac_rhs_algebraic_action_demoted = ( + rhs_algebraic_demoted.get_jacobian_action() + ) - def fcn_sensfn(t, y, p): - return sensfn_demoted(t, y, p) + def fcn_jac_rhs_algebraic_action( + t, y, p, v + ): # sundials calls (t, y, inputs, v) + return jac_rhs_algebraic_action_demoted( + t, y, v, p + ) # jvp calls (t, y, v, inputs) - sensfn = self._make_iree_function( - fcn_sensfn, t_eval, jnp.zeros_like(y0), inputs0 - ) + jac_rhs_algebraic_action = self._make_iree_function( + fcn_jac_rhs_algebraic_action, t_eval, y0, inputs0, v0 + ) - # output_variables - self.var_idaklu_fcns = [] - self.dvar_dy_idaklu_fcns = [] - self.dvar_dp_idaklu_fcns = [] - for key in self.output_variables: - fcn = self.computed_var_fcns[key] - fcn._demote_constants() - self.var_idaklu_fcns.append( + # sensfn + if model.jacp_rhs_algebraic_eval is None: + sensfn = idaklu.IREEBaseFunctionType() # empty equation + else: + sensfn_demoted = rhs_algebraic_demoted.get_sensitivities() + + def fcn_sensfn(t, y, p): + return sensfn_demoted(t, y, p) + + sensfn = self._make_iree_function( + fcn_sensfn, t_eval, jnp.zeros_like(y0), inputs0 + ) + + # output_variables + self.var_idaklu_fcns = [] + self.dvar_dy_idaklu_fcns = [] + self.dvar_dp_idaklu_fcns = [] + for key in self.output_variables: + fcn = self.computed_var_fcns[key] + fcn._demote_constants() + self.var_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: fcn(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, + ) + ) + # Convert derivative functions for sensitivities + if (len(inputs) > 0) and (model.calculate_sensitivities): + dvar_dy = fcn.get_jacobian() + self.dvar_dy_idaklu_fcns.append( self._make_iree_function( - lambda t, y, p: fcn(t, y, p), # noqa: B023 + lambda t, y, p: dvar_dy(t, y, p), # noqa: B023 t_eval, y0, inputs0, + sparse_index=True, ) ) - # Convert derivative functions for sensitivities - if (len(inputs) > 0) and (model.calculate_sensitivities): - dvar_dy = fcn.get_jacobian() - self.dvar_dy_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: dvar_dy(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - sparse_index=True, - ) - ) - dvar_dp = fcn.get_sensitivities() - self.dvar_dp_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: dvar_dp(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - ) + dvar_dp = fcn.get_sensitivities() + self.dvar_dp_idaklu_fcns.append( + self._make_iree_function( + lambda t, y, p: dvar_dp(t, y, p), # noqa: B023 + t_eval, + y0, + inputs0, ) + ) - # Identify IREE library - iree_lib_path = os.path.join(iree.compiler.__path__[0], "_mlir_libs") - os.environ["IREE_COMPILER_LIB"] = os.path.join( - iree_lib_path, - next(f for f in os.listdir(iree_lib_path) if "IREECompiler" in f), - ) + # Identify IREE library + iree_lib_path = os.path.join(iree.compiler.__path__[0], "_mlir_libs") + os.environ["IREE_COMPILER_LIB"] = os.path.join( + iree_lib_path, + next(f for f in os.listdir(iree_lib_path) if "IREECompiler" in f), + ) - pybamm.demote_expressions_to_32bit = False - else: # pragma: no cover - raise pybamm.SolverError( - "Unsupported evaluation engine for convert_to_format='jax'" - ) + pybamm.demote_expressions_to_32bit = False + else: # pragma: no cover + raise pybamm.SolverError( + "Unsupported evaluation engine for convert_to_format='jax'" + ) - self._setup = { - "solver_function": idaklu_solver_fcn, # callable - "jac_bandwidth_upper": jac_bw_upper, # int - "jac_bandwidth_lower": jac_bw_lower, # int - "rhs_algebraic": rhs_algebraic, # function - "jac_times_cjmass": jac_times_cjmass, # function - "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, # array - "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, # array - "jac_times_cjmass_nnz": jac_times_cjmass_nnz, # int - "jac_rhs_algebraic_action": jac_rhs_algebraic_action, # function - "mass_action": mass_action, # function - "sensfn": sensfn, # function - "rootfn": rootfn, # function - "num_of_events": num_of_events, # int - "ids": ids, # array - "sensitivity_names": sensitivity_names, - "number_of_sensitivity_parameters": number_of_sensitivity_parameters, - "output_variables": self.output_variables, - "var_fcns": self.computed_var_fcns, - "var_idaklu_fcns": self.var_idaklu_fcns, - "dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns, - "dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns, - } + self._setup = { + "solver_function": idaklu_solver_fcn, # callable + "jac_bandwidth_upper": jac_bw_upper, # int + "jac_bandwidth_lower": jac_bw_lower, # int + "rhs_algebraic": rhs_algebraic, # function + "jac_times_cjmass": jac_times_cjmass, # function + "jac_times_cjmass_colptrs": jac_times_cjmass_colptrs, # array + "jac_times_cjmass_rowvals": jac_times_cjmass_rowvals, # array + "jac_times_cjmass_nnz": jac_times_cjmass_nnz, # int + "jac_rhs_algebraic_action": jac_rhs_algebraic_action, # function + "mass_action": mass_action, # function + "sensfn": sensfn, # function + "rootfn": rootfn, # function + "num_of_events": num_of_events, # int + "ids": ids, # array + "sensitivity_names": sensitivity_names, + "number_of_sensitivity_parameters": number_of_sensitivity_parameters, + "output_variables": self.output_variables, + "var_fcns": self.computed_var_fcns, + "var_idaklu_fcns": self.var_idaklu_fcns, + "dvar_dy_idaklu_fcns": self.dvar_dy_idaklu_fcns, + "dvar_dp_idaklu_fcns": self.dvar_dp_idaklu_fcns, + } - solver = self._setup["solver_function"]( - number_of_states=len(y0), - number_of_parameters=self._setup["number_of_sensitivity_parameters"], - rhs_alg=self._setup["rhs_algebraic"], - jac_times_cjmass=self._setup["jac_times_cjmass"], - jac_times_cjmass_colptrs=self._setup["jac_times_cjmass_colptrs"], - jac_times_cjmass_rowvals=self._setup["jac_times_cjmass_rowvals"], - jac_times_cjmass_nnz=self._setup["jac_times_cjmass_nnz"], - jac_bandwidth_lower=jac_bw_lower, - jac_bandwidth_upper=jac_bw_upper, - jac_action=self._setup["jac_rhs_algebraic_action"], - mass_action=self._setup["mass_action"], - sens=self._setup["sensfn"], - events=self._setup["rootfn"], - number_of_events=self._setup["num_of_events"], - rhs_alg_id=self._setup["ids"], - atol=atol, - rtol=rtol, - inputs=len(inputs), - var_fcns=self._setup["var_idaklu_fcns"], - dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], - dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], - options=self._options, - ) + solver = self._setup["solver_function"]( + number_of_states=len(y0), + number_of_parameters=self._setup["number_of_sensitivity_parameters"], + rhs_alg=self._setup["rhs_algebraic"], + jac_times_cjmass=self._setup["jac_times_cjmass"], + jac_times_cjmass_colptrs=self._setup["jac_times_cjmass_colptrs"], + jac_times_cjmass_rowvals=self._setup["jac_times_cjmass_rowvals"], + jac_times_cjmass_nnz=self._setup["jac_times_cjmass_nnz"], + jac_bandwidth_lower=jac_bw_lower, + jac_bandwidth_upper=jac_bw_upper, + jac_action=self._setup["jac_rhs_algebraic_action"], + mass_action=self._setup["mass_action"], + sens=self._setup["sensfn"], + events=self._setup["rootfn"], + number_of_events=self._setup["num_of_events"], + rhs_alg_id=self._setup["ids"], + atol=atol, + rtol=rtol, + inputs=len(inputs), + var_fcns=self._setup["var_idaklu_fcns"], + dvar_dy_fcns=self._setup["dvar_dy_idaklu_fcns"], + dvar_dp_fcns=self._setup["dvar_dp_idaklu_fcns"], + options=self._options, + ) - self._setup["solver"] = solver - else: - self._setup = { - "resfn": resfn, - "jac_class": jac_class, - "sensfn": sensfn, - "rootfn": rootfn, - "num_of_events": num_of_events, - "use_jac": 1, - "ids": ids, - "sensitivity_names": sensitivity_names, - "number_of_sensitivity_parameters": number_of_sensitivity_parameters, - } + self._setup["solver"] = solver return base_set_up_return @@ -859,7 +755,6 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, y0full.size) - rtol = self.rtol timer = pybamm.Timer() if model.convert_to_format == "casadi" or ( @@ -873,27 +768,9 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): ydot0full, inputs, ) - else: - sol = idaklu.solve_python( - t_eval, - y0full, - ydot0full, - self._setup["resfn"], - self._setup["jac_class"].jac_res, - self._setup["sensfn"], - self._setup["jac_class"].get_jac_data, - self._setup["jac_class"].get_jac_row_vals, - self._setup["jac_class"].get_jac_col_ptrs, - self._setup["jac_class"].nnz, - self._setup["rootfn"], - self._setup["num_of_events"], - self._setup["use_jac"], - self._setup["ids"], - atol, - rtol, - inputs, - self._setup["number_of_sensitivity_parameters"], - ) + else: # pragma: no cover + # Shouldn't ever reach this point + raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") integration_time = timer.time() number_of_sensitivity_parameters = self._setup[ diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index a35b864a64..e86b0f702e 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -114,7 +114,7 @@ def test_block_symbolic_inputs(self): ): solver.solve(model, np.array([1, 2, 3])) - def testode_solver_fail_with_dae(self): + def test_ode_solver_fail_with_dae(self): model = pybamm.BaseModel() a = pybamm.Scalar(1) model.algebraic = {a: a} @@ -364,49 +364,41 @@ def exact_diff_a(y, a, b): def exact_diff_b(y, a, b): return np.array([[y[0]], [0]]) - for convert_to_format in ["", "python", "casadi", "jax"]: - model = pybamm.BaseModel() - v = pybamm.Variable("v") - u = pybamm.Variable("u") - a = pybamm.InputParameter("a") - b = pybamm.InputParameter("b") - model.rhs = {v: a * v**2 + b * v + a**2} - model.algebraic = {u: a * v - u} - model.initial_conditions = {v: 1, u: a * 1} - model.convert_to_format = convert_to_format - solver = pybamm.IDAKLUSolver(root_method="lm") - model.calculate_sensitivities = ["a", "b"] - solver.set_up(model, inputs={"a": 0, "b": 0}) - all_inputs = [] - for v_value in [0.1, -0.2, 1.5, 8.4]: - for u_value in [0.13, -0.23, 1.3, 13.4]: - for a_value in [0.12, 1.5]: - for b_value in [0.82, 1.9]: - y = np.array([v_value, u_value]) - t = 0 - inputs = {"a": a_value, "b": b_value} - all_inputs.append((t, y, inputs)) - for t, y, inputs in all_inputs: - if model.convert_to_format == "casadi": - use_inputs = casadi.vertcat(*[x for x in inputs.values()]) - else: - use_inputs = inputs - - sens = model.jacp_rhs_algebraic_eval(t, y, use_inputs) - - if convert_to_format == "casadi": - sens_a = sens[0] - sens_b = sens[1] - else: - sens_a = sens["a"] - sens_b = sens["b"] - - np.testing.assert_allclose( - sens_a, exact_diff_a(y, inputs["a"], inputs["b"]) - ) - np.testing.assert_allclose( - sens_b, exact_diff_b(y, inputs["a"], inputs["b"]) - ) + model = pybamm.BaseModel() + v = pybamm.Variable("v") + u = pybamm.Variable("u") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b") + model.rhs = {v: a * v**2 + b * v + a**2} + model.algebraic = {u: a * v - u} + model.initial_conditions = {v: 1, u: a * 1} + model.convert_to_format = "casadi" + solver = pybamm.IDAKLUSolver(root_method="lm") + model.calculate_sensitivities = ["a", "b"] + solver.set_up(model, inputs={"a": 0, "b": 0}) + all_inputs = [] + for v_value in [0.1, -0.2, 1.5, 8.4]: + for u_value in [0.13, -0.23, 1.3, 13.4]: + for a_value in [0.12, 1.5]: + for b_value in [0.82, 1.9]: + y = np.array([v_value, u_value]) + t = 0 + inputs = {"a": a_value, "b": b_value} + all_inputs.append((t, y, inputs)) + for t, y, inputs in all_inputs: + use_inputs = casadi.vertcat(*[x for x in inputs.values()]) + + sens = model.jacp_rhs_algebraic_eval(t, y, use_inputs) + + sens_a = sens[0] + sens_b = sens[1] + + np.testing.assert_allclose( + sens_a, exact_diff_a(y, inputs["a"], inputs["b"]) + ) + np.testing.assert_allclose( + sens_b, exact_diff_b(y, inputs["a"], inputs["b"]) + ) if __name__ == "__main__": diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 5bc845d66c..67e68e9c6a 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -19,7 +19,7 @@ def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -65,7 +65,7 @@ def test_ida_roberts_klu(self): np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) def test_model_events(self): - for form in ["python", "casadi", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -195,7 +195,7 @@ def test_model_events(self): def test_input_params(self): # test a mix of scalar and vector input params - for form in ["python", "casadi", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -306,7 +306,7 @@ def test_ida_roberts_klu_sensitivities(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -414,7 +414,7 @@ def test_ida_roberts_consistent_initialization(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["python", "casadi", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -458,7 +458,7 @@ def test_sensitivities_with_events(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "python", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -619,7 +619,7 @@ def test_failures(self): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): - for form in ["python", "casadi", "jax", "iree"]: + for form in ["casadi", "iree"]: if (form == "jax" or form == "iree") and not pybamm.have_jax(): continue if (form == "iree") and not pybamm.have_iree(): @@ -1097,6 +1097,46 @@ def test_interpolate_time_step_start_offset(self): sol.sub_solutions[1].t[0], ) + def test_python_idaklu_deprecation_errors(self): + for form in ["python", "", "jax"]: + if form == "jax" and not pybamm.have_jax(): + continue + + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: 0.1 * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.events = [pybamm.Event("1", 0.2 - u), pybamm.Event("2", v)] + + disc = pybamm.Discretisation() + disc.process_model(model) + + t_eval = np.linspace(0, 3, 100) + + solver = pybamm.IDAKLUSolver( + root_method="lm", + ) + + if form == "python": + with self.assertRaisesRegex( + pybamm.SolverError, + "Unsupported option for convert_to_format=python", + ): + with self.assertWarnsRegex( + DeprecationWarning, + "The python-idaklu solver has been deprecated.", + ): + _ = solver.solve(model, t_eval) + elif form == "jax": + with self.assertRaisesRegex( + pybamm.SolverError, + "Unsupported evaluation engine for convert_to_format=jax", + ): + _ = solver.solve(model, t_eval) + if __name__ == "__main__": print("Add -v for more debug output")