From 0947e2f6a9f9c1d8a3df90498a0ecd9f4fbecba9 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 12 Nov 2019 14:09:54 -0500 Subject: [PATCH] #699 coverage --- pybamm/expression_tree/binary_operators.py | 20 +++++++++++++------ pybamm/expression_tree/functions.py | 8 ++++++-- pybamm/parameters/parameter_values.py | 3 --- .../test_binary_operators.py | 8 +++++--- .../test_operations/test_evaluate.py | 14 +++++++++++++ .../test_operations/test_jac.py | 10 ++++++++++ .../unit/test_expression_tree/test_symbol.py | 2 +- 7 files changed, 50 insertions(+), 15 deletions(-) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 32045b9109..9a78bfd912 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -504,7 +504,9 @@ def _binary_evaluate(self, left, right): return csr_matrix(left.multiply(1 / right)) else: if isinstance(right, numbers.Number) and right == 0: - return left * np.inf + # don't raise RuntimeWarning for NaNs + with np.errstate(invalid="ignore"): + return left * np.inf else: return left / right @@ -715,7 +717,11 @@ def __init__(self, left, right, equal): """ See :meth:`pybamm.BinaryOperator.__init__()`. """ # 'equal' determines whether to return 1 or 0 when left = right self.equal = equal - super().__init__("heaviside", left, right) + if equal is True: + name = "<=" + else: + name = "<" + super().__init__(name, left, right) def __str__(self): """ See :meth:`pybamm.Symbol.__str__()`. """ @@ -738,10 +744,12 @@ def _binary_jac(self, left_jac, right_jac): def _binary_evaluate(self, left, right): """ See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """ - if self.equal is True: - return left <= right - else: - return left < right + # don't raise RuntimeWarning for NaNs + with np.errstate(invalid="ignore"): + if self.equal is True: + return left <= right + else: + return left < right def _binary_new_copy(self, left, right): """ See :meth:`pybamm.BinaryOperator._binary_new_copy()`. """ diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index c7ca525286..8ce840815f 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -318,7 +318,9 @@ def __init__(self, child): super().__init__(np.log, child) def _function_evaluate(self, evaluated_children): - return np.log(*evaluated_children) + # don't raise RuntimeWarning for NaNs + with np.errstate(invalid="ignore"): + return np.log(*evaluated_children) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ @@ -392,7 +394,9 @@ def __init__(self, child): super().__init__(np.sqrt, child) def _function_evaluate(self, evaluated_children): - return np.sqrt(*evaluated_children) + # don't raise RuntimeWarning for NaNs + with np.errstate(invalid="ignore"): + return np.sqrt(*evaluated_children) def _function_diff(self, children, idx): """ See :meth:`pybamm.Function._function_diff()`. """ diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index ef369ecec4..13e3f8b861 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -444,9 +444,6 @@ def _process_symbol(self, symbol): else: # otherwise evaluate the function to create a new PyBaMM object function = function_name(*new_children) - # this might return a scalar, in which case convert to a pybamm scalar - if isinstance(function, numbers.Number): - function = pybamm.Scalar(function, name=symbol.name) # Differentiate if necessary if symbol.diff_variable is None: function_out = function diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index af125736b9..2f80f31c6e 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -332,12 +332,14 @@ def test_heaviside(self): self.assertEqual(heav.evaluate(y=np.array([2])), 1) self.assertEqual(heav.evaluate(y=np.array([1])), 0) self.assertEqual(heav.evaluate(y=np.array([0])), 0) + self.assertEqual(str(heav), "1.0 < y[0:1]") - heav = a <= b + heav = a >= b self.assertTrue(heav.equal) - self.assertEqual(heav.evaluate(y=np.array([2])), 1) + self.assertEqual(heav.evaluate(y=np.array([2])), 0) self.assertEqual(heav.evaluate(y=np.array([1])), 1) - self.assertEqual(heav.evaluate(y=np.array([0])), 0) + self.assertEqual(heav.evaluate(y=np.array([0])), 1) + self.assertEqual(str(heav), "y[0:1] <= 1.0") class TestIsZero(unittest.TestCase): diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate.py b/tests/unit/test_expression_tree/test_operations/test_evaluate.py index edb9ee3d14..52addcb634 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate.py @@ -342,6 +342,20 @@ def test_evaluator_python(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) + # test something with a heaviside + a = pybamm.Vector(np.array([1, 2])) + expr = a <= pybamm.StateVector(slice(0, 2)) + evaluator = pybamm.EvaluatorPython(expr) + for t, y in zip(t_tests, y_tests): + result = evaluator.evaluate(t=t, y=y) + np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) + + expr = a > pybamm.StateVector(slice(0, 2)) + evaluator = pybamm.EvaluatorPython(expr) + for t, y in zip(t_tests, y_tests): + result = evaluator.evaluate(t=t, y=y) + np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) + # test something with an index expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0) evaluator = pybamm.EvaluatorPython(expr) diff --git a/tests/unit/test_expression_tree/test_operations/test_jac.py b/tests/unit/test_expression_tree/test_operations/test_jac.py index 2d00692c54..526f313903 100644 --- a/tests/unit/test_expression_tree/test_operations/test_jac.py +++ b/tests/unit/test_expression_tree/test_operations/test_jac.py @@ -264,6 +264,16 @@ def test_jac_of_inner(self): jac = pybamm.inner(a * vec, b * vec).jac(vec).evaluate(y=np.ones(2)).toarray() np.testing.assert_array_equal(jac, 4 * np.eye(2)) + def test_jac_of_heaviside(self): + a = pybamm.Scalar(1) + y = pybamm.StateVector(slice(0, 5)) + np.testing.assert_array_equal( + ((a < y) * y ** 2).jac(y).evaluate(y=5 * np.ones(5)), 10 * np.eye(5) + ) + np.testing.assert_array_equal( + ((a < y) * y ** 2).jac(y).evaluate(y=-5 * np.ones(5)), 0 + ) + def test_jac_of_domain_concatenation(self): # create mesh mesh = get_mesh_for_testing() diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index dbde7d4614..840185b107 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -198,7 +198,7 @@ def test_symbol_evaluates_to_number(self): def test_symbol_repr(self): """ test that __repr___ returns the string - `__class__(id, name, parent expression)` + `__class__(id, name, children, domain, auxiliary_domains)` """ a = pybamm.Symbol("a") b = pybamm.Symbol("b")