Skip to content

Commit

Permalink
fixed test
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke LB committed Nov 26, 2023
1 parent c800f84 commit f4aeff6
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Normal,
TruncatedNormal,
TruncatedNormalRV,
Uniform,
)
from pymc.distributions.shape_utils import change_dist_size
from pymc.distributions.transforms import _default_transform
Expand Down Expand Up @@ -404,25 +405,19 @@ def test_truncated_inference():


def test_truncated_normal_logcdf_inference():
N = 1
mu, sigma = 1, 2
upper_bounds = np.array([1])
y = np.random.normal(loc=mu, scale=sigma, size=N)

def right_censor(y, upper_bounds):
cy = copy(y)
censor_index = cy >= upper_bounds
cy[censor_index] = upper_bounds[censor_index]
return cy
lower = 0
upper = 1

cy = right_censor(y, upper_bounds)
rng = np.random.default_rng(260)
x = rng.normal(0, size=5000)
obs = x[np.where(~((x < lower) | (x > upper)))]

with Model() as m:
mu = Normal("mu", 0, 1)
sigma = Normal("sigma", 0, 1)
mu = Normal("mu", 10, 1)
sigma = Uniform("sigma", 1, 10)
normal_ = Normal.dist(mu, sigma)
trunc_normal = Truncated.dist(normal_, lower=0)
censored = Censored("censored", trunc_normal, lower=None, upper=upper_bounds, observed=cy)
censored = Censored("censored", trunc_normal, lower=None, upper=upper, observed=obs)

ip = m.initial_point()
norm_logp, trunc_logp, cens_logp = m.compile_logp(sum=False)(ip)
Expand Down

0 comments on commit f4aeff6

Please sign in to comment.