Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More efficient sampling from KroneckerMultiTaskGP #2460

Closed
wants to merge 11 commits into from
44 changes: 41 additions & 3 deletions botorch/posteriors/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def __init__(
distribution: Posterior multivariate normal distribution.
joint_covariance_matrix: Joint test train covariance matrix over the entire
tensor.
train_train_covar: Covariance matrix of train points in the data space.
test_obs_covar: Covariance matrix of test x train points in the data space.
test_train_covar: Covariance matrix of test x train points in the data
space.
train_diff: Difference between train mean and train responses.
test_mean: Test mean response.
train_train_covar: Covariance matrix of train points in the data space.
train_noise: Training noise covariance.
test_noise: Only used if posterior should contain observation noise.
Testing noise covariance.
Expand Down Expand Up @@ -226,7 +228,9 @@ def rsample_from_base_samples(
train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
)
train_covar_plus_noise = self.train_train_covar + self.train_noise
obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1))
obs_solve = _permute_solve(
train_covar_plus_noise, obs_minus_samples.unsqueeze(-1)
)

# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)
Expand Down Expand Up @@ -286,3 +290,37 @@ def _draw_from_base_covar(
res = covar_root.matmul(base_samples)

return res.squeeze(-1)


def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator:
r"""Solve the batched linear system AX = b, where b is a batched column
vector. The solve is carried out after permuting the largest batch
dimension of b to the final position, which results in a more efficient
matrix-matrix solve.

This ideally should be handled upstream (in GPyTorch, linear_operator or
PyTorch), after which any uses of this method can be replaced with
`A.solve(b)`.

Args:
A: LinearOperator of shape (n, n)
b: Tensor of shape (..., n, 1)

Returns:
LinearOperator of shape (..., n, 1)
"""
# permute dimensions to move largest batch dimension to the end (more efficient
# than unsqueezing)
perm = list(range(b.ndim))
if b.ndim > 2:
largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1])
perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1]
b_p = b.permute(*perm)

x_p = A.solve(b_p)

# Undo permutation
inverse_perm = torch.argsort(torch.tensor(perm))
x = x_p.permute(*inverse_perm)

return x
24 changes: 22 additions & 2 deletions test/posteriors/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.posteriors.multitask import MultitaskGPPosterior
from botorch.posteriors.multitask import MultitaskGPPosterior, _permute_solve
Balandat marked this conversation as resolved.
Show resolved Hide resolved
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase
from linear_operator.operators import to_linear_operator


def get_posterior_test_cases(
Expand Down Expand Up @@ -41,7 +42,6 @@ def get_posterior_test_cases(


class TestMultitaskGPPosterior(BotorchTestCase):

def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None:
post_list = get_posterior_test_cases(device=self.device, dtype=dtype)
sample_shaping = torch.Size([5, 3])
Expand Down Expand Up @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self):
base_samples = torch.randn(4, 10, 1, device=self.device)
with self.assertRaises(RuntimeError):
res = posterior._draw_from_base_covar(sym_mat, base_samples)


class TestPermuteSolve(BotorchTestCase):
def test_permute_solve_tensor(self):
# Random PSD matrix
a = torch.randn(32, 32, device=self.device, dtype=torch.float64)
A = torch.mm(a, a.t())

# Random batched column vector
b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64)

# Compare results of permuted and standard solve
x_1 = _permute_solve(to_linear_operator(A), b)
x_2 = torch.linalg.solve(A, b)
self.assertAllClose(x_1, x_2)

# Ensure also works if b is not batched
x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :])
x_2 = torch.linalg.solve(A, b[0, 0, :, :])
self.assertAllClose(x_1, x_2)
Loading