diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 08102d2022..2d023dbfbf 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -45,7 +45,7 @@ ) from pymc3.distributions.distribution import Continuous, draw_values, generate_samples from pymc3.distributions.special import log_i0 -from pymc3.math import invlogit, logdiffexp, logit +from pymc3.math import invlogit, log1mexp, logdiffexp, logit from pymc3.theanof import floatX __all__ = [ @@ -1513,12 +1513,6 @@ def logcdf(self, value): Compute the log of cumulative distribution function for the Exponential distribution at the specified value. - References - ---------- - .. [Machler2012] Martin Mächler (2012). - "Accurately computing :math:`\log(1-\exp(-\mid a \mid))` Assessed by the Rmpfr - package" - Parameters ---------- value: numeric @@ -1533,9 +1527,9 @@ def logcdf(self, value): lam = self.lam a = lam * value return tt.switch( - tt.le(value, 0.0), + tt.le(value, 0.0) | tt.le(lam, 0), -np.inf, - tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))), + log1mexp(a), ) @@ -2806,12 +2800,6 @@ def logcdf(self, value): Compute the log of the cumulative distribution function for Weibull distribution at the specified value. - References - ---------- - .. [Machler2012] Martin Mächler (2012). - "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr - package" - Parameters ---------- value: numeric @@ -2828,7 +2816,7 @@ def logcdf(self, value): return tt.switch( tt.le(value, 0.0), -np.inf, - tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))), + log1mexp(a), ) diff --git a/pymc3/math.py b/pymc3/math.py index 17f286ceab..17a7a7f67d 100644 --- a/pymc3/math.py +++ b/pymc3/math.py @@ -219,14 +219,21 @@ def log1pexp(x): def log1mexp(x): - """Return log(1 - exp(-x)). + r"""Return log(1 - exp(-x)). This function is numerically more stable than the naive approach. For details, see https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf + + References + ---------- + .. [Machler2012] Martin Mächler (2012). + "Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr + package" + """ - return tt.switch(tt.lt(x, 0.683), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x))) + return tt.switch(tt.lt(x, 0.6931471805599453), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x))) def log1mexp_numpy(x): @@ -235,7 +242,7 @@ def log1mexp_numpy(x): For details, see https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf """ - return np.where(x < 0.683, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x))) + return np.where(x < 0.6931471805599453, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x))) def flatten_list(tensors):