Skip to content

Commit

Permalink
Fix a test that is failing flakily on GPU (pytorch#2521)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2521

This PR:
* Updates the tolerance described below from 0.12 to 0.13
* Adds a double-precision case to this test (doesn't make a difference)
* Updates some checks to `assertAllClose`

Reviewed By: Balandat

Differential Revision: D62442905

fbshipit-source-id: 9377705aece3e4c340717b4e99652e619eb038a1
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 10, 2024
1 parent ad4a93a commit 4d49bf7
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions test/posteriors/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,40 @@
from botorch.utils.testing import BotorchTestCase


class TestMultitaskGPPosterior(BotorchTestCase):
def setUp(self):
super().setUp()
torch.random.manual_seed(0)
def get_posterior_test_cases(
device: torch.device, dtype: torch.dtype
) -> list[tuple[KroneckerMultiTaskGP, torch.Tensor, MultitaskGPPosterior]]:
torch.random.manual_seed(0)

train_x = torch.rand(10, 1, device=device, dtype=dtype)
train_y = torch.randn(10, 3, device=device, dtype=dtype)

train_x = torch.rand(10, 1, device=self.device)
train_y = torch.randn(10, 3, device=self.device)
m2 = KroneckerMultiTaskGP(train_x, train_y)

m2 = KroneckerMultiTaskGP(train_x, train_y)
torch.random.manual_seed(0)
test_x = torch.rand(2, 5, 1, device=device, dtype=dtype)

torch.random.manual_seed(0)
test_x = torch.rand(2, 5, 1, device=self.device)
posterior0 = m2.posterior(test_x[0])
posterior1 = m2.posterior(test_x)
posterior2 = m2.posterior(test_x[0], observation_noise=True)
posterior3 = m2.posterior(test_x, observation_noise=True)

posterior0 = m2.posterior(test_x[0])
posterior1 = m2.posterior(test_x)
posterior2 = m2.posterior(test_x[0], observation_noise=True)
posterior3 = m2.posterior(test_x, observation_noise=True)
post_list = [
(m2, test_x[0], posterior0),
(m2, test_x, posterior1),
(m2, test_x[0], posterior2),
(m2, test_x, posterior3),
]
return post_list

self.post_list = [
[m2, test_x[0], posterior0],
[m2, test_x, posterior1],
[m2, test_x[0], posterior2],
[m2, test_x, posterior3],
]

def test_MultitaskGPPosterior(self):
class TestMultitaskGPPosterior(BotorchTestCase):

def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None:
post_list = get_posterior_test_cases(device=self.device, dtype=dtype)
sample_shaping = torch.Size([5, 3])

for post_collection in self.post_list:
for post_collection in post_list:
model, test_x, posterior = post_collection

self.assertIsInstance(posterior, MultitaskGPPosterior)
Expand Down Expand Up @@ -70,7 +75,10 @@ def test_MultitaskGPPosterior(self):
expected_base_sample_shape,
)
base_samples = torch.randn(
8, *expected_base_sample_shape, device=self.device
8,
*expected_base_sample_shape,
device=self.device,
dtype=dtype,
)

samples_1 = posterior.rsample_from_base_samples(
Expand All @@ -79,7 +87,7 @@ def test_MultitaskGPPosterior(self):
samples_2 = posterior.rsample_from_base_samples(
base_samples=base_samples, sample_shape=torch.Size((8,))
)
self.assertTrue(torch.allclose(samples_1, samples_2))
self.assertAllClose(samples_1, samples_2)

# test that botorch.sampler picks up the correct shapes
sampler = IIDNormalSampler(sample_shape=torch.Size([5]))
Expand All @@ -90,7 +98,10 @@ def test_MultitaskGPPosterior(self):

# test that providing only some base samples is okay
base_samples = torch.randn(
8, np.prod(expected_extended_shape), device=self.device
8,
np.prod(expected_extended_shape),
device=self.device,
dtype=dtype,
)
samples_3 = posterior.rsample_from_base_samples(
base_samples=base_samples, sample_shape=torch.Size((8,))
Expand Down Expand Up @@ -137,19 +148,29 @@ def test_MultitaskGPPosterior(self):
posterior_variance = posterior_variance.view(-1)

# slightly higher tolerance here because of the potential for low norms
# Note: This check appears to yield different results with CUDA
# depending on what machine it is run on, so you may see better
# convergence.
self.assertLess(
(posterior_mean - sampled_mean).norm() / posterior_mean.norm(),
0.12,
0.13,
)
self.assertLess(
(posterior_variance - sampled_variance).norm()
/ posterior_variance.norm(),
5e-2,
)

def test_MultitaskGPPosterior(self) -> None:
for dtype in (torch.float, torch.double):
with self.subTest(dtype=dtype):
self._test_MultitaskGPPosterior(dtype=dtype)

def test_draw_from_base_covar(self):
# grab a posterior
posterior = self.post_list[0][2]
posterior = get_posterior_test_cases(device=self.device, dtype=torch.float32)[
0
][2]

base_samples = torch.randn(4, 30, 1, device=self.device)
base_mat = torch.randn(30, 30, device=self.device)
Expand Down

0 comments on commit 4d49bf7

Please sign in to comment.