Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1183 added erf and erfc #1184

Merged
merged 8 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Features

- Added an example script to check conservation of lithium ([#1186](https://github.com/pybamm-team/PyBaMM/pull/1186))
- Added an example script to check conservation of lithium ([#1186](https://github.com/pybamm-team/PyBaMM/pull/1186))
- Added `erf` and `erfc` functions ([#1184](https://github.com/pybamm-team/PyBaMM/pull/1184))

## Optimizations

Expand Down
22 changes: 22 additions & 0 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import autograd
import numbers
import numpy as np
from scipy import special
import pybamm


Expand Down Expand Up @@ -450,3 +451,24 @@ def _function_diff(self, children, idx):
def arctan(child):
" Returns hyperbolic tan function of child. "
return pybamm.simplify_if_constant(Arctan(child), keep_domains=True)


class Erf(SpecificFunction):
""" Error function """

def __init__(self, child):
super().__init__(special.erf, child)

def _function_diff(self, children, idx):
""" See :meth:`pybamm.Function._function_diff()`. """
return 2 / np.sqrt(np.pi) * Exponential(-children[0] ** 2)


def erf(child):
" Returns error function of child. "
return pybamm.simplify_if_constant(Erf(child), keep_domains=True)


def erfc(child):
" Returns complementary error function of child. "
return pybamm.simplify_if_constant(1 - Erf(child), keep_domains=True)
3 changes: 3 additions & 0 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import casadi
import numpy as np
from scipy.interpolate import PchipInterpolator, CubicSpline
from scipy import special


class CasadiConverter(object):
Expand Down Expand Up @@ -129,6 +130,8 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.log(*converted_children)
elif symbol.function == np.sign:
return casadi.sign(*converted_children)
elif symbol.function == special.erf:
return casadi.erf(*converted_children)
elif isinstance(symbol.function, (PchipInterpolator, CubicSpline)):
return casadi.interpolant("LUT", "bspline", [symbol.x], symbol.y)(
*converted_children
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
import numpy as np
from scipy.interpolate import interp1d
from scipy import special


def test_function(arg):
Expand Down Expand Up @@ -322,6 +323,38 @@ def test_tanh(self):
places=5,
)

def test_erf(self):
a = pybamm.InputParameter("a")
fun = pybamm.erf(a)
self.assertEqual(fun.evaluate(inputs={"a": 3}), special.erf(3))
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(inputs={"a": 3}),
(
pybamm.erf(pybamm.Scalar(3 + h)).evaluate()
- fun.evaluate(inputs={"a": 3})
)
/ h,
places=5,
)

def test_erfc(self):
a = pybamm.InputParameter("a")
fun = pybamm.erfc(a)
self.assertAlmostEqual(
fun.evaluate(inputs={"a": 3}), special.erfc(3), places=15
)
h = 0.0000001
self.assertAlmostEqual(
fun.diff(a).evaluate(inputs={"a": 3}),
(
pybamm.erfc(pybamm.Scalar(3 + h)).evaluate()
- fun.evaluate(inputs={"a": 3})
)
/ h,
places=5,
)


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pybamm
import unittest
from tests import get_mesh_for_testing, get_1p1d_discretisation_for_testing
from scipy import special


class TestCasadiConverter(unittest.TestCase):
Expand All @@ -15,6 +16,15 @@ def assert_casadi_equal(self, a, b, evalf=False):
else:
self.assertTrue((a - b).is_zero())

def assert_casadi_almost_equal(self, a, b, decimal=7, evalf=False):
tol = 1.5 * 10**(-decimal)
if evalf is True:
self.assertTrue(
(casadi.fabs(casadi.evalf(a) - casadi.evalf(b)) < tol).is_one()
)
else:
self.assertTrue((casadi.fabs(a - b) < tol).is_one())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nicely done!


def test_convert_scalar_symbols(self):
a = pybamm.Scalar(0)
b = pybamm.Scalar(1)
Expand Down Expand Up @@ -104,6 +114,8 @@ def test_special_functions(self):
self.assert_casadi_equal(
pybamm.Function(np.abs, c).to_casadi(), casadi.MX(3), evalf=True
)

# test functions with assert_casadi_equal
for np_fun in [
np.sqrt,
np.tanh,
Expand All @@ -121,6 +133,17 @@ def test_special_functions(self):
pybamm.Function(np_fun, c).to_casadi(), casadi.MX(np_fun(3)), evalf=True
)

# test functions with assert_casadi_almost_equal
for np_fun in [
special.erf,
]:
self.assert_casadi_almost_equal(
pybamm.Function(np_fun, c).to_casadi(),
casadi.MX(np_fun(3)),
decimal=15,
evalf=True,
)

def test_interpolation(self):
x = np.linspace(0, 1)[:, np.newaxis]
y = pybamm.StateVector(slice(0, 2))
Expand Down