Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Truncation of CustomDist #6947

Merged
merged 3 commits into from
Apr 8, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 11, 2023

This is now possible:

import pymc as pm
import numpy as np

def maxwell_dist(scale, size):
  return pm.math.sqrt(pm.ChiSquared.dist(nu=3, size=size)) * scale

scale = 5.0
x = pm.CustomDist.dist(scale, dist=maxwell_dist)

trunc_x = pm.Truncated.dist(x, lower=0, upper=2, size=(1000,))
assert np.all(pm.draw(trunc_x) < 2)

trunc_x = pm.Truncated.dist(x, lower=0, upper=5, size=())
assert pm.logp(trunc_x, 3.0).eval() > pm.logp(x, 3.0).eval()

This required cleaning up the interface of SymbolicRandomVariables (mainly circumventing pymc-devs/pytensor#473) so that we can safely "box" the base RVs in the inner OpFromGraph (i.e., recreate them with new shared inputs).

This challenge is very specific to Truncated which needs to "resample" the base RV for the rejection based algorithm.
No other SymbolicRandomVariable needs to do this, and they have avoided the need to box the base RVs by simply resizing them to the total size and using the resized RVs as explicit inputs to the inner graph.

For instance, Mixture will resize the component RVs to the "total size" and then scholastically index them based on its internal Categorical RV. ZeroSumNormal will create Normals as inputs and simply subtraction the mean.

Such an approach, however, makes it tricky for Truncated to know exactly what constitutes the "true" inputs of underlying SymbolicRandomVariables, and for this reason it rejected and still rejects arbitrary SymbolicRandomVariables. The exception, are the SymbolicRandomVariables created via CustomDist because for those are already "pre-boxed" in a sense. We know the relevant graph must start at dist.owner.inputs. Now that our class can safely manage and replace shared RNGs inputs, we can allow Truncated to handle such RVs, even if they require a bunch of shared RNGs.

Related to #6905 (comment)

TODO

@codecov
Copy link

codecov bot commented Oct 11, 2023

Codecov Report

Attention: Patch coverage is 91.08280% with 14 lines in your changes are missing coverage. Please review.

Project coverage is 87.75%. Comparing base (abe7bc9) to head (374e4e3).

❗ Current head 374e4e3 differs from pull request most recent head 92efe38. Consider uploading reports for the commit 92efe38 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6947      +/-   ##
==========================================
- Coverage   92.30%   87.75%   -4.55%     
==========================================
  Files         100      100              
  Lines       16888    16958      +70     
==========================================
- Hits        15588    14882     -706     
- Misses       1300     2076     +776     
Files Coverage Δ
pymc/distributions/timeseries.py 94.40% <100.00%> (-0.18%) ⬇️
pymc/distributions/truncated.py 99.44% <100.00%> (+0.03%) ⬆️
pymc/distributions/distribution.py 95.58% <91.54%> (+1.29%) ⬆️
pymc/pytensorf.py 90.71% <66.66%> (-0.58%) ⬇️

... and 76 files with indirect coverage changes

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 25, 2024

This PR depends on #7227

@zaxtax
Copy link
Contributor

zaxtax commented Mar 28, 2024

Looks like some changes still needed till tests pass.

@ricardoV94
Copy link
Member Author

Looks like some changes still needed till tests pass.

See my comment above, it needs the pytensor dependency bump which is happening in a separate PR

@zaxtax
Copy link
Contributor

zaxtax commented Mar 28, 2024 via email

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 28, 2024

Oh ok. I know it's not the most essential for this PR, but why does the Scan require a shared variable as an argument?

It's a limitation in the original implementation of Scan, where RNG variables must be shared. It's one of the things that we are hoping to solve with pymc-devs/pytensor#191

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ricardoV94 ricardoV94 force-pushed the truncated_custom_dist branch 2 times, most recently from 64aa904 to 48a5c39 Compare April 8, 2024 09:26
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment: What do you think about a warning if the Truncated distribution has to fall back to rejection sampling? This will introduce a while scan into the graph that could be quite surprising to users (JAX mode no longer possible, potentially big performance hit)

pymc/distributions/distribution.py Show resolved Hide resolved
pymc/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc/distributions/truncated.py Outdated Show resolved Hide resolved
pymc/distributions/truncated.py Show resolved Hide resolved
pymc/pytensorf.py Show resolved Hide resolved
@ricardoV94 ricardoV94 marked this pull request as ready for review April 8, 2024 11:54
@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 8, 2024

General comment: What do you think about a warning if the Truncated distribution has to fall back to rejection sampling? This will introduce a while scan into the graph that could be quite surprising to users (JAX mode no longer possible, potentially big performance hit)

I wouldn't add a warning, because there is nothing for the user to do instead. Can add a note in the docstrings if there's no mention of it yet

@ricardoV94 ricardoV94 merged commit 5f95fc2 into pymc-devs:main Apr 8, 2024
19 of 21 checks passed
@ricardoV94 ricardoV94 deleted the truncated_custom_dist branch April 8, 2024 17:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants