Skip to content

Commit

Permalink
add tests for coverage; valentin comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aabills committed Oct 31, 2024
1 parent 0d4f12d commit c259bb0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
11 changes: 5 additions & 6 deletions src/pybamm/expression_tree/coupled_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CoupledVariable(pybamm.Symbol):
name : str
name of the node
domain : iterable of str
list of domains that this variable is valid over
list of domains that this coupled variable is valid over
"""

def __init__(
Expand All @@ -31,11 +31,9 @@ def _evaluate_for_shape(self):
return pybamm.evaluate_for_shape_using_domain(self.domains)

def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
new_input_parameter = CoupledVariable(
self.name, self.domain, expected_size=self._expected_size
)
return new_input_parameter
"""Creates a new copy of the coupled variable."""
new_coupled_variable = CoupledVariable(self.name, self.domain)
return new_coupled_variable

@property
def children(self):
Expand All @@ -46,6 +44,7 @@ def children(self, expr):
self._children = expr

def set_coupled_variable(self, symbol, expr):
"""Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step."""
if self == symbol:
symbol.children = [
expr,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def coupled_variables(self, coupled_variables):
self._coupled_variables = coupled_variables

def list_coupled_variables(self):
list(self._coupled_variables.keys())
return list(self._coupled_variables.keys())

@property
def variables(self):
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_expression_tree/test_coupled_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import pybamm

import pytest


def combine_models(list_of_models):
model = pybamm.BaseModel()
Expand Down Expand Up @@ -72,3 +74,21 @@ def test_coupled_variable(self):
np.testing.assert_almost_equal(
solution["a"].entries, solution["b"].entries, decimal=10
)

assert set(model.list_coupled_variables()) == set(["a", "b"])

def test_create_copy(self):
a = pybamm.CoupledVariable("a")
b = a.create_copy()
assert a == b

def test_setter(self):
model = pybamm.BaseModel()
a = pybamm.CoupledVariable("a")
coupled_variables = {"a": a}
model.coupled_variables = coupled_variables
assert model.coupled_variables == coupled_variables

with pytest.raises(ValueError, match="Coupled variable with name"):
coupled_variables = {"b": a}
model.coupled_variables = coupled_variables

0 comments on commit c259bb0

Please sign in to comment.