Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance refactor for Jax BDF Solver, fixes #4455 #4456

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

## Optimizations

- Performance refactor of JAX BDF Solver with default Jax method set to `"BDF"`. ([#4456](https://github.com/pybamm-team/PyBaMM/pull/4456))
- Improved performance of initialization and reinitialization of ODEs in the (`IDAKLUSolver`). ([#4453](https://github.com/pybamm-team/PyBaMM/pull/4453))
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))

Expand Down
57 changes: 57 additions & 0 deletions examples/scripts/multiprocess_jax_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pybamm
import time
import numpy as np


# This script provides an example for massively vectorised
# model solves using the JAX BDF solver. First,
# we set up the model and process parameters
model = pybamm.lithium_ion.SPM()
model.convert_to_format = "jax"
model.events = [] # remove events (not supported in jax)
geometry = model.default_geometry
param = pybamm.ParameterValues("Chen2020")
param.update({"Current function [A]": "[input]"})
param.process_geometry(geometry)
param.process_model(model)

# Discretise and setup solver
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF")

# Set number of vectorised solves
values = np.linspace(0.01, 1.0, 1000)
inputs = [{"Current function [A]": value} for value in values]

# Run solve for all inputs, with a just-in-time compilation
# occurring on the first solve. All sequential solves will
# use the compiled code, with a large performance improvement.
start_time = time.time()
sol = solver.solve(model, t_eval, inputs=inputs)
print(f"Time taken: {time.time() - start_time}") # 1.3s

# Rerun the vectorised solve, showing performance improvement
start_time = time.time()
compiled_sol = solver.solve(model, t_eval, inputs=inputs)
print(f"Compiled time taken: {time.time() - start_time}") # 0.42s

# Plot one of the solves
plot = pybamm.QuickPlot(
sol[5],
[
"Negative particle concentration [mol.m-3]",
"Electrolyte concentration [mol.m-3]",
"Positive particle concentration [mol.m-3]",
"Current [A]",
"Negative electrode potential [V]",
"Electrolyte potential [V]",
"Positive electrode potential [V]",
"Voltage [V]",
],
time_unit="seconds",
spatial_unit="um",
)
plot.dynamic_plot()
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def run_scripts(session):
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
session.install("setuptools", silent=False)
session.install("-e", ".[all,dev]", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("python", "-m", "pytest", "-m", "scripts")


Expand Down
12 changes: 6 additions & 6 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def root_method(self):
def supports_parallel_solve(self):
return False

@property
def requires_explicit_sensitivities(self):
return True

@root_method.setter
def root_method(self, method):
if method == "casadi":
Expand Down Expand Up @@ -141,7 +145,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):

# see if we need to form the explicit sensitivity equations
calculate_sensitivities_explicit = (
model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver)
model.calculate_sensitivities and self.requires_explicit_sensitivities
)

self._set_up_model_sensitivities_inplace(
Expand Down Expand Up @@ -494,11 +498,7 @@ def _set_up_model_sensitivities_inplace(
# if we have a mass matrix, we need to extend it
def extend_mass_matrix(M):
M_extend = [M.entries] * (num_parameters + 1)
M_extend_pybamm = pybamm.Matrix(block_diag(M_extend, format="csr"))
return M_extend_pybamm

model.mass_matrix = extend_mass_matrix(model.mass_matrix)
model.mass_matrix = extend_mass_matrix(model.mass_matrix)
return pybamm.Matrix(block_diag(M_extend, format="csr"))

model.mass_matrix = extend_mass_matrix(model.mass_matrix)

Expand Down
4 changes: 4 additions & 0 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,10 @@ def _demote_64_to_32(self, x: pybamm.EvaluatorJax):
def supports_parallel_solve(self):
return True

@property
def requires_explicit_sensitivities(self):
return False

def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
"""
Solve a DAE model defined by residuals with initial conditions y0.
Expand Down
Loading
Loading