Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly integrate variables which appear nowhere else on the rhs #2348

Merged
merged 43 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
661d10f
identify independent and constant variables
aabills Sep 26, 2022
cda6301
wip
aabills Sep 27, 2022
a8d268f
working on 2240
aabills Sep 27, 2022
c3f763a
wrap up 2240
aabills Sep 27, 2022
387c58e
cleanup
aabills Sep 27, 2022
93a75f6
cleanup
aabills Sep 27, 2022
3d2331b
more cleanup
aabills Sep 27, 2022
8105483
style
aabills Sep 28, 2022
9b1ef35
make experiments work
aabills Sep 28, 2022
872c6c2
new method
aabills Sep 29, 2022
ac5a9a7
found and fixed bug
aabills Sep 29, 2022
6ef84af
fixing tests
aabills Sep 29, 2022
94ad578
add catch to prevent empty models
aabills Sep 30, 2022
db28dc0
done
aabills Oct 5, 2022
22cd381
add test
aabills Oct 6, 2022
eecc39c
Update CHANGELOG.md
aabills Oct 6, 2022
fa4df63
flake8 on test
aabills Oct 6, 2022
ce37860
Merge branch 'issue-2240-explicit' of github.com:abillscmu/PyBaMM int…
aabills Oct 6, 2022
8fc63a1
Merge branch 'develop' into issue-2240-explicit
valentinsulzer Oct 6, 2022
c8a1508
style: pre-commit fixes
pre-commit-ci[bot] Oct 6, 2022
6af4e50
update solution
aabills Oct 6, 2022
2fcb289
merge develop and fix style
aabills Oct 6, 2022
2b66a93
style: pre-commit fixes
pre-commit-ci[bot] Oct 6, 2022
b20695c
fix new test
aabills Oct 6, 2022
55359ed
Merge branch 'develop' into issue-2240-explicit
aabills Oct 10, 2022
b32358c
style: pre-commit fixes
pre-commit-ci[bot] Oct 10, 2022
170bbe4
add option to bypass;fixes idaklutest
aabills Oct 10, 2022
2075777
style: pre-commit fixes
pre-commit-ci[bot] Oct 10, 2022
35eca5a
only check for variables
aabills Oct 10, 2022
94d54a0
rename tree to name
aabills Oct 10, 2022
b5695df
fix typo in comment
aabills Oct 10, 2022
9e52a86
Merge branch 'issue-2240-explicit' of github.com:abillscmu/PyBaMM int…
aabills Oct 10, 2022
5ea698d
lower tolerance for failing test
aabills Oct 11, 2022
75cc615
style: pre-commit fixes
pre-commit-ci[bot] Oct 11, 2022
9578898
style
aabills Oct 11, 2022
ca431c6
Merge branch 'issue-2240-explicit' of github.com:abillscmu/PyBaMM int…
aabills Oct 11, 2022
4a873ff
add comment to force workflow
aabills Oct 11, 2022
34e5fdf
style: pre-commit fixes
pre-commit-ci[bot] Oct 11, 2022
a38a572
add two tests :/
aabills Oct 12, 2022
cf3e239
Merge branch 'develop' into issue-2240-explicit
aabills Oct 12, 2022
90b72c0
style: pre-commit fixes
pre-commit-ci[bot] Oct 12, 2022
c59e228
typo
aabills Oct 12, 2022
87c0943
style: pre-commit fixes
pre-commit-ci[bot] Oct 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

