Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update log1mexp and remove redundant local reimplementations in the library #4394

Merged
merged 5 commits into from
Jan 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from pymc3.distributions.distribution import Continuous, draw_values, generate_samples
from pymc3.distributions.special import log_i0
from pymc3.math import invlogit, logdiffexp, logit
from pymc3.math import invlogit, log1mexp, logdiffexp, logit
from pymc3.theanof import floatX

__all__ = [
Expand Down Expand Up @@ -1513,12 +1513,6 @@ def logcdf(self, value):
Compute the log of cumulative distribution function for the Exponential distribution
at the specified value.

References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing :math:`\log(1-\exp(-\mid a \mid))` Assessed by the Rmpfr
package"

Parameters
----------
value: numeric
Expand All @@ -1533,9 +1527,9 @@ def logcdf(self, value):
lam = self.lam
a = lam * value
return tt.switch(
tt.le(value, 0.0),
tt.le(value, 0.0) | tt.le(lam, 0),
-np.inf,
tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))),
log1mexp(a),
)


Expand Down Expand Up @@ -2806,12 +2800,6 @@ def logcdf(self, value):
Compute the log of the cumulative distribution function for Weibull distribution
at the specified value.

References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr
package"

Parameters
----------
value: numeric
Expand All @@ -2828,7 +2816,7 @@ def logcdf(self, value):
return tt.switch(
tt.le(value, 0.0),
-np.inf,
tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))),
log1mexp(a),
)


Expand Down
13 changes: 10 additions & 3 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,21 @@ def log1pexp(x):


def log1mexp(x):
"""Return log(1 - exp(-x)).
r"""Return log(1 - exp(-x)).

This function is numerically more stable than the naive approach.

For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf

References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr
package"

"""
return tt.switch(tt.lt(x, 0.683), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x)))
return tt.switch(tt.lt(x, 0.6931471805599453), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x)))


def log1mexp_numpy(x):
Expand All @@ -235,7 +242,7 @@ def log1mexp_numpy(x):
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
return np.where(x < 0.683, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))
return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))


def flatten_list(tensors):
Expand Down