Skip to content

Commit

Permalink
Merge pull request #2348 from abillscmu/issue-2240-explicit
Browse files Browse the repository at this point in the history
Explicitly integrate variables which appear nowhere else on the rhs
  • Loading branch information
valentinsulzer authored Oct 12, 2022
2 parents b359850 + 87c0943 commit 7e35372
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 47 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211))

- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210))
- Added `ExplicitTimeIntegral` functionality to move variables which do not appear anywhere on the rhs to a new location, and to integrate those variables explicitly when `get` is called by the solution object. ([#2348](https://github.com/pybamm-team/PyBaMM/pull/2348))

## Breaking change

- Removed parameter cli tools (add/edit/remove parameters). Parameter sets can now more easily be added via python scripts. ([#2342](https://github.com/pybamm-team/PyBaMM/pull/2342))
Expand Down
10 changes: 9 additions & 1 deletion pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@
# Utility classes and methods
#
from .util import Timer, TimerTime, FuzzyDict
from .util import root_dir, load_function, rmse, get_infinite_nested_dict, load
from .util import (
root_dir,
load_function,
rmse,
get_infinite_nested_dict,
load,
is_constant_and_can_evaluate,
tree_search,
)
from .util import (
get_parameters_filepath,
have_jax,
Expand Down
77 changes: 74 additions & 3 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ def bcs(self, value):
# reset discretised_symbols
self._discretised_symbols = {}

def process_model(self, model, inplace=True, check_model=True):
def process_model(
self,
model,
inplace=True,
check_model=True,
check_for_independent_variables=True,
):
"""Discretise a model.
Currently inplace, could be changed to return a new model.
Expand All @@ -109,9 +115,15 @@ def process_model(self, model, inplace=True, check_model=True):
check_model : bool, optional
If True, model checks are performed after discretisation. For large
systems these checks can be slow, so can be skipped by setting this
option to False. When developing, testing or debugging it is recommened
option to False. When developing, testing or debugging it is recommended
to leave this option as True as it may help to identify any errors.
Default is True.
check_for_independent_variables : bool, optional
If True, model checks to see whether any variables from the RHS are used
in any other equation. If a variable meets all of the following criteria
(not used anywhere in the model, len(rhs)>1), then the variable
is moved to be explicitly integrated when called by the solution object.
Default is True.
Returns
-------
Expand Down Expand Up @@ -148,7 +160,12 @@ def process_model(self, model, inplace=True, check_model=True):

# Prepare discretisation
# set variables (we require the full variable not just id)

# Search Equations for Independence
if check_for_independent_variables:
model = self.check_for_independent_variables(model)
variables = list(model.rhs.keys()) + list(model.algebraic.keys())
# Find those RHS's that are constant
if self.spatial_methods == {} and any(var.domain != [] for var in variables):
for var in variables:
if var.domain != []:
Expand Down Expand Up @@ -1005,7 +1022,6 @@ def _process_symbol(self, symbol):
out = pybamm.Index(ext, slice(start, end))
out.copy_domains(symbol)
return out

else:
# add a try except block for a more informative error if a variable
# can't be found. This should usually be caught earlier by
Expand Down Expand Up @@ -1197,3 +1213,58 @@ def check_variables(self, model):
var.shape, model.rhs[rhs_var].shape, var
)
)

def search_for_independent_var(self, model, var):
pybamm.logger.verbose("Removing independent blocks.")
boundary_variables = list(model.boundary_conditions.keys())
boundary_variable_keys = []
for condition in boundary_variables:
keys_for_condition = list(model.boundary_conditions[condition].keys())
boundary_variable_keys.append(keys_for_condition)
rhs_variables = list(model.rhs.keys())
algebraic_variables = list(model.algebraic.keys())
this_var_list = []
if not isinstance(var, pybamm.Variable):
return model, False
for tree in rhs_variables:
pybamm.tree_search(model.rhs[tree], var, this_var_list)
for tree in algebraic_variables:
pybamm.tree_search(model.algebraic[tree], var, this_var_list)
for (keys, tree) in zip(boundary_variable_keys, boundary_variables):
for key in keys:
pybamm.tree_search(
model.boundary_conditions[tree][key][0], var, this_var_list
)
for name in model.variables.keys():
for rhs_child in model.variables[name].children:
pybamm.tree_search(rhs_child, var, this_var_list)
this_var_is_independent = not any(this_var_list)
not_in_y_slices = not (var in list(self.y_slices.keys()))
not_in_discretised = not (var in list(self._discretised_symbols.keys()))
is_0D = len(var.domain) == 0
this_var_is_independent = (
this_var_is_independent and not_in_y_slices and not_in_discretised and is_0D
)
return model, this_var_is_independent

def check_for_independent_variables(self, model):
rhs_vars_to_search_over = list(model.rhs.keys())
for var in rhs_vars_to_search_over:
model, this_var_is_independent = self.search_for_independent_var(model, var)
if this_var_is_independent:
if len(model.rhs) != 1:
pybamm.logger.info("removing variable {} from rhs".format(var))
my_initial_condition = model.initial_conditions[var]
model.variables[var.name] = pybamm.ExplicitTimeIntegral(
model.rhs[var], my_initial_condition
)
# edge case where a variable appears
# in variables twice under different names
for key in model.variables:
if model.variables[key] == var:
model.variables[key] = model.variables[var.name]
del model.rhs[var]
del model.initial_conditions[var]
else:
break
return model
17 changes: 1 addition & 16 deletions pybamm/expression_tree/operations/evaluate_julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import scipy.sparse
from collections import OrderedDict
from pybamm.util import is_constant_and_can_evaluate

import numbers

Expand All @@ -20,22 +21,6 @@ def id_to_julia_variable(symbol_id, prefix):
return var_format.format(symbol_id).replace("-", "m")


def is_constant_and_can_evaluate(symbol):
"""
Returns True if symbol is constant and evaluation does not raise any errors.
Returns False otherwise.
An example of a constant symbol that cannot be "evaluated" is PrimaryBroadcast(0).
"""
if symbol.is_constant():
try:
symbol.evaluate()
return True
except NotImplementedError:
return False
else:
return False


def find_symbols(
symbol,
constant_symbols,
Expand Down
10 changes: 9 additions & 1 deletion pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def create_copy(self):

def _unary_new_copy(self, child):
"""Make a new copy of the unary operator, with child `child`"""

return self.__class__(child)

def _unary_jac(self, child_jac):
Expand Down Expand Up @@ -958,6 +957,15 @@ def _sympy_operator(self, child):
return sympy.Symbol(latex_child)


class ExplicitTimeIntegral(UnaryOperator):
def __init__(self, children, initial_condition):
super().__init__("explicit time integral", children)
self.initial_condition = initial_condition

def _unary_new_copy(self, child):
return self.__class__(child, self.initial_condition)


class BoundaryGradient(BoundaryOperator):
"""
A node in the expression tree which gets the boundary flux of a variable.
Expand Down
27 changes: 25 additions & 2 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ class ProcessedVariable(object):
Default is True.
"""

def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
def __init__(
self,
base_variables,
base_variables_casadi,
solution,
warn=True,
cumtrapz_ic=None,
):
self.base_variables = base_variables
self.base_variables_casadi = base_variables_casadi

Expand All @@ -46,6 +53,7 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
self.domain = base_variables[0].domain
self.domains = base_variables[0].domains
self.warn = warn
self.cumtrapz_ic = cumtrapz_ic

# Sensitivity starts off uninitialized, only set when called
self._sensitivities = None
Expand Down Expand Up @@ -106,15 +114,30 @@ def initialise_0D(self):
# initialise empty array of the correct size
entries = np.empty(len(self.t_pts))
idx = 0
last_t = 0
# Evaluate the base_variable index-by-index
for ts, ys, inputs, base_var_casadi in zip(
self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
):
for inner_idx, t in enumerate(ts):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]
if self.cumtrapz_ic is not None:
if idx == 0:
new_val = t * base_var_casadi(t, y, inputs).full()[0, 0]
entries[idx] = self.cumtrapz_ic + (
t * base_var_casadi(t, y, inputs).full()[0, 0]
)
else:
new_val = (t - last_t) * (
base_var_casadi(t, y, inputs).full()[0, 0]
)
entries[idx] = new_val + entries[idx - 1]
else:
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]

idx += 1
last_t = t

# set up interpolation
if len(self.t_pts) == 1:
Expand Down
46 changes: 26 additions & 20 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,42 +471,48 @@ def update(self, variables):
variables = [variables]
# Process
for key in variables:
cumtrapz_ic = None
pybamm.logger.debug("Post-processing {}".format(key))
vars_pybamm = [model.variables_and_events[key] for model in self.all_models]

# Iterate through all models, some may be in the list several times and
# therefore only get set up once
vars_casadi = []
for model, ys, inputs, var_pybamm in zip(
self.all_models, self.all_ys, self.all_inputs, vars_pybamm
for (i, (model, ys, inputs, var_pybamm)) in enumerate(
zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm)
):
if key in model._variables_casadi:
if isinstance(var_pybamm, pybamm.ExplicitTimeIntegral):
cumtrapz_ic = var_pybamm.initial_condition
cumtrapz_ic = cumtrapz_ic.evaluate()
var_pybamm = var_pybamm.child
var_casadi = self.process_casadi_var(var_pybamm, inputs, ys)
model._variables_casadi[key] = var_casadi
vars_pybamm[i] = var_pybamm
elif key in model._variables_casadi:
var_casadi = model._variables_casadi[key]
else:
t_MX = casadi.MX.sym("t")
y_MX = casadi.MX.sym("y", ys.shape[0])
inputs_MX_dict = {
key: casadi.MX.sym("input", value.shape[0])
for key, value in inputs.items()
}
inputs_MX = casadi.vertcat(*[p for p in inputs_MX_dict.values()])

# Convert variable to casadi
# Make all inputs symbolic first for converting to casadi
var_sym = var_pybamm.to_casadi(t_MX, y_MX, inputs=inputs_MX_dict)

var_casadi = casadi.Function(
"variable", [t_MX, y_MX, inputs_MX], [var_sym]
)
var_casadi = self.process_casadi_var(var_pybamm, inputs, ys)
model._variables_casadi[key] = var_casadi
vars_casadi.append(var_casadi)

var = pybamm.ProcessedVariable(vars_pybamm, vars_casadi, self)
var = pybamm.ProcessedVariable(
vars_pybamm, vars_casadi, self, cumtrapz_ic=cumtrapz_ic
)

# Save variable and data
self._variables[key] = var
self.data[key] = var.data

def process_casadi_var(self, var_pybamm, inputs, ys):
t_MX = casadi.MX.sym("t")
y_MX = casadi.MX.sym("y", ys.shape[0])
inputs_MX_dict = {
key: casadi.MX.sym("input", value.shape[0]) for key, value in inputs.items()
}
inputs_MX = casadi.vertcat(*[p for p in inputs_MX_dict.values()])
var_sym = var_pybamm.to_casadi(t_MX, y_MX, inputs=inputs_MX_dict)
var_casadi = casadi.Function("variable", [t_MX, y_MX, inputs_MX], [var_sym])
return var_casadi

def __getitem__(self, key):
"""Read a variable from the solution. Variables are created 'just in time', i.e.
only when they are called.
Expand Down
27 changes: 27 additions & 0 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@
JAXLIB_VERSION = "0.1.70"


def tree_search(tree, item, solutions):
for child in tree.children:
tree_search(child, item, solutions)
if (child == item) or (child.name == item.name):
solutions.append(True)
else:
solutions.append(False)
solutions.append((tree == item) or (tree.name == item.name))
return None


def root_dir():
"""return the root directory of the PyBaMM install directory"""
return str(pathlib.Path(pybamm.__path__[0]).parent)
Expand Down Expand Up @@ -351,6 +362,22 @@ def is_jax_compatible():
)


def is_constant_and_can_evaluate(symbol):
"""
Returns True if symbol is constant and evaluation does not raise any errors.
Returns False otherwise.
An example of a constant symbol that cannot be "evaluated" is PrimaryBroadcast(0).
"""
if symbol.is_constant():
try:
symbol.evaluate()
return True
except NotImplementedError:
return False
else:
return False


def install_jax(arguments=None): # pragma: no cover
"""
Install compatible versions of jax, jaxlib.
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def test_solving(
self.solver.rtol = 1e-8
self.solver.atol = 1e-8

# Somehow removing an equation makes the solver fail at
# the low tolerances
if isinstance(self.model, pybamm.lithium_ion.NewmanTobias):
self.solver.rtol = 1e-7

Crate = abs(
self.parameter_values["Current function [A]"]
/ self.parameter_values["Nominal cell capacity [A.h]"]
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,21 @@ def test_length_scale_errors(self):
disc.process_model(model)
self.assertEqual(model.length_scales["negative electrode"], pybamm.Scalar(1))

def test_independent_rhs(self):
a = pybamm.Variable("a")
b = pybamm.Variable("b")
c = pybamm.Variable("c")
model = pybamm.BaseModel()
model.rhs = {a: b, b: c, c: -c}
model.initial_conditions = {
a: pybamm.Scalar(0),
b: pybamm.Scalar(1),
c: pybamm.Scalar(1),
}
disc = pybamm.Discretisation()
disc.process_model(model)
self.assertEqual(len(model.rhs), 2)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
Loading

0 comments on commit 7e35372

Please sign in to comment.