Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement variance reduction in SLQ logdet backward pass. #1836

Merged
merged 7 commits into from
Dec 21, 2021

Conversation

gpleiss
Copy link
Member

@gpleiss gpleiss commented Nov 22, 2021

Based on "Reducing the Variance of Gaussian Process Hyperparameter Optimization with Preconditioning" by Wenger et al., 2021.

When using iterative methods (i.e. CG/SLQ) to compute the log determinant, the forward pass currently computes:
logdet K \approx logdet P + SLQ( P^{-1/2} K P^{-1/2} ),
where P is a preconditioner, and SLQ is a stochastic estimate of the log determinant. If the preconditioner is a good approximation of K, then this forward pass can be seen as a form of variance reduction.

In this PR, we apply this same variance reduction strategy to the backward pass. We compute the backward pass as:
d logdet(K)/dtheta \approx d logdet(P)/dtheta + d SLQ/dtheta

TODOs:

  • Implement pivoted cholesky as a torch.autograd.Function, so that we can compute backward passes through it.
  • Redo inv_quad_logdet function to apply variance reduction in the forward and backward passes.

@gpleiss
Copy link
Member Author

gpleiss commented Dec 14, 2021

@jacobrgardner @JonathanWenger ready for review

@gpleiss gpleiss changed the title [WIP] Implement variance reduction in SLQ logdet backward pass. Implement variance reduction in SLQ logdet backward pass. Dec 14, 2021
@gpleiss
Copy link
Member Author

gpleiss commented Dec 14, 2021

(actually ready for review now. I just fixed broken tests.)

@jacobrgardner
Copy link
Member

@gpleiss something seems off about how this is computing the preconditioner log determinant now. We're still computing it efficiently using the QR decomposition in the init_cache methods on AddedDiagLT, and then it looks like we discard that efficient computation and call logdet on precond_lt which could be as bad as calling Cholesky on the full n x n preconditioner matrix, right? PsdSumLazyTensor doesn't override logdet?

Entirely possible I just missed the relevant code here...

@jacobrgardner
Copy link
Member

Ideally, this would be the logdet value we'd return in the forward pass:

self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze()

Copy link
Member

@jacobrgardner jacobrgardner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid to me now 👍

@gpleiss
Copy link
Member Author

gpleiss commented Dec 21, 2021

I just profiled this PR on the KeOps example notebook - just to double check. It is just as fast as what is on is on master.

@gpleiss gpleiss merged commit 0907c95 into master Dec 21, 2021
@gpleiss gpleiss deleted the piv_chol_func3 branch December 21, 2021 00:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants