Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pm.Bound and TruncatedNormal generate wrong gradients #4417

Closed
ricardoV94 opened this issue Jan 14, 2021 · 15 comments · Fixed by #4448
Closed

pm.Bound and TruncatedNormal generate wrong gradients #4417

ricardoV94 opened this issue Jan 14, 2021 · 15 comments · Fixed by #4448

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 14, 2021

Edit: I came across issues in the model.dlogp (and sampling) of models using pm.Bound and TruncatedNormal while trying to sketch a generic pm.Truncated class. These problems are described in the messages at the end of the thread.


I was playing around to see if I could implement a generic Truncated class similar to the Bound but taking into consideration the extra normalization term. Everything is almost the same as the Bound except for the logp method and that it allows for observed values:

class _Truncated(Distribution):
    ...

    def _normalization(self):
        if self.lower is not None and self.upper is not None:
            lcdf_upper = self._wrapped.logcdf(self.upper)
            lcdf_lower = self._wrapped.logcdf(self.lower)
            return logdiffexp(lcdf_upper, lcdf_lower)

        if self.lower is not None:
            return log1mexp(-self._wrapped.logcdf(self.lower))

        if self.upper is not None:
            return self._wrapped.logcdf(self.upper)

        return 0

    def logp(self, value):
        logp = self._wrapped.logp(value) - self._normalization()
        bounds = []
        if self.lower is not None:
            bounds.append(value >= self.lower)
        if self.upper is not None:
            bounds.append(value <= self.upper)
        if len(bounds) > 0:
            return bound(logp, *bounds)
        else:
            return logp

You can check all the changes in my fork: master...ricardoV94:truncated

Everything seems to work fine, but when I actually try to sample something is definitely off. You can find my Notebook here: https://gist.github.com/ricardoV94/269f07b016a5136f52a1e0238d0ec4e6

First is a manual implementation using Potential:

# create data
np.random.seed(451)
x = np.random.exponential(3, size=5000)
minx=1
maxx=20

obs = x[np.where(~((x<minx) | (x>maxx)))] # remove values outside range

with pm.Model() as manual_model:
    λ = pm.Exponential("λ", lam=1/5)  # prior exponential with mean of 5
    x = pm.Exponential('x', lam=1/λ, observed=obs) # obs exponential with mean of $\lambda$.

    exp_dist = pm.Exponential.dist(lam=1/λ) # this is not part of the model, just used to get the logcdf
    norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size)

    trace_manual= pm.sample(2000, tune=1000, return_inferencedata=True)

az.summary(trace_manual)

image

And now with Truncated

with pm.Model() as trunc_model:
    λ = pm.Exponential("λ", lam=1/5)
    x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/λ, observed=obs)

    trace_trunc = pm.sample(2000, tune=1000, return_inferencedata=True)
az.summary(trace_trunc)

image

A lot of divergences and non convergence!

Everything seems to be working well when looking at the check_test_point:

in[5]
trunc_point = trunc_model.check_test_point(
    test_point={'λ_log__': np.log(5)}
)
trunc_point
out[5]
λ_log__      -1.00
x         -7936.43
Name: Log-probability of test_point, dtype: float64

in[6]
manual_point = manual_model.check_test_point(
    test_point={'λ_log__': np.log(5)}
)
manual_point
out[6]
λ_log__      -1.00
x         -8749.23
Name: Log-probability of test_point, dtype: float64

The difference between the manual and the trunc models is exactly the correction term (I assume the potential term is ignored in check_test_point)

in[7]
model_diff = manual_point - trunc_point
model_diff
out[7]
λ_log__      0.0
x         -812.8
Name: Log-probability of test_point, dtype: float64

in[8]
exp_dist = pm.Exponential.dist(1/5)
norm_term = (-pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size).eval()
norm_term
out[8]
array(812.80311981)

in[9]
trunc_dist = pm.Truncated(pm.Exponential, lower=minx, upper=maxx).dist(lam=1/5)

in[10]
trunc_dist.logp(obs).eval().sum()
out[10]
-7936.430976274112

in[11]
trunc_dist._normalization().eval() * len(obs)
out[11]
-812.8031198123141

I think I must be missing something obvious that happens during the sampling. Why would the model evaluate correctly but fail to sample?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 21, 2021

I made some progress. It seems that while I am getting the right logp, I am not getting the right gradient 😮

# create data
np.random.seed(451)
x = np.random.exponential(3, size=500)
minx=1
maxx=20

obs = x[np.where(~((x<minx) | (x>maxx)))] # remove values outside range
print(obs.size)

357
with pm.Model() as m:
    mu = pm.Exponential("mu", lam=1/5)
    x = pm.Exponential('x', lam=1/mu, observed=obs)

    exp_dist = pm.Exponential.dist(lam=1/mu) # this is not part of the model, just used to get the logcdf
    norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(exp_dist.logcdf(maxx), exp_dist.logcdf(minx)) * x.size)

print(m.logp({'mu_log__': np.log(3)}))
print(m.dlogp_array(np.log(3)))

