-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pm.Bound and TruncatedNormal generate wrong gradients #4417
Comments
I made some progress. It seems that while I am getting the right logp, I am not getting the right gradient 😮 # create data
np.random.seed(451)
x = np.random.exponential(3, size=500)
minx=1
maxx=20
obs = x[np.where(~((x<minx) | (x>maxx)))] # remove values outside range
print(obs.size)
357 with pm.Model() as m:
mu = pm.Exponential("mu", lam=1/5)
x = pm.Exponential('x', lam=1/mu, observed=obs)
exp_dist = pm.Exponential.dist(lam=1/mu) # this is not part of the model, just used to get the logcdf
norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size)
print(m.logp({'mu_log__': np.log(3)}))
print(m.dlogp_array(np.log(3)))
-732.573797660396
[-12.68406707] with pm.Model() as m:
mu = pm.Exponential("mu", lam=1/5)
x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/mu, observed=obs)
print(m.logp({'mu_log__': np.log(3)}))
print(m.dlogp_array(np.log(3)))
-732.5737976603965 # Same logp
[457.97095304] # Very different dlogp! with pm.Model() as m:
mu = pm.Exponential("mu", lam=1/5)
x = pm.Bound(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/mu, observed=obs)
print(m.logp({'mu_log__': np.log(3)}))
print(m.dlogp_array(np.log(3)))
-852.2084303799146
[458.2930177] # Almost the same as when using Truncated In fact regardless of the sample size, the difference in gradient between the Bound and Truncated Normal is always minimal. And if I sample with Metropolis instead of NUTS, everything looks fine with pm.Model() as m:
mu = pm.Exponential("mu", lam=1/5)
x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/mu, observed=obs)
trace = pm.sample(return_inferencedata=True, step=[pm.Metropolis([mu])])
az.summary(trace)
Multiprocess sampling (4 chains in 4 jobs)
Metropolis: [mu]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
The number of effective samples is smaller than 25% for some parameters.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \
mu 2.896 0.155 2.634 3.201 0.007 0.005 475.0 469.0
ess_bulk ess_tail r_hat
mu 494.0 408.0 1.01 Whereas with NUTS I get (it actually changes widely from run to run): mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \
mu 3.293 0.997 2.18 4.944 0.494 0.378 4.0 4.0
ess_bulk ess_tail r_hat
mu 4.0 11.0 3.69 |
And... This problem seems to also be present in the implemented TruncatedNormal! np.random.seed(2021)
x = np.random.normal(0, 2, size=5000)
obs = x[(x >= -1) & (x <= 2)]
# obs = x
obs.size
2641 with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
print(m.logp({'mu': 0}))
print(m.dlogp_array(0))
-2884.7483443333817
[285.5950243] with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.Truncated(pm.Normal, lower=-1, upper=2)('x', mu=mu, sigma=2, observed=obs)
print(m.logp({'mu': 0}))
print(m.dlogp_array(0))
-2884.7483443333817
[285.5950243] with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.Bound(pm.Normal, lower=-1, upper=2)('x', mu=mu, sigma=2, observed=obs)
print(m.logp({'mu': 0}))
print(m.dlogp_array(0))
-4547.510374445674
[285.69833991] # Again suspiciously similar to the Truncated versions with pm.Model() as m:
mu = pm.Normal("mu", 0, 5)
x = pm.Normal('x', mu=mu, sigma=2, observed=obs)
norm_dist = pm.Normal.dist(mu=mu, sigma=2) # this is not part of the model, just used to get the logcdf
norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(norm_dist.logcdf(2), norm_dist.logcdf(-1)) * x.size)
print(m.logp({'mu': 0}))
print(m.dlogp_array(0))
-2884.7483443333813
[12.84180893] # This gradient is probably the correct one And sample fails with NUTS, but not with Metropolis with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
trace = pm.sample(return_inferencedata=True)
print(az.summary(trace))
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \
mu -0.149 0.249 -0.542 0.239 0.123 0.094 4.0 4.0
ess_bulk ess_tail r_hat
mu 4.0 11.0 3.4 with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
trace = pm.sample(return_inferencedata=True, step=[pm.Metropolis([mu])])
print(az.summary(trace))
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \
mu 0.113 0.09 -0.07 0.274 0.003 0.002 662.0 648.0
ess_bulk ess_tail r_hat
mu 647.0 676.0 1.01 |
Following @ColCarroll suggestion it seems that indeed the gradient should be around 12: # model using native TruncatedNormal shown above
print((m.logp({'mu': 1e-5}) - m.logp({'mu': 0})) / 1e-5)
print(m.dlogp_array(0))
12.841238685723509 # Correct
[285.5950243] # Wrong! |
A bit more of worry, Edit: I am more inclined to think this is the correct behavior, but if someone can confirm that would be great :) In [28]: np.random.seed(2021)
...: x = np.random.normal(0, 2, size=5000)
...: obs = x[(x >= -1) & (x <= 2)]
In [29]: for edge in (None, 2, 20, 200):
...: lower = None if edge is None else - edge
...: upper = edge
...: with pm.Model() as m:
...: mu = pm.Bound(pm.Normal, lower=lower, upper=upper)('mu', 0, 5)
...: x = pm.Normal('x', mu=mu, sigma=2, observed=obs)
...:
...: print(f'\n[{lower=}, {upper=}]')
...: if edge is None:
...: print(m.logp({'mu': 0}))
...: else:
...: print(m.logp({'mu_interval__': 0}))
...: print(m.dlogp_array(0))
...:
[lower=None, upper=None]
-4547.510374445674
[285.69833991]
[lower=-2, upper=2]
-4547.510374445674
[285.69833991]
[lower=-20, upper=20]
-4545.207789352679
[2856.98339912]
[lower=-200, upper=200]
-4542.905204259686
[28569.83399118] In [30]: with pm.Model() as m:
...: mu = pm.Normal('mu', 0, 5)
...: x = pm.Normal('x', mu=mu, sigma=2, observed=obs)
...:
...: print(m.logp({'mu': 0}))
...: print(m.dlogp_array(0))
-4547.510374445674
[285.69833991] |
Don't even know where to start looking here. @brandonwillard any ideas? |
Surprisingly enough there was no issue with this TruncatedNormal regression example in pymc-devs/pymc-examples#30 Indeed, if I add zeros to mu in the shape of the observed data the problem seems to go away entirely (as well as in my generic Truncated class): with pm.Model() as m1:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs, shape=len(obs))
trace1 = pm.sample()
print(az.summary(trace1))
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \
mu -0.301 0.411 -0.826 0.221 0.204 0.156 4.0 4.0
ess_bulk ess_tail r_hat
mu 4.0 12.0 3.45
with pm.Model() as m2:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu+np.zeros_like(obs), sigma=2, lower=-1, upper=2, observed=obs, shape=len(obs))
trace2 = pm.sample()
print(az.summary(trace2))
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd \
mu 0.11 0.095 -0.067 0.287 0.002 0.002 1818.0 1818.0
ess_bulk ess_tail r_hat
mu 1817.0 2965.0 1.0 I don't know how to request the gradient anymore (I get an error if I try to use dlogp_array as before), but I would bet it is now correct. What could the the cause? Some wild Theano optimization / broadcasting issue? |
In addition a FreeRV TruncatedNormal also works fine when conditioned on the same values as the obs: with pm.Model() as m3:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, shape=(len(obs)))
m3.dlogp([m3.mu])({'mu': 0, 'x_interval__': m3.x.transformation.forward_val(obs)})
12.8412... # The correct gradient |
Possibly related TruncatedNormal discourse issues: |
It seems to be fixed in Theano1.1.2! 😮 np.random.seed(2021)
x = np.random.normal(0, 2, size=5000)
obs = x[(x >= -1) & (x <= 2)] Theano1.1.2: with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
print(m.logp({'mu': 0}))
print(m.dlogp([mu])({'mu':0}))
-2884.7483443333817
[12.84180893] Theano 1.1.0: with pm.Model() as m:
mu = pm.Normal('mu', 0, 5)
x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
print(m.logp({'mu': 0}))
print(m.dlogp([mu])({'mu':0}))
-2884.7483443333817
[285.5950243] Should we do anything here? Add a unit test to check the gradient does not change in future iterations? |
That is bizarre. @michaelosthege any idea what we might have fixed here relating to gradients? Anyway, great that it's fixed. I suppose this manual gradient test would make for a great unittest. |
Maybe it was this one: pymc-devs/pytensor@2b06fa1 ? |
@ricardoV94 did you compare with |
I compared with 1.1.2 (after rebasing pymc3 to include the PR #4444). Might it have been the one related to tt.switch then? pymc-devs/pytensor@b379b0f (I thought that had been incorporated in 1.1.0 already) Edit: That one would makes more sense actually. I see it was in 1.1.1, which we didn't pair with any pymc3 release. That's why I was not considering it. |
@ricardoV94 That's good, because we don't need another Thano-PyMC release then and this issue was accidentally fixed by #4444. If you can make a small PR that adds a regression test and a mention to the release notes that'd be great. |
Will do |
Edit: I came across issues in the model.dlogp (and sampling) of models using
pm.Bound
andTruncatedNormal
while trying to sketch a genericpm.Truncated
class. These problems are described in the messages at the end of the thread.I was playing around to see if I could implement a generic
Truncated
class similar to theBound
but taking into consideration the extra normalization term. Everything is almost the same as theBound
except for thelogp
method and that it allows for observed values:You can check all the changes in my fork: master...ricardoV94:truncated
Everything seems to work fine, but when I actually try to sample something is definitely off. You can find my Notebook here: https://gist.github.com/ricardoV94/269f07b016a5136f52a1e0238d0ec4e6
First is a manual implementation using Potential:
And now with
Truncated
A lot of divergences and non convergence!
Everything seems to be working well when looking at the
check_test_point
:The difference between the manual and the trunc models is exactly the correction term (I assume the potential term is ignored in
check_test_point
)I think I must be missing something obvious that happens during the sampling. Why would the model evaluate correctly but fail to sample?
The text was updated successfully, but these errors were encountered: