Skip to content

Commit

Permalink
Merge pull request #1184 from brosaplanella/issue-1183-erf
Browse files Browse the repository at this point in the history
#1183 added erf and erfc
  • Loading branch information
valentinsulzer authored Oct 16, 2020
2 parents 9aa93b8 + 734745b commit 9b4c54d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Features

- Added an example script to check conservation of lithium ([#1186](https://github.com/pybamm-team/PyBaMM/pull/1186))
- Added an example script to check conservation of lithium ([#1186](https://github.com/pybamm-team/PyBaMM/pull/1186))
- Added `erf` and `erfc` functions ([#1184](https://github.com/pybamm-team/PyBaMM/pull/1184))

## Optimizations

Expand Down
22 changes: 22 additions & 0 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import autograd
import numbers
import numpy as np
from scipy import special
import pybamm


Expand Down Expand Up @@ -450,3 +451,24 @@ def _function_diff(self, children, idx):
def arctan(child):
" Returns hyperbolic tan function of child. "
return pybamm.simplify_if_constant(Arctan(child), keep_domains=True)


class Erf(SpecificFunction):
""" Error function """

def __init__(self, child):
super().__init__(special.erf, child)

def _function_diff(self, children, idx):
""" See :meth:`pybamm.Function._function_diff()`. """
return 2 / np.sqrt(np.pi) * Exponential(-children[0] ** 2)


def erf(child):
" Returns error function of child. "
return pybamm.simplify_if_constant(Erf(child), keep_domains=True)


def erfc(child):
" Returns complementary error function of child. "
return pybamm.simplify_if_constant(1 - Erf(child), keep_domains=True)
3 changes: 3 additions & 0 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import casadi
import numpy as np
from scipy.interpolate import PchipInterpolator, CubicSpline
from scipy import special


class CasadiConverter(object):
Expand Down Expand Up @@ -129,6 +130,8 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.log(*converted_children)
elif symbol.function == np.sign:
return casadi.sign(*converted_children)
elif symbol.function == special.erf:
return casadi.erf(*converted_children)
elif isinstance(symbol.function, (PchipInterpolator, CubicSpline)):
return casadi.interpolant("LUT", "bspline", [symbol.x], symbol.y)(
*converted_children
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
import numpy as np
from scipy.interpolate import interp1d
from scipy import special


def test_function(arg):
Expand Down Expand Up @@ -322,6 +323,38 @@ def test_tanh(self):
places=5,
)

def test_erf(self):
a = pybamm.InputParameter("a")
fun = pybamm.erf(a)
self.assertEqual(fun.evaluate(inputs={"a": 3}), special.erf(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(inputs={"a": 3}),
(
pybamm.erf(pybamm.Scalar(3 + h)).evaluate()
- fun.evaluate(inputs={"a": 3})
)
/ h,
places=5,
)

def test_erfc(self):
a = pybamm.InputParameter("a")
fun = pybamm.erfc(a)
self.assertAlmostEqual(
fun.evaluate(inputs={"a": 3}), special.erfc(3), places=15
)
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(inputs={"a": 3}),
(
pybamm.erfc(pybamm.Scalar(3 + h)).evaluate()
- fun.evaluate(inputs={"a": 3})
)
/ h,
places=5,
)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pybamm
import unittest
from tests import get_mesh_for_testing, get_1p1d_discretisation_for_testing
from scipy import special


class TestCasadiConverter(unittest.TestCase):
Expand All @@ -15,6 +16,15 @@ def assert_casadi_equal(self, a, b, evalf=False):
else:
self.assertTrue((a - b).is_zero())

def assert_casadi_almost_equal(self, a, b, decimal=7, evalf=False):
tol = 1.5 * 10**(-decimal)
if evalf is True:
self.assertTrue(
(casadi.fabs(casadi.evalf(a) - casadi.evalf(b)) < tol).is_one()
)
else:
self.assertTrue((casadi.fabs(a - b) < tol).is_one())

def test_convert_scalar_symbols(self):
a = pybamm.Scalar(0)
b = pybamm.Scalar(1)
Expand Down Expand Up @@ -104,6 +114,8 @@ def test_special_functions(self):
self.assert_casadi_equal(
pybamm.Function(np.abs, c).to_casadi(), casadi.MX(3), evalf=True
)

# test functions with assert_casadi_equal
for np_fun in [
np.sqrt,
np.tanh,
Expand All @@ -121,6 +133,17 @@ def test_special_functions(self):
pybamm.Function(np_fun, c).to_casadi(), casadi.MX(np_fun(3)), evalf=True
)

# test functions with assert_casadi_almost_equal
for np_fun in [
special.erf,
]:
self.assert_casadi_almost_equal(
pybamm.Function(np_fun, c).to_casadi(),
casadi.MX(np_fun(3)),
decimal=15,
evalf=True,
)

def test_interpolation(self):
x = np.linspace(0, 1)[:, np.newaxis]
y = pybamm.StateVector(slice(0, 2))
Expand Down

0 comments on commit 9b4c54d

Please sign in to comment.