Skip to content

Commit

Permalink
Merge branch 'issue-699-remove-autograd' into issue-579-volume-fractions
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Nov 12, 2019
2 parents 3f2c3c2 + 0947e2f commit 6405d0a
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 15 deletions.
20 changes: 14 additions & 6 deletions pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,9 @@ def _binary_evaluate(self, left, right):
return csr_matrix(left.multiply(1 / right))
else:
if isinstance(right, numbers.Number) and right == 0:
return left * np.inf
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return left * np.inf
else:
return left / right

Expand Down Expand Up @@ -715,7 +717,11 @@ def __init__(self, left, right, equal):
""" See :meth:`pybamm.BinaryOperator.__init__()`. """
# 'equal' determines whether to return 1 or 0 when left = right
self.equal = equal
super().__init__("heaviside", left, right)
if equal is True:
name = "<="
else:
name = "<"
super().__init__(name, left, right)

def __str__(self):
""" See :meth:`pybamm.Symbol.__str__()`. """
Expand All @@ -738,10 +744,12 @@ def _binary_jac(self, left_jac, right_jac):

def _binary_evaluate(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_evaluate()`. """
if self.equal is True:
return left <= right
else:
return left < right
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
if self.equal is True:
return left <= right
else:
return left < right

def _binary_new_copy(self, left, right):
""" See :meth:`pybamm.BinaryOperator._binary_new_copy()`. """
Expand Down
8 changes: 6 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def __init__(self, child):
super().__init__(np.log, child)

def _function_evaluate(self, evaluated_children):
return np.log(*evaluated_children)
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return np.log(*evaluated_children)

def _function_diff(self, children, idx):
""" See :meth:`pybamm.Function._function_diff()`. """
Expand Down Expand Up @@ -392,7 +394,9 @@ def __init__(self, child):
super().__init__(np.sqrt, child)

def _function_evaluate(self, evaluated_children):
return np.sqrt(*evaluated_children)
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return np.sqrt(*evaluated_children)

def _function_diff(self, children, idx):
""" See :meth:`pybamm.Function._function_diff()`. """
Expand Down
3 changes: 0 additions & 3 deletions pybamm/parameters/parameter_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,6 @@ def _process_symbol(self, symbol):
else:
# otherwise evaluate the function to create a new PyBaMM object
function = function_name(*new_children)
# this might return a scalar, in which case convert to a pybamm scalar
if isinstance(function, numbers.Number):
function = pybamm.Scalar(function, name=symbol.name)
# Differentiate if necessary
if symbol.diff_variable is None:
function_out = function
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/test_expression_tree/test_binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,14 @@ def test_heaviside(self):
self.assertEqual(heav.evaluate(y=np.array([2])), 1)
self.assertEqual(heav.evaluate(y=np.array([1])), 0)
self.assertEqual(heav.evaluate(y=np.array([0])), 0)
self.assertEqual(str(heav), "1.0 < y[0:1]")

heav = a <= b
heav = a >= b
self.assertTrue(heav.equal)
self.assertEqual(heav.evaluate(y=np.array([2])), 1)
self.assertEqual(heav.evaluate(y=np.array([2])), 0)
self.assertEqual(heav.evaluate(y=np.array([1])), 1)
self.assertEqual(heav.evaluate(y=np.array([0])), 0)
self.assertEqual(heav.evaluate(y=np.array([0])), 1)
self.assertEqual(str(heav), "y[0:1] <= 1.0")


class TestIsZero(unittest.TestCase):
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_expression_tree/test_operations/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,20 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

# test something with a heaviside
a = pybamm.Vector(np.array([1, 2]))
expr = a <= pybamm.StateVector(slice(0, 2))
evaluator = pybamm.EvaluatorPython(expr)
for t, y in zip(t_tests, y_tests):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

expr = a > pybamm.StateVector(slice(0, 2))
evaluator = pybamm.EvaluatorPython(expr)
for t, y in zip(t_tests, y_tests):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

# test something with an index
expr = pybamm.Index(A @ pybamm.StateVector(slice(0, 2)), 0)
evaluator = pybamm.EvaluatorPython(expr)
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ def test_jac_of_inner(self):
jac = pybamm.inner(a * vec, b * vec).jac(vec).evaluate(y=np.ones(2)).toarray()
np.testing.assert_array_equal(jac, 4 * np.eye(2))

def test_jac_of_heaviside(self):
a = pybamm.Scalar(1)
y = pybamm.StateVector(slice(0, 5))
np.testing.assert_array_equal(
((a < y) * y ** 2).jac(y).evaluate(y=5 * np.ones(5)), 10 * np.eye(5)
)
np.testing.assert_array_equal(
((a < y) * y ** 2).jac(y).evaluate(y=-5 * np.ones(5)), 0
)

def test_jac_of_domain_concatenation(self):
# create mesh
mesh = get_mesh_for_testing()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_symbol_evaluates_to_number(self):
def test_symbol_repr(self):
"""
test that __repr___ returns the string
`__class__(id, name, parent expression)`
`__class__(id, name, children, domain, auxiliary_domains)`
"""
a = pybamm.Symbol("a")
b = pybamm.Symbol("b")
Expand Down

0 comments on commit 6405d0a

Please sign in to comment.