Skip to content

Commit

Permalink
Merge pull request #2223 from pybamm-team/increase-coverage
Browse files Browse the repository at this point in the history
Increase coverage
  • Loading branch information
valentinsulzer authored Sep 15, 2022
2 parents a957d17 + 577f325 commit 96e57ec
Show file tree
Hide file tree
Showing 20 changed files with 230 additions and 134 deletions.
2 changes: 1 addition & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _concatenation_new_copy(self, children):

def _concatenation_jac(self, children_jacs):
"""Calculate the jacobian of a concatenation."""
return NotImplementedError
raise NotImplementedError

def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape`"""
Expand Down
6 changes: 3 additions & 3 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class JaxCooMatrix:
"""

def __init__(self, row, col, data, shape):
if not pybamm.have_jax():
if not pybamm.have_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)
Expand Down Expand Up @@ -413,7 +413,7 @@ def to_python(symbol, debug=False, output_jax=False):

line_format = "{} = {}"

if debug:
if debug: # pragma: no cover
variable_lines = [
"print('{}'); ".format(
line_format.format(id_to_python_variable(symbol_id, False), symbol_line)
Expand Down Expand Up @@ -540,7 +540,7 @@ class EvaluatorJax:
"""

def __init__(self, symbol):
if not pybamm.have_jax():
if not pybamm.have_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://pybamm.readthedocs.io/en/latest/install/GNU-linux.html#optional-jaxsolver" # noqa: E501
)
Expand Down
12 changes: 7 additions & 5 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,16 @@ def evaluate_ignoring_errors(self, t=0):
return None
elif error.args[0] == "StateVectorDot cannot evaluate input 'y_dot=None'":
return None
else:
else: # pragma: no cover
raise error
except ValueError as e:
except ValueError as error:
# return None if specific ValueError is raised
# (there is a e.g. Time in the tree)
if e.args[0] == "t must be provided":
if error.args[0] == "t must be provided":
return None
raise pybamm.ShapeError("Cannot find shape (original error: {})".format(e))
raise pybamm.ShapeError(
f"Cannot find shape (original error: {error})"
) # pragma: no cover
return result

def evaluates_to_number(self):
Expand Down Expand Up @@ -891,7 +893,7 @@ def create_copy(self):
"""
raise NotImplementedError(
"""method self.new_copy() not implemented
for symbol {!s} of type {}""".format(
for symbol {!s} of type {}""".format(
self, type(self)
)
)
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

13 changes: 8 additions & 5 deletions pybamm/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class EventType(Enum):
should return the time that the discontinuity occurs. The solver will integrate up
to the discontinuity and then restart just after the discontinuity.
INTERPOLANT_EXTRAPOLATION indicates that a pybamm.Interpolant object has been
evaluated outside of the range.
SWITCH indicates an event switch that is used in CasADI "fast with events" model.
"""

TERMINATION = 0
Expand All @@ -29,12 +33,11 @@ class Event:
----------
name: str
A string giving the name of the event
event_type: :class:`pybamm.EventType`
An enum defining the type of event
A string giving the name of the event.
expression: :class:`pybamm.Symbol`
An expression that defines when the event occurs
An expression that defines when the event occurs.
event_type: :class:`pybamm.EventType` (optional)
An enum defining the type of event. By default it is set to TERMINATION.
"""

Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
extrap_tol=0,
extra_options=None,
):
if scikits_odes_spec is None:
if scikits_odes_spec is None: # pragma: no cover
raise ImportError("scikits.odes is not installed")

super().__init__(method, rtol, atol, extrap_tol=extrap_tol)
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,58 @@ def test_process_model_dae(self):
with self.assertRaises(pybamm.ModelError):
disc.process_model(model)

def test_process_model_algebraic(self):
# TODO: implement this based on test_process_model_dae
# one rhs equation and one algebraic
whole_cell = ["negative electrode", "separator", "positive electrode"]
c = pybamm.Variable("c", domain=whole_cell)
N = pybamm.grad(c)
Q = pybamm.Scalar(1)
model = pybamm.BaseModel()
model.algebraic = {c: pybamm.div(N) - Q}
model.initial_conditions = {c: pybamm.Scalar(0)}

model.boundary_conditions = {
c: {"left": (0, "Dirichlet"), "right": (0, "Dirichlet")}
}
model.variables = {"c": c, "N": N}

# create discretisation
disc = get_discretisation_for_testing()
mesh = disc.mesh

disc.process_model(model)
combined_submesh = mesh.combine_submeshes(*whole_cell)

y0 = model.concatenated_initial_conditions.evaluate()
np.testing.assert_array_equal(
y0,
np.zeros_like(combined_submesh.nodes)[:, np.newaxis],
)

# grad and div are identity operators here
np.testing.assert_array_equal(
model.concatenated_rhs.evaluate(None, y0), np.ones([0, 1])
)

np.testing.assert_array_equal(
model.concatenated_algebraic.evaluate(None, y0),
-np.ones_like(combined_submesh.nodes[:, np.newaxis]),
)

# mass matrix is identity upper left, zeros elsewhere
mass = np.zeros(
(np.size(combined_submesh.nodes), np.size(combined_submesh.nodes))
)
np.testing.assert_array_equal(
mass, model.mass_matrix.entries.toarray()
)

# jacobian
y = pybamm.StateVector(slice(0, np.size(y0)))
jacobian = model.concatenated_algebraic.jac(y).evaluate(0, y0)
np.testing.assert_array_equal(np.eye(combined_submesh.npts), jacobian.toarray())

def test_process_model_concatenation(self):
# concatenation of variables as the key
cn = pybamm.Variable("c", domain=["negative electrode"])
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def test_binary_operator(self):
bin2 = pybamm.BinaryOperator("binary test", c, d)
with self.assertRaises(NotImplementedError):
bin2.evaluate()
with self.assertRaises(NotImplementedError):
bin2._binary_jac(a, b)

def test_binary_operator_domains(self):
# same domain
Expand Down Expand Up @@ -303,6 +305,7 @@ def test_equality(self):
self.assertEqual(equal.evaluate(y=np.array([1])), 1)
self.assertEqual(equal.evaluate(y=np.array([2])), 0)
self.assertEqual(str(equal), "1.0 == y[0:1]")
self.assertEqual(equal.diff(b), 0)

def test_sigmoid(self):
a = pybamm.Scalar(1)
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_expression_tree/test_concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def test_base_concatenation(self):
with self.assertRaisesRegex(TypeError, "ConcatenationVariable"):
pybamm.Concatenation(a, b)

# base concatenation jacobian
a = pybamm.Symbol("a", domain="test a")
b = pybamm.Symbol("b", domain="test b")
conc3 = pybamm.Concatenation(a, b)
with self.assertRaises(NotImplementedError):
conc3._concatenation_jac(None)

def test_concatenation_domains(self):
a = pybamm.Symbol("a", domain=["negative electrode"])
b = pybamm.Symbol("b", domain=["separator", "positive electrode"])
Expand Down Expand Up @@ -135,6 +142,9 @@ def test_numpy_concatenation_vectors(self):
conc = pybamm.NumpyConcatenation(a, b, c)
y = np.linspace(0, 1, 23)[:, np.newaxis]
np.testing.assert_array_equal(conc.evaluate(None, y), y)
# empty concatenation
conc = pybamm.NumpyConcatenation()
self.assertEqual(conc._concatenation_jac(None), 0)

def test_numpy_concatenation_vector_scalar(self):
# with entries
Expand Down Expand Up @@ -176,6 +186,10 @@ def test_domain_concatenation_domains(self):
],
)

conc.secondary_dimensions_npts = 2
with self.assertRaisesRegex(ValueError, "Concatenation and children must have"):
conc.create_slices(None)

def test_concatenation_orphans(self):
a = pybamm.Variable("a", domain=["negative electrode"])
b = pybamm.Variable("b", domain=["separator"])
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,22 @@ def test_jac_of_unary_operator(self):
with self.assertRaises(NotImplementedError):
b.jac(y)

def test_jac_of_binary_operator(self):
a = pybamm.Symbol("a")
b = pybamm.Symbol("b")

phi_s = pybamm.standard_variables.phi_s_n
i = pybamm.grad(phi_s)

inner = pybamm.inner(2, i)
self.assertEqual(inner._binary_jac(a, b), 2 * b)

inner = pybamm.inner(i, 2)
self.assertEqual(inner._binary_jac(a, b), 2 * a)

inner = pybamm.inner(i, i)
self.assertEqual(inner._binary_jac(a, b), i * a + i * b)

def test_jac_of_independent_variable(self):
a = pybamm.IndependentVariable("Variable")
y = pybamm.StateVector(slice(0, 1))
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/test_expression_tree/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_to_equation(self):
b.print_name = "test"
self.assertEqual(str(b.to_equation()), "test")

def test_copy(self):
a = pybamm.Scalar(5)
b = a.create_copy()
self.assertEqual(a, b)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_expression_tree/test_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def test_evaluate(self):
):
sv.evaluate(y_dot=y_dot2)

# Try evaluating with y_dot=None
with self.assertRaisesRegex(
TypeError,
"StateVectorDot cannot evaluate input 'y_dot=None'",
):
sv.evaluate(y_dot=None)

def test_name(self):
sv = pybamm.StateVectorDot(slice(0, 10))
self.assertEqual(sv.name, "y_dot[0:10]")
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def test_symbol_methods(self):
):
a + "two"

def test_symbol_create_copy(self):
a = pybamm.Symbol("a")
with self.assertRaisesRegex(NotImplementedError, "method self.new_copy()"):
a.create_copy()

def test_sigmoid(self):
# Test that smooth heaviside is used when the setting is changed
a = pybamm.Symbol("a")
Expand Down Expand Up @@ -244,6 +249,8 @@ def test_evaluate_ignoring_errors(self):
self.assertEqual(pybamm.t.evaluate_ignoring_errors(t=0), 0)
self.assertIsNone(pybamm.Parameter("a").evaluate_ignoring_errors())
self.assertIsNone(pybamm.StateVector(slice(0, 1)).evaluate_ignoring_errors())
self.assertIsNone(pybamm.StateVectorDot(slice(0, 1)).evaluate_ignoring_errors())

np.testing.assert_array_equal(
pybamm.InputParameter("a").evaluate_ignoring_errors(), np.nan
)
Expand Down
Loading

0 comments on commit 96e57ec

Please sign in to comment.