-732.573797660396 
[-12.68406707] 
with pm.Model() as m:
    mu = pm.Exponential("mu", lam=1/5)
    x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/mu, observed=obs) 

print(m.logp({'mu_log__': np.log(3)}))
print(m.dlogp_array(np.log(3)))

-732.5737976603965  # Same logp
[457.97095304]  # Very different dlogp!
with pm.Model() as m:
    mu = pm.Exponential("mu", lam=1/5)
    x = pm.Bound(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/mu, observed=obs)

print(m.logp({'mu_log__': np.log(3)}))
print(m.dlogp_array(np.log(3)))

-852.2084303799146
[458.2930177]  # Almost the same as when using Truncated

In fact regardless of the sample size, the difference in gradient between the Bound and Truncated Normal is always minimal.

And if I sample with Metropolis instead of NUTS, everything looks fine

with pm.Model() as m:
    mu = pm.Exponential("mu", lam=1/5)
    x = pm.Truncated(pm.Exponential, lower=minx, upper=maxx)('x', lam=1/mu, observed=obs)
    trace = pm.sample(return_inferencedata=True, step=[pm.Metropolis([mu])])

az.summary(trace)

Multiprocess sampling (4 chains in 4 jobs)
Metropolis: [mu]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1 seconds.
The number of effective samples is smaller than 25% for some parameters.


     mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  \
mu  2.896  0.155   2.634    3.201      0.007    0.005     475.0   469.0   

    ess_bulk  ess_tail  r_hat  
mu     494.0     408.0   1.01  

Whereas with NUTS I get (it actually changes widely from run to run):

     mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  \
mu  3.293  0.997    2.18    4.944      0.494    0.378       4.0     4.0   

    ess_bulk  ess_tail  r_hat  
mu       4.0      11.0   3.69  

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 21, 2021

And... This problem seems to also be present in the implemented TruncatedNormal!

np.random.seed(2021)
x = np.random.normal(0, 2, size=5000)
obs = x[(x >= -1) & (x <= 2)]
# obs = x
obs.size

2641
with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)

print(m.logp({'mu': 0}))
print(m.dlogp_array(0))

-2884.7483443333817
[285.5950243]
with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.Truncated(pm.Normal, lower=-1, upper=2)('x', mu=mu, sigma=2, observed=obs)

print(m.logp({'mu': 0}))
print(m.dlogp_array(0))

-2884.7483443333817
[285.5950243]
with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.Bound(pm.Normal, lower=-1, upper=2)('x', mu=mu, sigma=2, observed=obs)

print(m.logp({'mu': 0}))
print(m.dlogp_array(0))

-4547.510374445674
[285.69833991]   # Again suspiciously similar to the Truncated versions
with pm.Model() as m:
    mu = pm.Normal("mu", 0, 5)
    x = pm.Normal('x', mu=mu, sigma=2, observed=obs)

    norm_dist = pm.Normal.dist(mu=mu, sigma=2) # this is not part of the model, just used to get the logcdf
    norm_term = pm.Potential("norm_term", -pm.math.logdiffexp(norm_dist.logcdf(2), norm_dist.logcdf(-1)) * x.size)

print(m.logp({'mu': 0}))
print(m.dlogp_array(0))

-2884.7483443333813
[12.84180893]   # This gradient is probably the correct one

And sample fails with NUTS, but not with Metropolis

with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
    trace = pm.sample(return_inferencedata=True)

print(az.summary(trace))


     mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  \
mu -0.149  0.249  -0.542    0.239      0.123    0.094       4.0     4.0   

    ess_bulk  ess_tail  r_hat  
mu       4.0      11.0    3.4 
with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)
    trace = pm.sample(return_inferencedata=True, step=[pm.Metropolis([mu])])

print(az.summary(trace))


     mean    sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  \
mu  0.113  0.09   -0.07    0.274      0.003    0.002     662.0   648.0   

    ess_bulk  ess_tail  r_hat  
mu     647.0     676.0   1.01  

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 21, 2021

Following @ColCarroll suggestion it seems that indeed the gradient should be around 12:

# model using native TruncatedNormal shown above
print((m.logp({'mu': 1e-5}) - m.logp({'mu': 0})) / 1e-5) 
print(m.dlogp_array(0))        
                                         
12.841238685723509  # Correct
[285.5950243]  # Wrong!

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 22, 2021

A bit more of worry, pm.Bound also seems to be doing something off to the model gradient. As I understand it, the gradient shouldn't change as we change the edges:

Edit: I am more inclined to think this is the correct behavior, but if someone can confirm that would be great :)

In [28]: np.random.seed(2021)
    ...: x = np.random.normal(0, 2, size=5000)
    ...: obs = x[(x >= -1) & (x <= 2)]

In [29]: for edge in (None, 2, 20, 200): 
    ...:     lower = None if edge is None else - edge 
    ...:     upper = edge 
    ...:     with pm.Model() as m: 
    ...:         mu = pm.Bound(pm.Normal, lower=lower, upper=upper)('mu', 0, 5) 
    ...:         x = pm.Normal('x', mu=mu, sigma=2, observed=obs) 
    ...:  
    ...:     print(f'\n[{lower=}, {upper=}]') 
    ...:     if edge is None: 
    ...:         print(m.logp({'mu': 0})) 
    ...:     else: 
    ...:         print(m.logp({'mu_interval__': 0})) 
    ...:     print(m.dlogp_array(0)) 
    ...:                                                                  

