Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More efficient sampling from KroneckerMultiTaskGP (#2460)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation See #2310 (comment) ```python import torch from botorch.models import KroneckerMultiTaskGP n_inputs = 10 n_tasks = 4 n_train = 2048 n_test = 1 device = torch.device("cuda:0") train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device) train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device) test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device) gp = KroneckerMultiTaskGP(train_x, train_y) posterior = gp.posterior(test_x) posterior.rsample(torch.Size([256, 1])) ``` The final line requires allocation of 128GB of GPU memory, because of the call to `torch.cholesky_solve` with B shaped `(256, 1, 8192, 1)` and L shaped `(8192, 8192)`. By moving the largest batch dimension to the final position, we should achieve a more efficient operation. Also fix docstring for `MultitaskGPPosterior`. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2460 Test Plan: Passes unit tests (specifically `test_multitask.py`). Benchmarking results: ![image](https://github.com/user-attachments/assets/1eca54be-1ed4-43c9-bb50-a18cf24d00f5) ![image](https://github.com/user-attachments/assets/016322f6-992a-45bf-b175-e76208c11b12) ## Related PRs N/A Reviewed By: saitcakmak Differential Revision: D63678866 Pulled By: Balandat fbshipit-source-id: 6675c66dadd62934f95fabafe7b3f0155a1c0c6f
- Loading branch information