- Added more rules for simplifying expressions ([#2211](https://github.com/pybamm-team/PyBaMM/pull/2211))

- Sped up calculations of Electrode SOH variables for summary variables ([#2210](https://github.com/pybamm-team/PyBaMM/pull/2210))
- Added `ExplicitTimeIntegral` functionality to move variables which do not appear anywhere on the rhs to a new location, and to integrate those variables explicitly when `get` is called by the solution object. ([#2348](https://github.com/pybamm-team/PyBaMM/pull/2348))

## Breaking change

- Removed parameter cli tools (add/edit/remove parameters). Parameter sets can now more easily be added via python scripts. ([#2342](https://github.com/pybamm-team/PyBaMM/pull/2342))
Expand Down
10 changes: 9 additions & 1 deletion pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@
# Utility classes and methods
#
from .util import Timer, TimerTime, FuzzyDict
from .util import root_dir, load_function, rmse, get_infinite_nested_dict, load
from .util import (
root_dir,
load_function,
rmse,
get_infinite_nested_dict,
load,
is_constant_and_can_evaluate,
tree_search,
)
from .util import (
get_parameters_filepath,
have_jax,
Expand Down
77 changes: 74 additions & 3 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ def bcs(self, value):
# reset discretised_symbols
self._discretised_symbols = {}

def process_model(self, model, inplace=True, check_model=True):
def process_model(
self,
model,
inplace=True,
check_model=True,
check_for_independent_variables=True,
):
"""Discretise a model.
Currently inplace, could be changed to return a new model.

Expand All @@ -109,9 +115,15 @@ def process_model(self, model, inplace=True, check_model=True):
check_model : bool, optional
If True, model checks are performed after discretisation. For large
systems these checks can be slow, so can be skipped by setting this
option to False. When developing, testing or debugging it is recommened
option to False. When developing, testing or debugging it is recommended
to leave this option as True as it may help to identify any errors.
Default is True.
check_for_independent_variables : bool, optional
If True, model checks to see whether any variables from the RHS are used
in any other equation. If a variable meets all of the following criteria
(not used anywhere in the model, len(rhs)>1), then the variable
is moved to be explicitly integrated when called by the solution object.
Default is True.

Returns
-------
Expand Down Expand Up @@ -148,7 +160,12 @@ def process_model(self, model, inplace=True, check_model=True):

# Prepare discretisation
# set variables (we require the full variable not just id)

# Search Equations for Independence
aabills marked this conversation as resolved.
Show resolved Hide resolved
if check_for_independent_variables:
model = self.check_for_independent_variables(model)
variables = list(model.rhs.keys()) + list(model.algebraic.keys())
# Find those RHS's that are constant
if self.spatial_methods == {} and any(var.domain != [] for var in variables):
for var in variables:
if var.domain != []:
Expand Down Expand Up @@ -1005,7 +1022,6 @@ def _process_symbol(self, symbol):
out = pybamm.Index(ext, slice(start, end))
out.copy_domains(symbol)
return out

else:
# add a try except block for a more informative error if a variable
# can't be found. This should usually be caught earlier by
Expand Down Expand Up @@ -1197,3 +1213,58 @@ def check_variables(self, model):
var.shape, model.rhs[rhs_var].shape, var
)
)

def search_for_independent_var(self, model, var):
pybamm.logger.verbose("Removing independent blocks.")
boundary_variables = list(model.boundary_conditions.keys())
boundary_variable_keys = []
for condition in boundary_variables:
keys_for_condition = list(model.boundary_conditions[condition].keys())
boundary_variable_keys.append(keys_for_condition)
rhs_variables = list(model.rhs.keys())
algebraic_variables = list(model.algebraic.keys())
this_var_list = []
if not isinstance(var, pybamm.Variable):
return model, False
for tree in rhs_variables:
pybamm.tree_search(model.rhs[tree], var, this_var_list)
for tree in algebraic_variables:
pybamm.tree_search(model.algebraic[tree], var, this_var_list)
for (keys, tree) in zip(boundary_variable_keys, boundary_variables):
for key in keys:
pybamm.tree_search(
model.boundary_conditions[tree][key][0], var, this_var_list
)
for name in model.variables.keys():
for rhs_child in model.variables[name].children:
pybamm.tree_search(rhs_child, var, this_var_list)
this_var_is_independent = not any(this_var_list)
not_in_y_slices = not (var in list(self.y_slices.keys()))
not_in_discretised = not (var in list(self._discretised_symbols.keys()))
is_0D = len(var.domain) == 0
this_var_is_independent = (
this_var_is_independent and not_in_y_slices and not_in_discretised and is_0D
)
return model, this_var_is_independent

def check_for_independent_variables(self, model):
rhs_vars_to_search_over = list(model.rhs.keys())
for var in rhs_vars_to_search_over:
model, this_var_is_independent = self.search_for_independent_var(model, var)
if this_var_is_independent:
if len(model.rhs) != 1:
pybamm.logger.info("removing variable {} from rhs".format(var))
my_initial_condition = model.initial_conditions[var]
model.variables[var.name] = pybamm.ExplicitTimeIntegral(
model.rhs[var], my_initial_condition
)
# edge case where a variable appears
# in variables twice under different names
for key in model.variables:
if model.variables[key] == var:
model.variables[key] = model.variables[var.name]
del model.rhs[var]
del model.initial_conditions[var]
else:
break
return model
17 changes: 1 addition & 16 deletions pybamm/expression_tree/operations/evaluate_julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import scipy.sparse
from collections import OrderedDict
from pybamm.util import is_constant_and_can_evaluate

import numbers

Expand All @@ -20,22 +21,6 @@ def id_to_julia_variable(symbol_id, prefix):
return var_format.format(symbol_id).replace("-", "m")


def is_constant_and_can_evaluate(symbol):
"""
Returns True if symbol is constant and evaluation does not raise any errors.
Returns False otherwise.
An example of a constant symbol that cannot be "evaluated" is PrimaryBroadcast(0).
"""
if symbol.is_constant():
try:
symbol.evaluate()
return True
except NotImplementedError:
return False
else:
return False


def find_symbols(
symbol,
constant_symbols,
Expand Down
10 changes: 9 additions & 1 deletion pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def create_copy(self):

def _unary_new_copy(self, child):
"""Make a new copy of the unary operator, with child `child`"""

return self.__class__(child)

def _unary_jac(self, child_jac):
Expand Down Expand Up @@ -958,6 +957,15 @@ def _sympy_operator(self, child):
return sympy.Symbol(latex_child)


class ExplicitTimeIntegral(UnaryOperator):
def __init__(self, children, initial_condition):
super().__init__("explicit time integral", children)
self.initial_condition = initial_condition

def _unary_new_copy(self, child):
return self.__class__(child, self.initial_condition)


class BoundaryGradient(BoundaryOperator):
"""
A node in the expression tree which gets the boundary flux of a variable.
Expand Down
27 changes: 25 additions & 2 deletions pybamm/solvers/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@ class ProcessedVariable(object):
Default is True.
"""

def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
def __init__(
self,
base_variables,
base_variables_casadi,
solution,
warn=True,
cumtrapz_ic=None,
):
self.base_variables = base_variables
self.base_variables_casadi = base_variables_casadi

Expand All @@ -46,6 +53,7 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
self.domain = base_variables[0].domain
self.domains = base_variables[0].domains
self.warn = warn
self.cumtrapz_ic = cumtrapz_ic

# Sensitivity starts off uninitialized, only set when called
self._sensitivities = None
Expand Down Expand Up @@ -106,15 +114,30 @@ def initialise_0D(self):
# initialise empty array of the correct size
entries = np.empty(len(self.t_pts))
idx = 0
last_t = 0
# Evaluate the base_variable index-by-index
for ts, ys, inputs, base_var_casadi in zip(
self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
):
for inner_idx, t in enumerate(ts):
t = ts[inner_idx]
y = ys[:, inner_idx]
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]
if self.cumtrapz_ic is not None:
if idx == 0:
new_val = t * base_var_casadi(t, y, inputs).full()[0, 0]
entries[idx] = self.cumtrapz_ic + (
t * base_var_casadi(t, y, inputs).full()[0, 0]
)
else:
new_val = (t - last_t) * (
base_var_casadi(t, y, inputs).full()[0, 0]
)
entries[idx] = new_val + entries[idx - 1]
else:
entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]

idx += 1
last_t = t

# set up interpolation
if len(self.t_pts) == 1:
Expand Down
46 changes: 26 additions & 20 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,42 +471,48 @@ def update(self, variables):
variables = [variables]
# Process
for key in variables:
cumtrapz_ic = None
pybamm.logger.debug("Post-processing {}".format(key))
vars_pybamm = [model.variables_and_events[key] for model in self.all_models]

# Iterate through all models, some may be in the list several times and
# therefore only get set up once
vars_casadi = []
for model, ys, inputs, var_pybamm in zip(
self.all_models, self.all_ys, self.all_inputs, vars_pybamm
for (i, (model, ys, inputs, var_pybamm)) in enumerate(
zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm)
):
if key in model._variables_casadi:
if isinstance(var_pybamm, pybamm.ExplicitTimeIntegral):
cumtrapz_ic = var_pybamm.initial_condition
cumtrapz_ic = cumtrapz_ic.evaluate()
var_pybamm = var_pybamm.child
var_casadi = self.process_casadi_var(var_pybamm, inputs, ys)
model._variables_casadi[key] = var_casadi
vars_pybamm[i] = var_pybamm
elif key in model._variables_casadi:
aabills marked this conversation as resolved.
Show resolved Hide resolved
var_casadi = model._variables_casadi[key]
else:
t_MX = casadi.MX.sym("t")
y_MX = casadi.MX.sym("y", ys.shape[0])
inputs_MX_dict = {
key: casadi.MX.sym("input", value.shape[0])
for key, value in inputs.items()
}
inputs_MX = casadi.vertcat(*[p for p in inputs_MX_dict.values()])

# Convert variable to casadi
# Make all inputs symbolic first for converting to casadi
var_sym = var_pybamm.to_casadi(t_MX, y_MX, inputs=inputs_MX_dict)

var_casadi = casadi.Function(
"variable", [t_MX, y_MX, inputs_MX], [var_sym]
)
var_casadi = self.process_casadi_var(var_pybamm, inputs, ys)
model._variables_casadi[key] = var_casadi
vars_casadi.append(var_casadi)

var = pybamm.ProcessedVariable(vars_pybamm, vars_casadi, self)
var = pybamm.ProcessedVariable(
vars_pybamm, vars_casadi, self, cumtrapz_ic=cumtrapz_ic
)

# Save variable and data
self._variables[key] = var
self.data[key] = var.data

def process_casadi_var(self, var_pybamm, inputs, ys):
t_MX = casadi.MX.sym("t")
y_MX = casadi.MX.sym("y", ys.shape[0])
inputs_MX_dict = {
key: casadi.MX.sym("input", value.shape[0]) for key, value in inputs.items()
}
inputs_MX = casadi.vertcat(*[p for p in inputs_MX_dict.values()])
var_sym = var_pybamm.to_casadi(t_MX, y_MX, inputs=inputs_MX_dict)
var_casadi = casadi.Function("variable", [t_MX, y_MX, inputs_MX], [var_sym])
return var_casadi

def __getitem__(self, key):
"""Read a variable from the solution. Variables are created 'just in time', i.e.
only when they are called.
Expand Down
27 changes: 27 additions & 0 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@
JAXLIB_VERSION = "0.1.70"


def tree_search(tree, item, solutions):
for child in tree.children:
tree_search(child, item, solutions)
if (child == item) or (child.name == item.name):
solutions.append(True)
else:
solutions.append(False)
solutions.append((tree == item) or (tree.name == item.name))
return None


def root_dir():
"""return the root directory of the PyBaMM install directory"""
return str(pathlib.Path(pybamm.__path__[0]).parent)
Expand Down Expand Up @@ -351,6 +362,22 @@ def is_jax_compatible():
)


def is_constant_and_can_evaluate(symbol):
"""
Returns True if symbol is constant and evaluation does not raise any errors.
Returns False otherwise.
An example of a constant symbol that cannot be "evaluated" is PrimaryBroadcast(0).
"""
if symbol.is_constant():
try:
symbol.evaluate()
return True
except NotImplementedError:
return False
else:
return False


def install_jax(arguments=None): # pragma: no cover
"""
Install compatible versions of jax, jaxlib.
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_models/standard_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def test_solving(
self.solver.rtol = 1e-8
self.solver.atol = 1e-8

# Somehow removing an equation makes the solver fail at
# the low tolerances
if isinstance(self.model, pybamm.lithium_ion.NewmanTobias):
self.solver.rtol = 1e-7

Crate = abs(
self.parameter_values["Current function [A]"]
/ self.parameter_values["Nominal cell capacity [A.h]"]
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,21 @@ def test_length_scale_errors(self):
disc.process_model(model)
self.assertEqual(model.length_scales["negative electrode"], pybamm.Scalar(1))

def test_independent_rhs(self):
a = pybamm.Variable("a")
b = pybamm.Variable("b")
c = pybamm.Variable("c")
model = pybamm.BaseModel()
model.rhs = {a: b, b: c, c: -c}
model.initial_conditions = {
a: pybamm.Scalar(0),
b: pybamm.Scalar(1),
c: pybamm.Scalar(1),
}
disc = pybamm.Discretisation()
disc.process_model(model)
self.assertEqual(len(model.rhs), 2)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
Loading