From b9b721fa2b6a35fe8c155e24b6f8a1318e44483c Mon Sep 17 00:00:00 2001 From: Alec Bills Date: Thu, 31 Oct 2024 07:50:25 -0700 Subject: [PATCH 1/6] add coupled variable to expression tree and discretisation --- src/pybamm/__init__.py | 1 + src/pybamm/discretisations/discretisation.py | 5 ++ .../expression_tree/coupled_variable.py | 54 +++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 src/pybamm/expression_tree/coupled_variable.py diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index 68529156e3..3de52e5724 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -39,6 +39,7 @@ from .expression_tree.parameter import Parameter, FunctionParameter from .expression_tree.scalar import Scalar from .expression_tree.variable import * +from .expression_tree.coupled_variable import * from .expression_tree.independent_variable import * from .expression_tree.independent_variable import t from .expression_tree.vector import Vector diff --git a/src/pybamm/discretisations/discretisation.py b/src/pybamm/discretisations/discretisation.py index af4bd2edd6..7255e4923a 100644 --- a/src/pybamm/discretisations/discretisation.py +++ b/src/pybamm/discretisations/discretisation.py @@ -938,6 +938,11 @@ def _process_symbol(self, symbol): if symbol._expected_size is None: symbol._expected_size = expected_size return symbol.create_copy() + + elif isinstance(symbol, pybamm.CoupledVariable): + new_symbol = self.process_symbol(symbol.children[0]) + return new_symbol + else: # Backup option: return the object return symbol diff --git a/src/pybamm/expression_tree/coupled_variable.py b/src/pybamm/expression_tree/coupled_variable.py new file mode 100644 index 0000000000..2b1cd61be5 --- /dev/null +++ b/src/pybamm/expression_tree/coupled_variable.py @@ -0,0 +1,54 @@ +import pybamm + +from pybamm.type_definitions import DomainType + + +class CoupledVariable(pybamm.Symbol): + """ + A node in the expression tree representing a variable whose equation is set by a different model or submodel. + + + Parameters + ---------- + name : str + The variable's name. If the + """ + def __init__( + self, + name: str, + domain: DomainType = None, + ) -> None: + super().__init__(name, domain=domain) + + + def _evaluate_for_shape(self): + """ + Returns the scalar 'NaN' to represent the shape of a parameter. + See :meth:`pybamm.Symbol.evaluate_for_shape()` + """ + return pybamm.evaluate_for_shape_using_domain(self.domains) + + + def create_copy(self): + """See :meth:`pybamm.Symbol.new_copy()`.""" + new_input_parameter = CoupledVariable( + self.name, self.domain, expected_size=self._expected_size + ) + return new_input_parameter + + @property + def children(self): + return self._children + + @children.setter + def children(self, expr): + self._children = expr + + + def set_coupled_variable(self, symbol, expr): + if self == symbol: + symbol.children = [expr,] + else: + for child in symbol.children: + self.set_coupled_variable(child, expr) + symbol.set_id() \ No newline at end of file From f30749187ae2c777e8c21ff5b45031610e735c31 Mon Sep 17 00:00:00 2001 From: Alec Bills Date: Thu, 31 Oct 2024 11:13:46 -0700 Subject: [PATCH 2/6] add test; add coupledvariable dict to model --- src/pybamm/models/base_model.py | 24 ++++++ .../test_coupled_variable.py | 76 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 tests/unit/test_expression_tree/test_coupled_variable.py diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index f6f47acc55..dfe3e838a7 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -56,6 +56,7 @@ def __init__(self, name="Unnamed model"): self._boundary_conditions = {} self._variables_by_submodel = {} self._variables = pybamm.FuzzyDict({}) + self._coupled_variables = {} self._summary_variables = [] self._events = [] self._concatenated_rhs = None @@ -182,6 +183,29 @@ def boundary_conditions(self): def boundary_conditions(self, boundary_conditions): self._boundary_conditions = BoundaryConditionsDict(boundary_conditions) + @property + def coupled_variables(self): + """Returns a dictionary mapping strings to expressions representing variables needed by the model but whose equations were set by other models.""" + return self._coupled_variables + + @coupled_variables.setter + def coupled_variables(self, coupled_variables): + for name, var in coupled_variables.items(): + if ( + isinstance(var, pybamm.CoupledVariable) + and var.name != name + # Exception if the variable is also there under its own name + and not (var.name in coupled_variables and coupled_variables[var.name] == var) + ): + raise ValueError( + f"Coupled variable with name '{var.name}' is in coupled variables dictionary with " + f"name '{name}'. Names must match." + ) + self._coupled_variables = coupled_variables + + def list_coupled_variables(self): + list(self._coupled_variables.keys()) + @property def variables(self): """Returns a dictionary mapping strings to expressions representing the model's useful variables.""" diff --git a/tests/unit/test_expression_tree/test_coupled_variable.py b/tests/unit/test_expression_tree/test_coupled_variable.py new file mode 100644 index 0000000000..48bf51c37c --- /dev/null +++ b/tests/unit/test_expression_tree/test_coupled_variable.py @@ -0,0 +1,76 @@ +# +# Tests for the CoupledVariable class +# + +import pytest + +import numpy as np + +import pybamm + +def combine_models(list_of_models): + model = pybamm.BaseModel() + + for submodel in list_of_models: + model.coupled_variables.update(submodel.coupled_variables) + model.variables.update(submodel.variables) + model.rhs.update(submodel.rhs) + model.algebraic.update(submodel.algebraic) + model.initial_conditions.update(submodel.initial_conditions) + model.boundary_conditions.update(submodel.boundary_conditions) + + for name, coupled_variable in model.coupled_variables.items(): + if name in model.variables: + for sym in model.rhs.values(): + coupled_variable.set_coupled_variable(sym, model.variables[name]) + for sym in model.algebraic.values(): + coupled_variable.set_coupled_variable(sym, model.variables[name]) + return model + + +class TestCoupledVariable: + def test_coupled_variable(self): + model_1 = pybamm.BaseModel() + model_1_var_1 = pybamm.CoupledVariable("a") + model_1_var_2 = pybamm.Variable("b") + model_1.rhs[model_1_var_2] = -0.2 * model_1_var_1 + model_1.variables["b"] = model_1_var_2 + model_1.coupled_variables["a"] = model_1_var_1 + model_1.initial_conditions[model_1_var_2] = 1.0 + + model_2 = pybamm.BaseModel() + model_2_var_1 = pybamm.Variable("a") + model_2_var_2 = pybamm.CoupledVariable("b") + model_2.rhs[model_2_var_1] = - 0.2 * model_2_var_2 + model_2.variables["a"] = model_2_var_1 + model_2.coupled_variables["b"] = model_2_var_2 + model_2.initial_conditions[model_2_var_1] = 1.0 + + model = combine_models([model_1, model_2]) + + params = pybamm.ParameterValues({}) + geometry = {} + + # Process parameters + params.process_model(model) + params.process_geometry(geometry) + + # mesh and discretise + submesh_types = {} + var_pts = {} + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + + spatial_methods = {} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + # solve + solver = pybamm.CasadiSolver() + t = np.linspace(0, 10, 1000) + solution = solver.solve(model, t) + + np.testing.assert_almost_equal(solution["a"].entries, solution["b"].entries, decimal=10) + + + From 65a45aff486dc5492fb9a684b90542ebe62ac640 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:18:40 +0000 Subject: [PATCH 3/6] style: pre-commit fixes --- src/pybamm/discretisations/discretisation.py | 2 +- src/pybamm/expression_tree/coupled_variable.py | 14 +++++++------- src/pybamm/models/base_model.py | 6 ++++-- .../test_coupled_variable.py | 16 +++++++--------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/pybamm/discretisations/discretisation.py b/src/pybamm/discretisations/discretisation.py index 7255e4923a..2ca8c87649 100644 --- a/src/pybamm/discretisations/discretisation.py +++ b/src/pybamm/discretisations/discretisation.py @@ -938,7 +938,7 @@ def _process_symbol(self, symbol): if symbol._expected_size is None: symbol._expected_size = expected_size return symbol.create_copy() - + elif isinstance(symbol, pybamm.CoupledVariable): new_symbol = self.process_symbol(symbol.children[0]) return new_symbol diff --git a/src/pybamm/expression_tree/coupled_variable.py b/src/pybamm/expression_tree/coupled_variable.py index 2b1cd61be5..14fd8fbcdd 100644 --- a/src/pybamm/expression_tree/coupled_variable.py +++ b/src/pybamm/expression_tree/coupled_variable.py @@ -7,12 +7,13 @@ class CoupledVariable(pybamm.Symbol): """ A node in the expression tree representing a variable whose equation is set by a different model or submodel. - + Parameters ---------- name : str - The variable's name. If the + The variable's name. If the """ + def __init__( self, name: str, @@ -20,7 +21,6 @@ def __init__( ) -> None: super().__init__(name, domain=domain) - def _evaluate_for_shape(self): """ Returns the scalar 'NaN' to represent the shape of a parameter. @@ -28,7 +28,6 @@ def _evaluate_for_shape(self): """ return pybamm.evaluate_for_shape_using_domain(self.domains) - def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" new_input_parameter = CoupledVariable( @@ -44,11 +43,12 @@ def children(self): def children(self, expr): self._children = expr - def set_coupled_variable(self, symbol, expr): if self == symbol: - symbol.children = [expr,] + symbol.children = [ + expr, + ] else: for child in symbol.children: self.set_coupled_variable(child, expr) - symbol.set_id() \ No newline at end of file + symbol.set_id() diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index dfe3e838a7..85111af7f7 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -187,7 +187,7 @@ def boundary_conditions(self, boundary_conditions): def coupled_variables(self): """Returns a dictionary mapping strings to expressions representing variables needed by the model but whose equations were set by other models.""" return self._coupled_variables - + @coupled_variables.setter def coupled_variables(self, coupled_variables): for name, var in coupled_variables.items(): @@ -195,7 +195,9 @@ def coupled_variables(self, coupled_variables): isinstance(var, pybamm.CoupledVariable) and var.name != name # Exception if the variable is also there under its own name - and not (var.name in coupled_variables and coupled_variables[var.name] == var) + and not ( + var.name in coupled_variables and coupled_variables[var.name] == var + ) ): raise ValueError( f"Coupled variable with name '{var.name}' is in coupled variables dictionary with " diff --git a/tests/unit/test_expression_tree/test_coupled_variable.py b/tests/unit/test_expression_tree/test_coupled_variable.py index 48bf51c37c..53056e9b25 100644 --- a/tests/unit/test_expression_tree/test_coupled_variable.py +++ b/tests/unit/test_expression_tree/test_coupled_variable.py @@ -2,15 +2,15 @@ # Tests for the CoupledVariable class # -import pytest import numpy as np import pybamm + def combine_models(list_of_models): model = pybamm.BaseModel() - + for submodel in list_of_models: model.coupled_variables.update(submodel.coupled_variables) model.variables.update(submodel.variables) @@ -18,7 +18,7 @@ def combine_models(list_of_models): model.algebraic.update(submodel.algebraic) model.initial_conditions.update(submodel.initial_conditions) model.boundary_conditions.update(submodel.boundary_conditions) - + for name, coupled_variable in model.coupled_variables.items(): if name in model.variables: for sym in model.rhs.values(): @@ -41,7 +41,7 @@ def test_coupled_variable(self): model_2 = pybamm.BaseModel() model_2_var_1 = pybamm.Variable("a") model_2_var_2 = pybamm.CoupledVariable("b") - model_2.rhs[model_2_var_1] = - 0.2 * model_2_var_2 + model_2.rhs[model_2_var_1] = -0.2 * model_2_var_2 model_2.variables["a"] = model_2_var_1 model_2.coupled_variables["b"] = model_2_var_2 model_2.initial_conditions[model_2_var_1] = 1.0 @@ -60,7 +60,6 @@ def test_coupled_variable(self): var_pts = {} mesh = pybamm.Mesh(geometry, submesh_types, var_pts) - spatial_methods = {} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) @@ -70,7 +69,6 @@ def test_coupled_variable(self): t = np.linspace(0, 10, 1000) solution = solver.solve(model, t) - np.testing.assert_almost_equal(solution["a"].entries, solution["b"].entries, decimal=10) - - - + np.testing.assert_almost_equal( + solution["a"].entries, solution["b"].entries, decimal=10 + ) From 4d638c9ebb984065c71c653c2f10d41bddcf53ad Mon Sep 17 00:00:00 2001 From: Alec Bills Date: Thu, 31 Oct 2024 11:21:20 -0700 Subject: [PATCH 4/6] pre-commit merge --- src/pybamm/expression_tree/coupled_variable.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pybamm/expression_tree/coupled_variable.py b/src/pybamm/expression_tree/coupled_variable.py index 14fd8fbcdd..b85589cbe0 100644 --- a/src/pybamm/expression_tree/coupled_variable.py +++ b/src/pybamm/expression_tree/coupled_variable.py @@ -11,7 +11,9 @@ class CoupledVariable(pybamm.Symbol): Parameters ---------- name : str - The variable's name. If the + name of the node + domain : iterable of str + list of domains that this variable is valid over """ def __init__( From 0d4f12d4e45ff1feef68e413f59d413abf1c2e7d Mon Sep 17 00:00:00 2001 From: Alec Bills Date: Thu, 31 Oct 2024 11:40:22 -0700 Subject: [PATCH 5/6] Trigger CI From c259bb0e6046dc5fd591e39d487ba6cf340b2915 Mon Sep 17 00:00:00 2001 From: Alec Bills Date: Thu, 31 Oct 2024 12:36:51 -0700 Subject: [PATCH 6/6] add tests for coverage; valentin comments --- .../expression_tree/coupled_variable.py | 11 +++++----- src/pybamm/models/base_model.py | 2 +- .../test_coupled_variable.py | 20 +++++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/pybamm/expression_tree/coupled_variable.py b/src/pybamm/expression_tree/coupled_variable.py index b85589cbe0..04d03d2792 100644 --- a/src/pybamm/expression_tree/coupled_variable.py +++ b/src/pybamm/expression_tree/coupled_variable.py @@ -13,7 +13,7 @@ class CoupledVariable(pybamm.Symbol): name : str name of the node domain : iterable of str - list of domains that this variable is valid over + list of domains that this coupled variable is valid over """ def __init__( @@ -31,11 +31,9 @@ def _evaluate_for_shape(self): return pybamm.evaluate_for_shape_using_domain(self.domains) def create_copy(self): - """See :meth:`pybamm.Symbol.new_copy()`.""" - new_input_parameter = CoupledVariable( - self.name, self.domain, expected_size=self._expected_size - ) - return new_input_parameter + """Creates a new copy of the coupled variable.""" + new_coupled_variable = CoupledVariable(self.name, self.domain) + return new_coupled_variable @property def children(self): @@ -46,6 +44,7 @@ def children(self, expr): self._children = expr def set_coupled_variable(self, symbol, expr): + """Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step.""" if self == symbol: symbol.children = [ expr, diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index 85111af7f7..f7e8f70f32 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -206,7 +206,7 @@ def coupled_variables(self, coupled_variables): self._coupled_variables = coupled_variables def list_coupled_variables(self): - list(self._coupled_variables.keys()) + return list(self._coupled_variables.keys()) @property def variables(self): diff --git a/tests/unit/test_expression_tree/test_coupled_variable.py b/tests/unit/test_expression_tree/test_coupled_variable.py index 53056e9b25..3e60c412e5 100644 --- a/tests/unit/test_expression_tree/test_coupled_variable.py +++ b/tests/unit/test_expression_tree/test_coupled_variable.py @@ -7,6 +7,8 @@ import pybamm +import pytest + def combine_models(list_of_models): model = pybamm.BaseModel() @@ -72,3 +74,21 @@ def test_coupled_variable(self): np.testing.assert_almost_equal( solution["a"].entries, solution["b"].entries, decimal=10 ) + + assert set(model.list_coupled_variables()) == set(["a", "b"]) + + def test_create_copy(self): + a = pybamm.CoupledVariable("a") + b = a.create_copy() + assert a == b + + def test_setter(self): + model = pybamm.BaseModel() + a = pybamm.CoupledVariable("a") + coupled_variables = {"a": a} + model.coupled_variables = coupled_variables + assert model.coupled_variables == coupled_variables + + with pytest.raises(ValueError, match="Coupled variable with name"): + coupled_variables = {"b": a} + model.coupled_variables = coupled_variables