Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logcdf for TruncatedNormal #7003

Closed
Tracked by #7053
ricardoV94 opened this issue Nov 9, 2023 · 3 comments · Fixed by #7034
Closed
Tracked by #7053

Implement logcdf for TruncatedNormal #7003

ricardoV94 opened this issue Nov 9, 2023 · 3 comments · Fixed by #7034

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 9, 2023

Description

We implemented the logcdf for the general Truncated cases in #6690, but TruncatedNormal is its own distribution and doesn't have a logcdf currently implemented.

We should refactor the logcdf function added in that PR so that we can call it from the logcdf of TruncatedNormal as well as the general Truncated (no point in duplicating the code)

Reported in https://discourse.pymc.io/t/censoring-a-truncated-distribution/13261

@LukeLB
Copy link
Contributor

LukeLB commented Nov 18, 2023

Hey I've been having a crack at this but hitting a road block with the refactor. Using the example given on discord,

N = 1  # one trial is enough to illustrate the problem
mu, sigma = 1, 2
upper_bounds = np.array([1]) # changed from discord for clarity
y = np.random.normal(loc=mu, scale=sigma, size=N)

def right_censor(y, upper_bounds):
    cy = copy(y)
    censor_index = cy >= upper_bounds
    cy[censor_index] = upper_bounds[censor_index]
    return cy

cy = right_censor(y, upper_bounds)

with pm.Model() as m1:
    mu = pm.Normal("mu", 0, 1)
    sigma = pm.Normal("sigma", 0, 1)
    normal_ = pm.Normal.dist(mu, sigma)
    trunc_normal = pm.Truncated.dist(normal_, lower=0)
    censored = pm.Censored("censored", trunc_normal, lower=None, upper=upper_bounds, observed=cy)
    
with m1:
    i_data = pm.sample()

As expected I can get the example to sample if I copy the trunacted logcdf over and explicitley call the normal_logcdf function so,

def logcdf(value, mu, sigma, lower, upper):
        logcdf = normal_lcdf(mu, sigma, value)
        lower_logcdf = normal_lcdf(mu, sigma, lower)
        upper_logcdf = normal_lcdf(mu, sigma, upper)
        # the rest of the function is left unchanged 

but we don't actually want that. I run into issues when trying to import the function instead, when I do,

from pymc.distributions.truncated import truncated_logcdf
@_logcdf.register(TruncatedNormal)
def truncated_normal_logcdf(op, value, *inputs, **kwargs):
    return truncated_logcdf(op, value, *inputs,  **kwargs)

Using this code op is a truncated normal dist, my thinking is I need to access the underlying normal dist from op to pass to _logcdf, but I can't figure out how to. Is this the correct approach? And if so how do I access it?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 20, 2023

The idea would be to refactor this code into it's own function whose signature is truncated_logcdf_from_base_rv(base_rv_op, value, *base_rv_inputs, lower, upper)

logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)
# For left truncated discrete RVs, we don't want to include the lower bound in the
# normalization term
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf
logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
logcdf_trunc = logcdf_numerator - lognorm
if is_lower_bounded:
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc)
if is_upper_bounded:
logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0)
if is_lower_bounded and is_upper_bounded:
logcdf_trunc = check_parameters(
logcdf_trunc,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)
return logcdf_trunc

The main body of truncated_logcdf would look like:

@_logcdf.register(TruncatedRV)
def truncated_logcdf(op, value, *inputs, **kwargs):
    base_rv_op = op.base_rv_op
    *base_rv_inputs, lower, upper, rng = inputs
    base_rv_inputs = [rng, *base_rv_inputs]
    return truncated_logcdf_from_base_rv(base_rv_op, value, base_rv_inputs, lower, upper)

Then your truncated_normal_logcdf would look something like:

from pymc.distributions.truncated import truncated_logcdf
@_logcdf.register(TruncatedNormal)
def truncated_normal_logcdf(op, value, *inputs, **kwargs):
    base_rv_op = pm.Normal.rv_op
    *base_rv_inputs, lower, upper = inputs
    return truncated_logcdf_from_base_rv(base_rv_op, value, base_rv_inputs, lower, upper)

@LukeLB
Copy link
Contributor

LukeLB commented Nov 20, 2023

Ahh got it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants