-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Inefficient posterior evaluation of SaasFullyBayesianSingleTaskGP
when q=1
#2310
Comments
Thanks for raising this, this is a great catch. Since this call happens in gpytorch we'll have to make a change there and ensure that it is compatible with all kinds of other (non-fully-Bayesian) scenarios (not sure what kinds of shape exactly are encountered in this call), but we will definitely fix this. cc @dme65, @esantorella |
Thank you! I did initially try and come up with something to contribute to GPyTorch and/or linear_operator, but it was harder than anticipated to make it compatible and not introduce slowdowns in other situations, so I thought I'd report it here for now. For example, (n.b. I've just corrected a small mistake in the profiling code in the issue above) |
Interesting. Do you happen to have have some artifacts of those attempts that you could share? That would be very helpful. |
What I've done at the moment is just replace the aforementioned Below is a script that benchmarks Those specific situations when
In both situations, I'd guess that the appropriate set of matrix transposes and Codeimport torch
import torch.utils.benchmark as benchmark
from tqdm import tqdm
device = "cuda:0"
f_out = "einsum.txt"
def wrap(f):
def wrapped(a_sz, b_sz, device):
a = torch.randn(a_sz, device=device)
b = torch.randn(b_sz, device=device)
return f(a, b)
return wrapped
matmul = wrap(torch.matmul)
@wrap
def einsum(a, b):
return torch.einsum("...ik,...kj", a, b)
if __name__ == "__main__":
sizes = []
batch_funcs = [
("Full", lambda i_batch, j_batch: ((i_batch, j_batch), (i_batch, j_batch))),
("No a0", lambda i_batch, j_batch: ((j_batch,), (i_batch, j_batch))),
("No b0", lambda i_batch, j_batch: ((i_batch, j_batch), (j_batch,))),
]
for batch_type, batch_func in batch_funcs:
for i_batch in (256, 32, 1):
for j_batch in (32, 16, 1):
a_batch, b_batch = batch_func(i_batch, j_batch)
if None in (a_batch, b_batch):
continue
for i_size in (128, 64, 4):
sizes.append(
(
batch_type,
"Matrix-matrix product",
a_batch + (i_size, i_size),
b_batch + (i_size, i_size),
)
)
sizes.append(
(
batch_type,
"Matrix-vector product",
a_batch + (i_size, i_size),
b_batch + (i_size, 1),
)
)
sizes.append(
(
batch_type,
"Transposed MVP",
a_batch + (1, i_size),
b_batch + (i_size, i_size),
)
)
sizes.append(
(
batch_type,
"Vector outer product",
a_batch + (i_size, 1),
b_batch + (1, i_size),
)
)
sizes.append(
(
batch_type,
"Vector inner product",
a_batch + (1, i_size),
b_batch + (i_size, 1),
)
)
results = []
with torch.no_grad():
pbar = tqdm(sizes)
for env, label, a_sz, b_sz in pbar:
sub_label = f"{a_sz}x{b_sz}"
pbar.set_description(sub_label)
timers = [
benchmark.Timer(
stmt="matmul(a_sz, b_sz, device)",
globals={"device": device, "a_sz": a_sz, "b_sz": b_sz},
description="matmul",
setup="from __main__ import matmul",
label=label,
sub_label=sub_label,
env=env,
),
benchmark.Timer(
stmt="einsum(a_sz, b_sz, device)",
globals={"device": device, "a_sz": a_sz, "b_sz": b_sz},
description="einsum",
setup="from __main__ import einsum",
label=label,
sub_label=sub_label,
env=env,
),
]
for timer in timers:
result = timer.adaptive_autorange(min_run_time=1)
results.append(result)
pbar.write(str(result))
compare = benchmark.Compare(results)
compare.colorize(rowwise=True)
with open(f_out, "wt") as f:
f.write(str(compare)) Timing results
|
Along the same lines as above - I've found another (less significant) memory saving. If I attempt to optimise an acquisition function in the original issue with the setup below, after replacing the candidates, acq_values = optimize_acqf(
ucb,
bounds=torch.cat((torch.zeros(1, d), torch.ones(1, d))).to(**tkwargs),
q=1,
num_restarts=10,
raw_samples=1024,
) Traceback:
If I modify the call to res = torch.cdist(x1, x2, compute_mode="donot_use_mm_for_euclid_dist") For ref: at this point, The PyTorch documentation on (This issue is also more of a GPyTorch issue than a BoTorch issue, but it's closely related to this so I'm just adding it here for now; let me know if you want me to create a new issue). |
There is a very similar issue with import torch
from botorch.models import KroneckerMultiTaskGP
n_inputs = 10
n_tasks = 4
n_train = 2048
n_test = 1
device = torch.device("cuda:0")
train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)
test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)
gp = KroneckerMultiTaskGP(train_x, train_y)
posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([256, 1])) Stack trace
The final line requires allocation of 128GB of GPU memory, because of the call to The following line appears to be the root cause. Unsqueezing the last dimension sets up a batched matrix-vector solve, but if we instead transpose one of the batch dimensions to the end, we do a more efficient matrix-matrix solve. botorch/botorch/posteriors/multitask.py Line 229 in 4497a5c
For example, the following code requires less than 4GB of GPU memory, and I believe it is equivalent. By moving the first batch dimension to the final position, we at least stand the chance of having a more efficient operation if the first batch dimension is greater than 1. It would probably be even better to find the largest batch dimension and move that one to the end, or even flatten them all. perm = list(range(1, obs_minus_samples.ndim)) + [0]
inverse_perm = torch.argsort(torch.tensor(perm))
obs_minus_samples_p = obs_minus_samples.permute(*perm)
obs_solve = train_covar_plus_noise.solve(obs_minus_samples_p)
# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).permute(*inverse_perm) @Balandat @esantorella: should I submit a PR for this, or is there an obvious reason that it wouldn't work in other use cases? |
A PR for this would be great - I don't see any obvious reason why this wouldn't work in other cases. Ideally, we could even do this at the level of
Yes, that makes sense. There are probably some nontrivial tradeoffs between flattening them all and keeping them around depending on how exactly the underlying cuda kernel parallelizes the evaluation in each case. |
cc @sdaulton, @jandylin, @SebastianAment re excessive memory usage in Kronecker MTGPs |
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation See #2310 (comment) ```python import torch from botorch.models import KroneckerMultiTaskGP n_inputs = 10 n_tasks = 4 n_train = 2048 n_test = 1 device = torch.device("cuda:0") train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device) train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device) test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device) gp = KroneckerMultiTaskGP(train_x, train_y) posterior = gp.posterior(test_x) posterior.rsample(torch.Size([256, 1])) ``` The final line requires allocation of 128GB of GPU memory, because of the call to `torch.cholesky_solve` with B shaped `(256, 1, 8192, 1)` and L shaped `(8192, 8192)`. By moving the largest batch dimension to the final position, we should achieve a more efficient operation. Also fix docstring for `MultitaskGPPosterior`. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2460 Test Plan: Passes unit tests (specifically `test_multitask.py`). Benchmarking results: ![image](https://github.com/user-attachments/assets/1eca54be-1ed4-43c9-bb50-a18cf24d00f5) ![image](https://github.com/user-attachments/assets/016322f6-992a-45bf-b175-e76208c11b12) ## Related PRs N/A Reviewed By: saitcakmak Differential Revision: D63678866 Pulled By: Balandat fbshipit-source-id: 6675c66dadd62934f95fabafe7b3f0155a1c0c6f
🐛 Bug
Evaluating an acquisition function with
q=1
withSaasFullyBayesianSingleTaskGP
requires an unnecessarily large amount of memory, due to an inefficient broadcastedmatmul
operation.In the example below, the following line multiplies a tensor of size
[256, 16, 1, 2048]
with a tensor of size[16, 2048, 2048]
which requires the allocation of 128GB of memory:https://github.com/cornellius-gp/gpytorch/blob/9551eba889adf835b69cfd86e9a5d584fb61cdcc/gpytorch/models/exact_prediction_strategies.py#L118
To reproduce
** Code snippet to reproduce **
** Stack trace/error message **
Expected Behavior
The memory usage for this operation is very high because
torch.matmul
is inefficient for such batched matrix-vector multiplications. If the same operation is written as aneinsum
, or transposing such that it's a matrix-matrix multiplication, the memory usage and computation time are substantially reduced.For example, below is a demonstration of two alternative operations which reduce the memory and computation time by orders of magnitude:
System information
Please complete the following information:
The text was updated successfully, but these errors were encountered: