-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement logcdf for TruncatedNormal #7034
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just suggesting a more direct test
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #7034 +/- ##
==========================================
- Coverage 92.19% 89.86% -2.34%
==========================================
Files 101 101
Lines 16893 16904 +11
==========================================
- Hits 15575 15190 -385
- Misses 1318 1714 +396
|
I've rewritten the function and it returns the correct answers but not at the precision we would like, it is correct to 3 decimal places but past that the tests fail. Any idea why this might be occuring? |
We are still probably testing too extreme values, did it fix the nan thing at least? |
Yeah the nan problem is sorted, shall I just change the domains in the tests then? The domains in the check_logp(
pm.TruncatedNormal,
R,
{"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig},
scipy_logp,
decimal=select_by_precision(float64=6, float32=1),
skip_paramdomain_outside_edge_test=True,
) |
Yup, sounds good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
Any idea why these tests are failing? They all pass locally for me on my Mac |
It was a PyTensor issue that we already fixed, should work fine if we rerun them (I'm restarting them) |
Closes #7003 first highlighted by a user on discord (https://discourse.pymc.io/t/censoring-a-truncated-distribution/13261) when trying to use the following model,
By refactoring the truncated
truncated_logcdf
to take abase_rv_op
we can now call it fromTruncatedNormal
. I've also moved thetruncated_logcdf_from_base_rv
out of truncated.py and into dist_math.py as I was having curcular import errors.Checklist
New features
-TruncatedNormal logcdf
📚 Documentation preview 📚: https://pymc--7034.org.readthedocs.build/en/7034/