[lower=None, upper=None]
-4547.510374445674
[285.69833991]

[lower=-2, upper=2]
-4547.510374445674
[285.69833991]

[lower=-20, upper=20]
-4545.207789352679
[2856.98339912]

[lower=-200, upper=200]
-4542.905204259686
[28569.83399118]
In [30]: with pm.Model() as m: 
    ...:     mu = pm.Normal('mu', 0, 5) 
    ...:     x = pm.Normal('x', mu=mu, sigma=2, observed=obs) 
    ...:  
    ...: print(m.logp({'mu': 0})) 
    ...: print(m.dlogp_array(0))                                          
-4547.510374445674
[285.69833991]

@ricardoV94 ricardoV94 changed the title Issue implementing generic Truncated class pm.Bound and TruncatedNormal generate wrong gradients Jan 22, 2021
@twiecki
Copy link
Member

twiecki commented Jan 22, 2021

Don't even know where to start looking here. @brandonwillard any ideas?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 25, 2021

Surprisingly enough there was no issue with this TruncatedNormal regression example in pymc-devs/pymc-examples#30

Indeed, if I add zeros to mu in the shape of the observed data the problem seems to go away entirely (as well as in my generic Truncated class):

with pm.Model() as m1:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs, shape=len(obs))
    trace1 = pm.sample()

print(az.summary(trace1))

     mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  \
mu -0.301  0.411  -0.826    0.221      0.204    0.156       4.0     4.0   

    ess_bulk  ess_tail  r_hat  
mu       4.0      12.0   3.45


with pm.Model() as m2:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu+np.zeros_like(obs), sigma=2, lower=-1, upper=2, observed=obs, shape=len(obs))
    trace2 = pm.sample()

print(az.summary(trace2))

    mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_mean  ess_sd  \
mu  0.11  0.095  -0.067    0.287      0.002    0.002    1818.0  1818.0   

    ess_bulk  ess_tail  r_hat  
mu    1817.0    2965.0    1.0  

I don't know how to request the gradient anymore (I get an error if I try to use dlogp_array as before), but I would bet it is now correct.

What could the the cause? Some wild Theano optimization / broadcasting issue?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 25, 2021

In addition a FreeRV TruncatedNormal also works fine when conditioned on the same values as the obs:

with pm.Model() as m3:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, shape=(len(obs)))

m3.dlogp([m3.mu])({'mu': 0, 'x_interval__': m3.x.transformation.forward_val(obs)}) 

12.8412...  # The correct gradient

@ricardoV94
Copy link
Member Author

It seems to be fixed in Theano1.1.2! 😮
I cannot figure out why. The debugprint for the logp and model gradient are exactly the same, but the outputs are correct now.

np.random.seed(2021)
x = np.random.normal(0, 2, size=5000)
obs = x[(x >= -1) & (x <= 2)]

Theano1.1.2:

with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)

print(m.logp({'mu': 0}))
print(m.dlogp([mu])({'mu':0}))

-2884.7483443333817
[12.84180893]

Theano 1.1.0:

with pm.Model() as m:
    mu = pm.Normal('mu', 0, 5)
    x = pm.TruncatedNormal('x', mu=mu, sigma=2, lower=-1, upper=2, observed=obs)

print(m.logp({'mu': 0}))
print(m.dlogp([mu])({'mu':0}))

-2884.7483443333817
[285.5950243]

Should we do anything here? Add a unit test to check the gradient does not change in future iterations?

@twiecki
Copy link
Member

twiecki commented Jan 29, 2021

That is bizarre. @michaelosthege any idea what we might have fixed here relating to gradients?

Anyway, great that it's fixed. I suppose this manual gradient test would make for a great unittest.

@ricardoV94
Copy link
Member Author

Maybe it was this one: pymc-devs/pytensor@2b06fa1 ?

@michaelosthege
Copy link
Member

@ricardoV94 did you compare with Theano-PyMC 1.1.2 or Theano-PyMC master (before the rename) ?
Because the commit you mentioned was not part of the Theano-PyMC 1.1.2 release: pymc-devs/pytensor@rel-1.1.0...rel-1.1.2

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 29, 2021

I compared with 1.1.2 (after rebasing pymc3 to include the PR #4444). Might it have been the one related to tt.switch then?

pymc-devs/pytensor@b379b0f (I thought that had been incorporated in 1.1.0 already)

Edit: That one would makes more sense actually. I see it was in 1.1.1, which we didn't pair with any pymc3 release. That's why I was not considering it.

@michaelosthege
Copy link
Member

@ricardoV94 That's good, because we don't need another Thano-PyMC release then and this issue was accidentally fixed by #4444.

If you can make a small PR that adds a regression test and a mention to the release notes that'd be great.

@ricardoV94
Copy link
Member Author

Will do

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants