-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Marginal Laplace approximation #344
Comments
13 tasks
Closed
Build on top of QMC implementation in #353 |
WIP: Requires pymc-devs/pytensor#887. Need to limit to three tier models so we can access
# for MvNormal
# y are obs
# x is latent field
# params are hyperparams
mean, cov = marginalized_rv_op.dist_params(marginalized_rv_node)
y = # marginalized_rv_op.owner...?
f = # log(p(y | x, params)), takes a latent value and measurement and returns a float
df = grad(f)
d2f = grad(grad(f))
Q = matrix_inverse(cov) # precision matrix, assume rewrite so it takes user-supplied tau rather than inverting
def newton_step(x, y, gaussian_mean, gaussian_prec):
# solve A x = b
# where A is (precision - d2f)
A = gaussian_prec - d2f(x, y) # add to diagonal, maybe pt.fill_diagonal(Q, Q.diagonal() + d2f(x)) more efficient than multiply by eye
linear_part = Q @ gaussian_mean + df(x, y) - x * df(x, y)
res = pt.linalg.solve(A, linear_part)
return res
x_mode = # newton step solve to find best x, gaussian_mean=mean, gaussian_prec=Q
quadratic_part = - 0.5 * x_mode.T @ (Q - d2f(x_mode, y)) @ x_mode
linear_part = Q @ mean + df(x_mode, y) - x * d2f(x_mode, y) # evaluate d2f(x_mode, y) once
log_p_x_y_params = quadratic_part + linear_part + 0.5 * pt.linalg.slogdet(Q) # need const?
# use the value of the constant from the log probability function of the gaussian with the same quadratic and linear terms
# dim * pt.log(2 * pt.pi) / 2
# this is full likelihood but we only need to add marginal part
# which of these do we need to replace
# or should we replace the entire model log-likelihood
likelihood = ( # P(y | params) =
gaussian.logpdf(x) # P(x | params) # PrecisionMvNormal?
+ self.f(x, y).sum() # * P(y | x, params) # probably have this from p(y | x)
- log_laplace_approx # / P(x | y, params)
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is part of #340.
We have a Laplace approximation, but we only want to use it on a subset of variables (the latent field). We want to use some other inference method on the other variables (hyperparameters).
This can perhaps use step sampler, or maybe even something like the marginal model.
The text was updated successfully, but these errors were encountered: