-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement logcdf for TruncatedNormal #7003
Comments
Hey I've been having a crack at this but hitting a road block with the refactor. Using the example given on discord, N = 1 # one trial is enough to illustrate the problem
mu, sigma = 1, 2
upper_bounds = np.array([1]) # changed from discord for clarity
y = np.random.normal(loc=mu, scale=sigma, size=N)
def right_censor(y, upper_bounds):
cy = copy(y)
censor_index = cy >= upper_bounds
cy[censor_index] = upper_bounds[censor_index]
return cy
cy = right_censor(y, upper_bounds)
with pm.Model() as m1:
mu = pm.Normal("mu", 0, 1)
sigma = pm.Normal("sigma", 0, 1)
normal_ = pm.Normal.dist(mu, sigma)
trunc_normal = pm.Truncated.dist(normal_, lower=0)
censored = pm.Censored("censored", trunc_normal, lower=None, upper=upper_bounds, observed=cy)
with m1:
i_data = pm.sample() As expected I can get the example to sample if I copy the trunacted logcdf over and explicitley call the def logcdf(value, mu, sigma, lower, upper):
logcdf = normal_lcdf(mu, sigma, value)
lower_logcdf = normal_lcdf(mu, sigma, lower)
upper_logcdf = normal_lcdf(mu, sigma, upper)
# the rest of the function is left unchanged but we don't actually want that. I run into issues when trying to import the function instead, when I do, from pymc.distributions.truncated import truncated_logcdf
@_logcdf.register(TruncatedNormal)
def truncated_normal_logcdf(op, value, *inputs, **kwargs):
return truncated_logcdf(op, value, *inputs, **kwargs) Using this code |
The idea would be to refactor this code into it's own function whose signature is pymc/pymc/distributions/truncated.py Lines 362 to 397 in 547bcb4
The main body of @_logcdf.register(TruncatedRV)
def truncated_logcdf(op, value, *inputs, **kwargs):
base_rv_op = op.base_rv_op
*base_rv_inputs, lower, upper, rng = inputs
base_rv_inputs = [rng, *base_rv_inputs]
return truncated_logcdf_from_base_rv(base_rv_op, value, base_rv_inputs, lower, upper) Then your truncated_normal_logcdf would look something like: from pymc.distributions.truncated import truncated_logcdf
@_logcdf.register(TruncatedNormal)
def truncated_normal_logcdf(op, value, *inputs, **kwargs):
base_rv_op = pm.Normal.rv_op
*base_rv_inputs, lower, upper = inputs
return truncated_logcdf_from_base_rv(base_rv_op, value, base_rv_inputs, lower, upper) |
Ahh got it |
Description
We implemented the logcdf for the general Truncated cases in #6690, but TruncatedNormal is its own distribution and doesn't have a logcdf currently implemented.
We should refactor the logcdf function added in that PR so that we can call it from the logcdf of TruncatedNormal as well as the general Truncated (no point in duplicating the code)
Reported in https://discourse.pymc.io/t/censoring-a-truncated-distribution/13261
The text was updated successfully, but these errors were encountered: