Skip to content

Commit

Permalink
Add log1pexp_numpy helper
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 5, 2022
1 parent cb76083 commit f28ea64
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,25 @@ def log1mexp_numpy(x, *, negative_input=False):
return out


def log1pexp_numpy(x):
"""Return log(1 + exp(x)).
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
return np.where(
x < -37.0,
np.exp(x),
np.where(
x < 33.3,
np.log1p(np.exp(x)),
x,
),
)


def flatten_list(tensors):
return at.concatenate([var.ravel() for var in tensors])

Expand Down
13 changes: 13 additions & 0 deletions pymc/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import numpy.testing as npt
import pytest

from aesara.tensor import log1pexp

from pymc.aesaraf import floatX
from pymc.math import (
LogDet,
Expand All @@ -30,6 +32,7 @@
kronecker,
log1mexp,
log1mexp_numpy,
log1pexp_numpy,
log_softmax,
logdet,
logdiffexp,
Expand Down Expand Up @@ -188,6 +191,16 @@ def test_log1mexp_deprecation_warnings():
assert np.isclose(res_neg_at, res_neg)


def test_log1pexp_numpy():
a = np.array([-40, -30, 0, 30, 40])
assert np.all(
np.isclose(
log1pexp_numpy(a),
log1pexp(a).eval(),
)
)


def test_logdiffexp():
a = np.log([1, 2, 3, 4])
b = np.log([0, 1, 2, 3])
Expand Down

0 comments on commit f28ea64

Please sign in to comment.