From afcd13d088e5459fb6b6ab0b7451409f96b84871 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Fri, 2 Aug 2024 15:15:37 +0100 Subject: [PATCH 01/10] Improve memory usage of multitask posterior sampling --- botorch/posteriors/multitask.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 29245d7b6a..defc5876a7 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -226,10 +226,23 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1)) + + # permute dimensions to move largest batch dimension to the end (more efficient + # than unsqueezing) + largest_batch_dim = torch.argmax(torch.tensor(obs_minus_samples.shape)) + perm = list(range(obs_minus_samples.ndim)) + perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] + + # solve + obs_minus_samples_p = obs_minus_samples.permute(*perm) + obs_solve_p = train_covar_plus_noise.solve(obs_minus_samples_p) # and multiply the test-observed matrix against the result of the solve - updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) + updated_samples_p = self.test_train_covar.matmul(obs_solve_p) + + # Undo permutation + inverse_perm = torch.argsort(torch.tensor(perm)) + updated_samples = updated_samples_p.permute(*inverse_perm) # finally, we add the conditioned samples to the prior samples final_samples = test_samples + updated_samples From 7bd66d8841edf5acbebb47536f87c8190c53567a Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Mon, 5 Aug 2024 16:46:59 +0100 Subject: [PATCH 02/10] Update docs; fixes to pass unit tests --- botorch/posteriors/multitask.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index defc5876a7..7f7eb838e7 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -36,9 +36,10 @@ def __init__( distribution: Posterior multivariate normal distribution. joint_covariance_matrix: Joint test train covariance matrix over the entire tensor. - train_train_covar: Covariance matrix of train points in the data space. - test_obs_covar: Covariance matrix of test x train points in the data space. + test_train_covar: Covariance matrix of test x train points in the data space. train_diff: Difference between train mean and train responses. + test_mean: Test mean response. + train_train_covar: Covariance matrix of train points in the data space. train_noise: Training noise covariance. test_noise: Only used if posterior should contain observation noise. Testing noise covariance. @@ -229,20 +230,24 @@ def rsample_from_base_samples( # permute dimensions to move largest batch dimension to the end (more efficient # than unsqueezing) - largest_batch_dim = torch.argmax(torch.tensor(obs_minus_samples.shape)) + largest_batch_dim = torch.argmax(torch.tensor(obs_minus_samples.shape[:-1])).item() + # largest_batch_dim = torch.argmax(torch.tensor(sample_shape)) perm = list(range(obs_minus_samples.ndim)) - perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] - + perm.remove(largest_batch_dim) + perm.append(largest_batch_dim) + # perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] + inverse_perm = torch.argsort(torch.tensor(perm)) + # solve obs_minus_samples_p = obs_minus_samples.permute(*perm) obs_solve_p = train_covar_plus_noise.solve(obs_minus_samples_p) - # and multiply the test-observed matrix against the result of the solve - updated_samples_p = self.test_train_covar.matmul(obs_solve_p) - # Undo permutation - inverse_perm = torch.argsort(torch.tensor(perm)) - updated_samples = updated_samples_p.permute(*inverse_perm) + obs_solve = obs_solve_p.permute(*inverse_perm).unsqueeze(-1) + + # and multiply the test-observed matrix against the result of the solve + # TODO: this might be made more efficient with obs_solve_p (permuted) + updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) # finally, we add the conditioned samples to the prior samples final_samples = test_samples + updated_samples From b22b140de2d7cbd9e6b43019dd48a773885cb316 Mon Sep 17 00:00:00 2001 From: Sam Lishak <122301116+slishak-PX@users.noreply.github.com> Date: Mon, 30 Sep 2024 14:47:39 +0100 Subject: [PATCH 03/10] Apply suggestions from code review Co-authored-by: Max Balandat --- botorch/posteriors/multitask.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index e3ca956c07..7aeb545a93 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -230,8 +230,7 @@ def rsample_from_base_samples( # permute dimensions to move largest batch dimension to the end (more efficient # than unsqueezing) - largest_batch_dim = torch.argmax(torch.tensor(obs_minus_samples.shape[:-1])).item() - # largest_batch_dim = torch.argmax(torch.tensor(sample_shape)) + largest_batch_dim = max(enumerate(obs_minus_samples.shape[:-1]), key=lambda t: t[0]) perm = list(range(obs_minus_samples.ndim)) perm.remove(largest_batch_dim) perm.append(largest_batch_dim) From 3f6f6974b335d43efb3557691fd198d1089c4bf4 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Mon, 30 Sep 2024 15:16:19 +0100 Subject: [PATCH 04/10] Move permuted solve into its own helper function --- botorch/posteriors/multitask.py | 19 ++--------------- botorch/utils/linalg.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 17 deletions(-) create mode 100644 botorch/utils/linalg.py diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 7aeb545a93..934303b907 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -8,6 +8,7 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.utils.linalg import permute_solve from gpytorch.distributions import MultivariateNormal from linear_operator.operators import LinearOperator, to_linear_operator from torch import Tensor @@ -227,25 +228,9 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - - # permute dimensions to move largest batch dimension to the end (more efficient - # than unsqueezing) - largest_batch_dim = max(enumerate(obs_minus_samples.shape[:-1]), key=lambda t: t[0]) - perm = list(range(obs_minus_samples.ndim)) - perm.remove(largest_batch_dim) - perm.append(largest_batch_dim) - # perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] - inverse_perm = torch.argsort(torch.tensor(perm)) - - # solve - obs_minus_samples_p = obs_minus_samples.permute(*perm) - obs_solve_p = train_covar_plus_noise.solve(obs_minus_samples_p) - - # Undo permutation - obs_solve = obs_solve_p.permute(*inverse_perm).unsqueeze(-1) + obs_solve = permute_solve(train_covar_plus_noise, obs_minus_samples) # and multiply the test-observed matrix against the result of the solve - # TODO: this might be made more efficient with obs_solve_p (permuted) updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) # finally, we add the conditioned samples to the prior samples diff --git a/botorch/utils/linalg.py b/botorch/utils/linalg.py new file mode 100644 index 0000000000..dc648a4a4b --- /dev/null +++ b/botorch/utils/linalg.py @@ -0,0 +1,37 @@ +import torch +from linear_operator.operators import LinearOperator + + +def permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: + r"""Solve the batched linear system AX = b, where b is a batched column + vector. The solve is carried out after permuting the largest batch + dimension of b to the final position, which results in a more efficient + matrix-matrix solve. + + This ideally should be handled upstream (in GPyTorch, linear_operator or + PyTorch), after which any uses of this method can be replaced with + `A.solve(b)`. + + Args: + A: LinearOperator of shape (n, n) + b: LinearOperator of shape (..., n, 1) + + Returns: + LinearOperator of shape (..., n, 1) + """ + # permute dimensions to move largest batch dimension to the end (more efficient + # than unsqueezing) + largest_batch_dim = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) + perm = list(range(b.ndim)) + perm.remove(largest_batch_dim) + perm.append(largest_batch_dim) + b_p = b.permute(*perm) + + # solve + x_p = A.solve(b_p) + + # Undo permutation + inverse_perm = torch.argsort(torch.tensor(perm)) + x = x_p.permute(*inverse_perm).unsqueeze(-1) + + return x From e0809cce2dd629d5350b005e805d476897fcf7ab Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Mon, 30 Sep 2024 16:08:59 +0100 Subject: [PATCH 05/10] Revert "Move permuted solve into its own helper function" This reverts commit 3f6f6974b335d43efb3557691fd198d1089c4bf4. --- botorch/posteriors/multitask.py | 19 +++++++++++++++-- botorch/utils/linalg.py | 37 --------------------------------- 2 files changed, 17 insertions(+), 39 deletions(-) delete mode 100644 botorch/utils/linalg.py diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 934303b907..7aeb545a93 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -8,7 +8,6 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.posteriors.gpytorch import GPyTorchPosterior -from botorch.utils.linalg import permute_solve from gpytorch.distributions import MultivariateNormal from linear_operator.operators import LinearOperator, to_linear_operator from torch import Tensor @@ -228,9 +227,25 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - obs_solve = permute_solve(train_covar_plus_noise, obs_minus_samples) + + # permute dimensions to move largest batch dimension to the end (more efficient + # than unsqueezing) + largest_batch_dim = max(enumerate(obs_minus_samples.shape[:-1]), key=lambda t: t[0]) + perm = list(range(obs_minus_samples.ndim)) + perm.remove(largest_batch_dim) + perm.append(largest_batch_dim) + # perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] + inverse_perm = torch.argsort(torch.tensor(perm)) + + # solve + obs_minus_samples_p = obs_minus_samples.permute(*perm) + obs_solve_p = train_covar_plus_noise.solve(obs_minus_samples_p) + + # Undo permutation + obs_solve = obs_solve_p.permute(*inverse_perm).unsqueeze(-1) # and multiply the test-observed matrix against the result of the solve + # TODO: this might be made more efficient with obs_solve_p (permuted) updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) # finally, we add the conditioned samples to the prior samples diff --git a/botorch/utils/linalg.py b/botorch/utils/linalg.py deleted file mode 100644 index dc648a4a4b..0000000000 --- a/botorch/utils/linalg.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch -from linear_operator.operators import LinearOperator - - -def permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: - r"""Solve the batched linear system AX = b, where b is a batched column - vector. The solve is carried out after permuting the largest batch - dimension of b to the final position, which results in a more efficient - matrix-matrix solve. - - This ideally should be handled upstream (in GPyTorch, linear_operator or - PyTorch), after which any uses of this method can be replaced with - `A.solve(b)`. - - Args: - A: LinearOperator of shape (n, n) - b: LinearOperator of shape (..., n, 1) - - Returns: - LinearOperator of shape (..., n, 1) - """ - # permute dimensions to move largest batch dimension to the end (more efficient - # than unsqueezing) - largest_batch_dim = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) - perm = list(range(b.ndim)) - perm.remove(largest_batch_dim) - perm.append(largest_batch_dim) - b_p = b.permute(*perm) - - # solve - x_p = A.solve(b_p) - - # Undo permutation - inverse_perm = torch.argsort(torch.tensor(perm)) - x = x_p.permute(*inverse_perm).unsqueeze(-1) - - return x From d63e7b0b00733a5a427302ddd7e3e74b22802be6 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Mon, 30 Sep 2024 16:10:51 +0100 Subject: [PATCH 06/10] Try again (avoid circular import with `utils` --- botorch/posteriors/multitask.py | 53 ++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 7aeb545a93..835945d2eb 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -227,25 +227,9 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - - # permute dimensions to move largest batch dimension to the end (more efficient - # than unsqueezing) - largest_batch_dim = max(enumerate(obs_minus_samples.shape[:-1]), key=lambda t: t[0]) - perm = list(range(obs_minus_samples.ndim)) - perm.remove(largest_batch_dim) - perm.append(largest_batch_dim) - # perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] - inverse_perm = torch.argsort(torch.tensor(perm)) - - # solve - obs_minus_samples_p = obs_minus_samples.permute(*perm) - obs_solve_p = train_covar_plus_noise.solve(obs_minus_samples_p) - - # Undo permutation - obs_solve = obs_solve_p.permute(*inverse_perm).unsqueeze(-1) + obs_solve = _permute_solve(train_covar_plus_noise, obs_minus_samples) # and multiply the test-observed matrix against the result of the solve - # TODO: this might be made more efficient with obs_solve_p (permuted) updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) # finally, we add the conditioned samples to the prior samples @@ -303,3 +287,38 @@ def _draw_from_base_covar( res = covar_root.matmul(base_samples) return res.squeeze(-1) + + +def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: + r"""Solve the batched linear system AX = b, where b is a batched column + vector. The solve is carried out after permuting the largest batch + dimension of b to the final position, which results in a more efficient + matrix-matrix solve. + + This ideally should be handled upstream (in GPyTorch, linear_operator or + PyTorch), after which any uses of this method can be replaced with + `A.solve(b)`. + + Args: + A: LinearOperator of shape (n, n) + b: LinearOperator of shape (..., n, 1) + + Returns: + LinearOperator of shape (..., n, 1) + """ + # permute dimensions to move largest batch dimension to the end (more efficient + # than unsqueezing) + largest_batch_dim = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) + perm = list(range(b.ndim)) + perm.remove(largest_batch_dim) + perm.append(largest_batch_dim) + b_p = b.permute(*perm) + + # solve + x_p = A.solve(b_p) + + # Undo permutation + inverse_perm = torch.argsort(torch.tensor(perm)) + x = x_p.permute(*inverse_perm).unsqueeze(-1) + + return x From 0a3d5867fb1d193a88ce8605739bbee6e7eae1d3 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Mon, 30 Sep 2024 16:14:00 +0100 Subject: [PATCH 07/10] Bug fix --- botorch/posteriors/multitask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 835945d2eb..f722e8f86a 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -308,7 +308,7 @@ def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: """ # permute dimensions to move largest batch dimension to the end (more efficient # than unsqueezing) - largest_batch_dim = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) + largest_batch_dim, _ = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) perm = list(range(b.ndim)) perm.remove(largest_batch_dim) perm.append(largest_batch_dim) From 68ba15ddb064335e13766a5a8ac189c85179055d Mon Sep 17 00:00:00 2001 From: Sam Lishak <122301116+slishak-PX@users.noreply.github.com> Date: Tue, 1 Oct 2024 08:56:27 +0100 Subject: [PATCH 08/10] Update botorch/posteriors/multitask.py Co-authored-by: Max Balandat --- botorch/posteriors/multitask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index f722e8f86a..0f5d9adce1 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -36,7 +36,8 @@ def __init__( distribution: Posterior multivariate normal distribution. joint_covariance_matrix: Joint test train covariance matrix over the entire tensor. - test_train_covar: Covariance matrix of test x train points in the data space. + test_train_covar: Covariance matrix of test x train points in the data + space. train_diff: Difference between train mean and train responses. test_mean: Test mean response. train_train_covar: Covariance matrix of train points in the data space. From f3baebc950b3a50202de2556b64c580a6ac3f0f0 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Tue, 1 Oct 2024 09:38:37 +0100 Subject: [PATCH 09/10] Add unit test, and correct bugs found --- botorch/posteriors/multitask.py | 17 +++++++++-------- test/posteriors/test_multitask.py | 24 ++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 0f5d9adce1..76a2df43d4 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -228,7 +228,9 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - obs_solve = _permute_solve(train_covar_plus_noise, obs_minus_samples) + obs_solve = _permute_solve( + train_covar_plus_noise, obs_minus_samples.unsqueeze(-1) + ) # and multiply the test-observed matrix against the result of the solve updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) @@ -290,7 +292,7 @@ def _draw_from_base_covar( return res.squeeze(-1) -def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: +def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator: r"""Solve the batched linear system AX = b, where b is a batched column vector. The solve is carried out after permuting the largest batch dimension of b to the final position, which results in a more efficient @@ -302,24 +304,23 @@ def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: Args: A: LinearOperator of shape (n, n) - b: LinearOperator of shape (..., n, 1) + b: Tensor of shape (..., n, 1) Returns: LinearOperator of shape (..., n, 1) """ # permute dimensions to move largest batch dimension to the end (more efficient # than unsqueezing) - largest_batch_dim, _ = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) perm = list(range(b.ndim)) - perm.remove(largest_batch_dim) - perm.append(largest_batch_dim) + if b.ndim > 2: + largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1]) + perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] b_p = b.permute(*perm) - # solve x_p = A.solve(b_p) # Undo permutation inverse_perm = torch.argsort(torch.tensor(perm)) - x = x_p.permute(*inverse_perm).unsqueeze(-1) + x = x_p.permute(*inverse_perm) return x diff --git a/test/posteriors/test_multitask.py b/test/posteriors/test_multitask.py index 42913d3ef7..d5add4ca12 100644 --- a/test/posteriors/test_multitask.py +++ b/test/posteriors/test_multitask.py @@ -8,9 +8,10 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.multitask import KroneckerMultiTaskGP -from botorch.posteriors.multitask import MultitaskGPPosterior +from botorch.posteriors.multitask import MultitaskGPPosterior, _permute_solve from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase +from linear_operator.operators import to_linear_operator def get_posterior_test_cases( @@ -41,7 +42,6 @@ def get_posterior_test_cases( class TestMultitaskGPPosterior(BotorchTestCase): - def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None: post_list = get_posterior_test_cases(device=self.device, dtype=dtype) sample_shaping = torch.Size([5, 3]) @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self): base_samples = torch.randn(4, 10, 1, device=self.device) with self.assertRaises(RuntimeError): res = posterior._draw_from_base_covar(sym_mat, base_samples) + + +class TestPermuteSolve(BotorchTestCase): + def test_permute_solve_tensor(self): + # Random PSD matrix + a = torch.randn(32, 32, device=self.device, dtype=torch.float64) + A = torch.mm(a, a.t()) + + # Random batched column vector + b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64) + + # Compare results of permuted and standard solve + x_1 = _permute_solve(to_linear_operator(A), b) + x_2 = torch.linalg.solve(A, b) + self.assertAllClose(x_1, x_2) + + # Ensure also works if b is not batched + x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :]) + x_2 = torch.linalg.solve(A, b[0, 0, :, :]) + self.assertAllClose(x_1, x_2) From 6417409c4c4e1a056ddd57337e20a8b91bb4cc64 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Tue, 1 Oct 2024 07:54:49 -0700 Subject: [PATCH 10/10] Fix import ordering --- test/posteriors/test_multitask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/posteriors/test_multitask.py b/test/posteriors/test_multitask.py index d5add4ca12..1b6c6b7dbc 100644 --- a/test/posteriors/test_multitask.py +++ b/test/posteriors/test_multitask.py @@ -8,7 +8,7 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.multitask import KroneckerMultiTaskGP -from botorch.posteriors.multitask import MultitaskGPPosterior, _permute_solve +from botorch.posteriors.multitask import _permute_solve, MultitaskGPPosterior from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase from linear_operator.operators import to_linear_operator