Skip to content

Commit

Permalink
More efficient sampling from KroneckerMultiTaskGP (#2460)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

See #2310 (comment)

```python
import torch
from botorch.models import KroneckerMultiTaskGP

n_inputs = 10
n_tasks = 4
n_train = 2048
n_test = 1
device = torch.device("cuda:0")

train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)

test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)

gp = KroneckerMultiTaskGP(train_x, train_y)

posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([256, 1]))
```

The final line requires allocation of 128GB of GPU memory, because of the call to `torch.cholesky_solve` with B shaped `(256, 1, 8192, 1)` and L shaped `(8192, 8192)`.

By moving the largest batch dimension to the final position, we should achieve a more efficient operation.

Also fix docstring for `MultitaskGPPosterior`.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #2460

Test Plan:
Passes unit tests (specifically `test_multitask.py`).

Benchmarking results:
![image](https://github.com/user-attachments/assets/1eca54be-1ed4-43c9-bb50-a18cf24d00f5)
![image](https://github.com/user-attachments/assets/016322f6-992a-45bf-b175-e76208c11b12)

## Related PRs

N/A

Reviewed By: saitcakmak

Differential Revision: D63678866

Pulled By: Balandat

fbshipit-source-id: 6675c66dadd62934f95fabafe7b3f0155a1c0c6f
  • Loading branch information
slishak-PX authored and facebook-github-bot committed Oct 1, 2024
1 parent e29e30a commit 8924d1b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
44 changes: 41 additions & 3 deletions botorch/posteriors/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def __init__(
distribution: Posterior multivariate normal distribution.
joint_covariance_matrix: Joint test train covariance matrix over the entire
tensor.
train_train_covar: Covariance matrix of train points in the data space.
test_obs_covar: Covariance matrix of test x train points in the data space.
test_train_covar: Covariance matrix of test x train points in the data
space.
train_diff: Difference between train mean and train responses.
test_mean: Test mean response.
train_train_covar: Covariance matrix of train points in the data space.
train_noise: Training noise covariance.
test_noise: Only used if posterior should contain observation noise.
Testing noise covariance.
Expand Down Expand Up @@ -226,7 +228,9 @@ def rsample_from_base_samples(
train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
)
train_covar_plus_noise = self.train_train_covar + self.train_noise
obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1))
obs_solve = _permute_solve(
train_covar_plus_noise, obs_minus_samples.unsqueeze(-1)
)

# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)
Expand Down Expand Up @@ -286,3 +290,37 @@ def _draw_from_base_covar(
res = covar_root.matmul(base_samples)

return res.squeeze(-1)


def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator:
r"""Solve the batched linear system AX = b, where b is a batched column
vector. The solve is carried out after permuting the largest batch
dimension of b to the final position, which results in a more efficient
matrix-matrix solve.
This ideally should be handled upstream (in GPyTorch, linear_operator or
PyTorch), after which any uses of this method can be replaced with
`A.solve(b)`.
Args:
A: LinearOperator of shape (n, n)
b: Tensor of shape (..., n, 1)
Returns:
LinearOperator of shape (..., n, 1)
"""
# permute dimensions to move largest batch dimension to the end (more efficient
# than unsqueezing)
perm = list(range(b.ndim))
if b.ndim > 2:
largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1])
perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1]
b_p = b.permute(*perm)

x_p = A.solve(b_p)

# Undo permutation
inverse_perm = torch.argsort(torch.tensor(perm))
x = x_p.permute(*inverse_perm)

return x
24 changes: 22 additions & 2 deletions test/posteriors/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.posteriors.multitask import MultitaskGPPosterior
from botorch.posteriors.multitask import _permute_solve, MultitaskGPPosterior
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase
from linear_operator.operators import to_linear_operator


def get_posterior_test_cases(
Expand Down Expand Up @@ -41,7 +42,6 @@ def get_posterior_test_cases(


class TestMultitaskGPPosterior(BotorchTestCase):

def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None:
post_list = get_posterior_test_cases(device=self.device, dtype=dtype)
sample_shaping = torch.Size([5, 3])
Expand Down Expand Up @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self):
base_samples = torch.randn(4, 10, 1, device=self.device)
with self.assertRaises(RuntimeError):
res = posterior._draw_from_base_covar(sym_mat, base_samples)


class TestPermuteSolve(BotorchTestCase):
def test_permute_solve_tensor(self):
# Random PSD matrix
a = torch.randn(32, 32, device=self.device, dtype=torch.float64)
A = torch.mm(a, a.t())

# Random batched column vector
b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64)

# Compare results of permuted and standard solve
x_1 = _permute_solve(to_linear_operator(A), b)
x_2 = torch.linalg.solve(A, b)
self.assertAllClose(x_1, x_2)

# Ensure also works if b is not batched
x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :])
x_2 = torch.linalg.solve(A, b[0, 0, :, :])
self.assertAllClose(x_1, x_2)

0 comments on commit 8924d1b

Please sign in to comment.