Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use @fast_botorch_optimize to fix test timeouts #2556

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ax/utils/sensitivity/tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
SobolSensitivityGPSampling,
)
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import fast_botorch_optimize
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.utils.transforms import unnormalize
from gpytorch.distributions import MultivariateNormal
from torch import Tensor


@fast_botorch_optimize
def get_modelbridge(modular: bool = False, saasbo: bool = False) -> ModelBridge:
exp = get_branin_experiment(with_batch=True)
exp.trials[0].run()
Expand Down Expand Up @@ -152,6 +154,8 @@ def test_SobolGPMean(self) -> None:
self.assertEqual(total_order.shape, torch.Size([2]))
self.assertEqual(second_order.shape, torch.Size([1]))

def test_SobolGPMean_SAASBO(self) -> None:
bounds = torch.tensor([(0.0, 1.0) for _ in range(2)]).t()
sensitivity_mean_saas = SobolSensitivityGPMean(
self.saas_model, num_mc_samples=10, bounds=bounds, second_order=True
)
Expand Down Expand Up @@ -312,7 +316,7 @@ def test_SobolGPMean(self) -> None:
set_rng_seed(seed)
# Unsigned
ind_dict = ax_parameter_sens(
model_bridge=bridge, # pyre-ignore
model_bridge=bridge,
metrics=None,
order="total",
signed=False,
Expand All @@ -327,7 +331,7 @@ def test_SobolGPMean(self) -> None:
set_rng_seed(seed) # reset seed to keep discrete features the same
cat_indices = bridge.model.search_space_digest.categorical_features
ind_dict_signed = ax_parameter_sens(
model_bridge=bridge, # pyre-ignore
model_bridge=bridge,
metrics=None,
order="total",
signed=True,
Expand Down