-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More efficient sampling from KroneckerMultiTaskGP #2460
Conversation
Hi @slishak-PX! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2460 +/- ##
=======================================
Coverage 99.98% 99.98%
=======================================
Files 193 193
Lines 17062 17072 +10
=======================================
+ Hits 17059 17069 +10
Misses 3 3 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this up. Would love to see some benchmarks for this. Ultimately, this is something that ideally should be handled upstream in gpytorch, could you add a comment to that extent?
@slishak-PX have you had a chance to run some benchmarks on this? |
@Balandat sorry, I will prioritise this as soon as we have the CLA signed, hopefully in the next week or two |
Co-authored-by: Max Balandat <[email protected]>
This reverts commit 3f6f697.
Added benchmarking results. Uses the following code: benchmark.py (run on the code in this PR, and BoTorch 0.12.0): import pickle as pkl
import botorch
import torch
from tqdm import tqdm
from botorch.models import KroneckerMultiTaskGP
device = torch.device("cuda:0")
def get_data(n_inputs=10, n_tasks=4, n_train=128, n_test=1, seed=50):
torch.manual_seed(seed)
train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device)
test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device)
return train_x, train_y, test_x
def instantiate_and_sample(train_x, train_y, test_x, n_samples=1):
with torch.no_grad():
gp = KroneckerMultiTaskGP(train_x, train_y)
posterior = gp.posterior(test_x)
posterior.rsample(torch.Size([n_samples]))
def profile(func, *args, **kwargs):
torch.cuda.reset_peak_memory_stats(device=device)
m0 = torch.cuda.max_memory_allocated(device=device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
func(*args, **kwargs)
end.record()
torch.cuda.synchronize()
time = start.elapsed_time(end)
m1 = torch.cuda.max_memory_allocated(device=device)
torch.cuda.empty_cache()
memory = (m1 - m0) / 1024**3
return memory, time
if __name__ == "__main__":
fname = botorch.__version__ + "_results.pkl"
print(fname)
n_tasks_list = [2, 4, 8]
n_train_list = [32, 128, 512]
n_test_list = [1, 8, 64]
n_samples_list = [1, 4, 16, 64, 256]
results = []
for n_tasks in tqdm(n_tasks_list, desc="n_tasks"):
for n_train in tqdm(n_train_list, leave=False, desc="n_train"):
for n_test in tqdm(n_test_list, leave=False, desc="n_test"):
train_x, train_y, test_x = get_data(
n_tasks=n_tasks, n_train=n_train, n_test=n_test
)
for n_samples in tqdm(n_samples_list, leave=False, desc="n_sample"):
memory = []
time = []
for i in range(10):
try:
m, t = profile(
instantiate_and_sample,
train_x,
train_y,
test_x,
n_samples,
)
except:
print("Failed!")
print(
{
"n_tasks": n_tasks,
"n_train": n_train,
"n_test": n_test,
"n_samples": n_samples,
}
)
raise
if i > 0:
memory.append(m)
time.append(t)
results.append(
{
"n_tasks": n_tasks,
"n_train": n_train,
"n_test": n_test,
"n_samples": n_samples,
"memory": memory,
"time": time,
}
)
with open(fname, "wb") as f:
pkl.dump(results, f) Analysis notebook: import pickle as pkl
import numpy as np
import pandas as pd
import plotly.express as px
with open("Unknown_results.pkl", "rb") as f:
results = pkl.load(f)
df_new = pd.DataFrame(results)
df_new["version"] = "PR 2460"
with open("0.12.0_results.pkl", "rb") as f:
results = pkl.load(f)
df_old = pd.DataFrame(results)
df_old["version"] = "BoTorch 0.12.0"
df = pd.concat([df_new, df_old])
df["memory_mean"] = df["memory"].apply(np.mean)
df["memory_std"] = df["memory"].apply(np.std)
df["time_mean"] = df["time"].apply(np.mean)
df["time_std"] = df["time"].apply(np.std)
px.line(
df,
x="n_samples",
y="memory_mean",
error_y="memory_std",
facet_col="n_tasks",
facet_row="n_test",
color="n_train",
line_dash="version",
log_x=True,
log_y=True,
width=800,
height=800,
)
px.line(
df,
x="n_samples",
y="time_mean",
error_y="time_std",
facet_col="n_tasks",
facet_row="n_test",
color="n_train",
line_dash="version",
log_x=True,
log_y=True,
width=800,
height=800,
) |
@Balandat has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome, thanks a lot for contributing this change and the comprehensive benchmarks. Seems like there is no downside to using this from a perf perspective and the logic is also sufficiently straightforward so I'm not worried about tech debt.
The only ask I have before merging this in (besides fixing the flake8 lint) is to write a short unittest.
cc @jandylin, @SebastianAment re sampling from Kronecker structured GPs and interesting matrix solve efficiency gains...
Co-authored-by: Max Balandat <[email protected]>
You were absolutely right to ask for a unit test - the implementation was not entirely correct, although I think in the context of the use in |
@Balandat has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Many thanks for the contribution, @slishak-PX ! |
Motivation
See #2310 (comment)
The final line requires allocation of 128GB of GPU memory, because of the call to
torch.cholesky_solve
with B shaped(256, 1, 8192, 1)
and L shaped(8192, 8192)
.By moving the largest batch dimension to the final position, we should achieve a more efficient operation.
Also fix docstring for
MultitaskGPPosterior
.Have you read the Contributing Guidelines on pull requests?
Yes
Test Plan
Passes unit tests (specifically
test_multitask.py
).Benchmarking results:
Related PRs
N/A