Skip to content

Commit

Permalink
Add regression test for Truncated Gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 2, 2023
1 parent 4fb9bb6 commit b53c3de
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytensor.tensor.random.basic import GeometricRV, NormalRV

from pymc import Censored, Model, draw, find_MAP
from pymc.distributions.continuous import Exponential, TruncatedNormalRV
from pymc.distributions.continuous import Exponential, Gamma, TruncatedNormalRV
from pymc.distributions.shape_utils import change_dist_size
from pymc.distributions.transforms import _default_transform
from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
Expand Down Expand Up @@ -392,3 +392,33 @@ def test_truncated_inference():
map = find_MAP(progressbar=False)

assert np.isclose(map["lam"], lam_true, atol=0.1)


def test_truncated_gamma():
# Regression test for https://github.com/pymc-devs/pymc/issues/6931
alpha = 3.0
beta = 3.0
upper = 2.5
x = np.linspace(0.0, upper + 0.5, 100)

gamma_scipy = scipy.stats.gamma(a=alpha, scale=1.0 / beta)
logp_scipy = gamma_scipy.logpdf(x) - gamma_scipy.logcdf(upper)
logp_scipy[x > upper] = -np.inf

gamma_trunc_pymc = Truncated.dist(
Gamma.dist(alpha=alpha, beta=beta),
upper=upper,
)
logp_pymc = logp(gamma_trunc_pymc, x).eval()
np.testing.assert_allclose(
logp_pymc,
logp_scipy,
)

# # Changing the size used to invert the beta Gamma parameter again
resized_gamma_trunc_pymc = change_dist_size(gamma_trunc_pymc, new_size=x.shape)
logp_pymc = logp(resized_gamma_trunc_pymc, x).eval()
np.testing.assert_allclose(
logp_pymc,
logp_scipy,
)

0 comments on commit b53c3de

Please sign in to comment.