-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Continual regression tutorial question: Laplace #107
Comments
Ok there are a few things going on here 😅 hopefully I can help somewhat. You should not de-normalize the log_posterior for Laplace since it is called internally per-sample and so will work as long as when called with batchsize=1 is properly scaled. (Note that in most definitions of the Laplace approximation only the likelihood appears with no prior although I think its fine to also include the prior). I am not sure if that will solve the issue here though. One thing is that often for small datasets (like the episodes here), the full batch gradient at the mode will be very small and therefore the Empirical Fisher precision also very small and the Empirical Fisher covariance very large meaning the Laplace approximation breaks. But it is curious here that you get some good performance in the latest episode but not in previous ones. |
Hi, thanks for creating this library!
I was going through the continual regression tutorial notebook and wanted to try out
diag_fisher
and I expected it to produce similar results todiag_vi
. However, the predictive distribution did not look right except for the first episode. Here are my results for VI (by sampling parameters) and Fisher (linearized forward):One thing I noticed is for
diag_fisher
, theprec_diag
of some parameters are close to zero, which I clipped to1e-5
because otherwise it will cause an error withposteriors.laplace.diag_fisher.sample
. I also de-normalizedlog_posterior
by multiplying it with the number of data points fordiag_fisher
transform. I was wondering whether this is expected or there's something wrong with my implementation?Here's my code which can be subbed in after the VI training block in the notebook:
The text was updated successfully, but these errors were encountered: