Skip to content

Commit

Permalink
Merge pull request #3999 from pybamm-team/normal-cdf
Browse files Browse the repository at this point in the history
Add normal pdf and cdf functions
  • Loading branch information
valentinsulzer authored Apr 12, 2024
2 parents d764a18 + 9fcd56b commit 51c1f76
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Added functions for normal probability density function (`pybamm.normal_pdf`) and cumulative distribution function (`pybamm.normal_cdf`) ([#3999](https://github.com/pybamm-team/PyBaMM/pull/3999))
- Updates multiprocess `Pool` in `BaseSolver.solve()` to be constructed with context `fork`. Adds small example for multiprocess inputs. ([#3974](https://github.com/pybamm-team/PyBaMM/pull/3974))
- Added custom experiment steps ([#3835](https://github.com/pybamm-team/PyBaMM/pull/3835))
- Added support for macOS arm64 (M-series) platforms. ([#3789](https://github.com/pybamm-team/PyBaMM/pull/3789))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/expression_tree/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ Functions
:members:

.. autofunction:: pybamm.tanh

.. autofunction:: pybamm.normal_pdf

.. autofunction:: pybamm.normal_cdf
46 changes: 46 additions & 0 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,49 @@ def _function_diff(self, children, idx):
def tanh(child: pybamm.Symbol):
"""Returns hyperbolic tan function of child."""
return simplified_function(Tanh, child)


def normal_pdf(
x: pybamm.Symbol, mu: pybamm.Symbol | float, sigma: pybamm.Symbol | float
):
"""
Returns the normal probability density function at x.
Parameters
----------
x : pybamm.Symbol
The value at which to evaluate the normal distribution
mu : pybamm.Symbol or float
The mean of the normal distribution
sigma : pybamm.Symbol or float
The standard deviation of the normal distribution
Returns
-------
pybamm.Symbol
The value of the normal distribution at x
"""
return 1 / (np.sqrt(2 * np.pi) * sigma) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)


def normal_cdf(
x: pybamm.Symbol, mu: pybamm.Symbol | float, sigma: pybamm.Symbol | float
):
"""
Returns the normal cumulative distribution function at x.
Parameters
----------
x : pybamm.Symbol
The value at which to evaluate the normal distribution
mu : pybamm.Symbol or float
The mean of the normal distribution
sigma : pybamm.Symbol or float
The standard deviation of the normal distribution
Returns
-------
pybamm.Symbol
The value of the normal distribution at x
"""
return 0.5 * (1 + special.erf((x - mu) / (sigma * np.sqrt(2))))
43 changes: 43 additions & 0 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,49 @@ def test_erfc(self):
)


class TestNonObjectFunctions(TestCase):
def test_normal_pdf(self):
x = pybamm.InputParameter("x")
mu = pybamm.InputParameter("mu")
sigma = pybamm.InputParameter("sigma")
fun = pybamm.normal_pdf(x, mu, sigma)
self.assertEqual(
fun.evaluate(inputs={"x": 0, "mu": 0, "sigma": 1}), 1 / np.sqrt(2 * np.pi)
)
self.assertEqual(
fun.evaluate(inputs={"x": 2, "mu": 2, "sigma": 10}),
1 / np.sqrt(2 * np.pi) / 10,
)
self.assertAlmostEqual(fun.evaluate(inputs={"x": 100, "mu": 0, "sigma": 1}), 0)
self.assertAlmostEqual(fun.evaluate(inputs={"x": -100, "mu": 0, "sigma": 1}), 0)
self.assertGreater(
fun.evaluate(inputs={"x": 1, "mu": 0, "sigma": 1}),
fun.evaluate(inputs={"x": 1, "mu": 0, "sigma": 2}),
)
self.assertGreater(
fun.evaluate(inputs={"x": -1, "mu": 0, "sigma": 1}),
fun.evaluate(inputs={"x": -1, "mu": 0, "sigma": 2}),
)

def test_normal_cdf(self):
x = pybamm.InputParameter("x")
mu = pybamm.InputParameter("mu")
sigma = pybamm.InputParameter("sigma")
fun = pybamm.normal_cdf(x, mu, sigma)
self.assertEqual(fun.evaluate(inputs={"x": 0, "mu": 0, "sigma": 1}), 0.5)
self.assertEqual(fun.evaluate(inputs={"x": 2, "mu": 2, "sigma": 10}), 0.5)
self.assertAlmostEqual(fun.evaluate(inputs={"x": 100, "mu": 0, "sigma": 1}), 1)
self.assertAlmostEqual(fun.evaluate(inputs={"x": -100, "mu": 0, "sigma": 1}), 0)
self.assertGreater(
fun.evaluate(inputs={"x": 1, "mu": 0, "sigma": 1}),
fun.evaluate(inputs={"x": 1, "mu": 0, "sigma": 2}),
)
self.assertLess(
fun.evaluate(inputs={"x": -1, "mu": 0, "sigma": 1}),
fun.evaluate(inputs={"x": -1, "mu": 0, "sigma": 2}),
)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys
Expand Down

0 comments on commit 51c1f76

Please sign in to comment.