Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logcdf for TruncatedNormal #7034

Merged
merged 1 commit into from
Dec 21, 2023
Merged

Conversation

LukeLB
Copy link
Contributor

@LukeLB LukeLB commented Nov 26, 2023

Closes #7003 first highlighted by a user on discord (https://discourse.pymc.io/t/censoring-a-truncated-distribution/13261) when trying to use the following model,

with pm.Model() as m:
    mu = pm.Normal("mu", 0, 1)
    sigma = pm.Normal("sigma", 0, 1)
    normal_ = pm.Normal.dist(mu, sigma)
    trunc_normal = pm.Truncated.dist(normal_, lower=0)
    censored = pm.Censored("censored", trunc_normal, lower=None, upper=upper_bounds, observed=cy)

By refactoring the truncated truncated_logcdf to take a base_rv_op we can now call it from TruncatedNormal. I've also moved the truncated_logcdf_from_base_rv out of truncated.py and into dist_math.py as I was having curcular import errors.

Checklist

New features

-TruncatedNormal logcdf


📚 Documentation preview 📚: https://pymc--7034.org.readthedocs.build/en/7034/

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just suggesting a more direct test

pymc/distributions/dist_math.py Outdated Show resolved Hide resolved
tests/distributions/test_truncated.py Outdated Show resolved Hide resolved
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link

codecov bot commented Nov 29, 2023

Codecov Report

Merging #7034 (81dd599) into main (2e05854) will decrease coverage by 2.34%.
Report is 2 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7034      +/-   ##
==========================================
- Coverage   92.19%   89.86%   -2.34%     
==========================================
  Files         101      101              
  Lines       16893    16904      +11     
==========================================
- Hits        15575    15190     -385     
- Misses       1318     1714     +396     
Files Coverage Δ
pymc/distributions/continuous.py 97.79% <100.00%> (+0.02%) ⬆️

... and 9 files with indirect coverage changes

@LukeLB LukeLB closed this Nov 30, 2023
@LukeLB LukeLB reopened this Nov 30, 2023
@ricardoV94 ricardoV94 changed the title Issue #7003 Implement logcdf for TruncatedNormal Dec 1, 2023
@LukeLB
Copy link
Contributor Author

LukeLB commented Dec 11, 2023

I've rewritten the function and it returns the correct answers but not at the precision we would like, it is correct to 3 decimal places but past that the tests fail. Any idea why this might be occuring?

@ricardoV94
Copy link
Member

I've rewritten the function and it returns the correct answers but not at the precision we would like, it is correct to 3 decimal places but past that the tests fail. Any idea why this might be occuring?

We are still probably testing too extreme values, did it fix the nan thing at least?

@LukeLB
Copy link
Contributor Author

LukeLB commented Dec 11, 2023

Yeah the nan problem is sorted, shall I just change the domains in the tests then? The domains in the check_logp tests seem to work fine, e.g.,

check_logp(
    pm.TruncatedNormal,
    R,
    {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig},
    scipy_logp,
    decimal=select_by_precision(float64=6, float32=1),
    skip_paramdomain_outside_edge_test=True,
        )

@ricardoV94
Copy link
Member

Yup, sounds good

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great

@LukeLB
Copy link
Contributor Author

LukeLB commented Dec 13, 2023

Any idea why these tests are failing? They all pass locally for me on my Mac

@ricardoV94
Copy link
Member

It was a PyTensor issue that we already fixed, should work fine if we rerun them (I'm restarting them)

@ricardoV94 ricardoV94 merged commit 986738f into pymc-devs:main Dec 21, 2023
21 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement logcdf for TruncatedNormal
2 participants