Skip to content

Commit

Permalink
Fix SAASBO in benchmarking (#932)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #932

SAASBO is not currently supported in MBO (David and Max are working on it), reverting to the standard setup here.

Reviewed By: lena-kashtelyan

Differential Revision: D35853310

fbshipit-source-id: d8293ab378032e6af5059b1b477f415f25afa55c
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Apr 25, 2022
1 parent ca46192 commit 776f5a3
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 39 deletions.
55 changes: 22 additions & 33 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from botorch.acquisition.multi_objective.monte_carlo import (
qNoisyExpectedHypervolumeImprovement,
)
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import FixedNoiseGP


Expand Down Expand Up @@ -51,78 +50,68 @@ def get_sobol_botorch_modular_fixed_noise_gp_qnei() -> BenchmarkMethod:
scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name="SOBOL+BOTORCH_MODULAR::FixedNoiseGP_qNoisyExpectedImprovement",
name=generation_strategy.name,
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)


def get_sobol_botorch_modular_default():
generation_strategy = GenerationStrategy(
name="SOBOL+BOTORCH_MODULAR::default",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name="SOBOL+BOTORCH_MODULAR::default",
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)

def get_sobol_botorch_modular_fixed_noise_gp_qnehvi() -> BenchmarkMethod:
model_gen_kwargs = {
"model_gen_options": {
Keys.OPTIMIZER_KWARGS: {
"num_restarts": 50,
"raw_samples": 1024,
},
Keys.ACQF_KWARGS: {
"prune_baseline": True,
"qmc": True,
"mc_samples": 512,
},
}
}

def get_sobol_botorch_modular_saas_fully_bayesian_gp_qnei():
generation_strategy = GenerationStrategy(
name="SOBOL+BOTORCH_MODULAR::default",
name="SOBOL+BOTORCH_MODULAR::FixedNoiseGP_qNoisyExpectedHypervolumeImprovement",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
"surrogate": Surrogate(SaasFullyBayesianSingleTaskGP),
"botorch_acqf_class": qNoisyExpectedImprovement,
"surrogate": Surrogate(FixedNoiseGP),
"botorch_acqf_class": qNoisyExpectedHypervolumeImprovement,
},
model_gen_kwargs=model_gen_kwargs,
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name="SOBOL+BOTORCH_MODULAR::SaasFullyBayesianSingleTaskGP_qNoisyExpectedImprovement", # noqa
name=generation_strategy.name,
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)


def get_sobol_botorch_modular_saas_fully_bayesian_gp_qnehvi():
def get_sobol_botorch_modular_default():
generation_strategy = GenerationStrategy(
name="SOBOL+BOTORCH_MODULAR::default",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
"surrogate": Surrogate(SaasFullyBayesianSingleTaskGP),
"botorch_acqf_class": qNoisyExpectedHypervolumeImprovement,
},
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name="SOBOL+BOTORCH_MODULAR::SaasFullyBayesianSingleTaskGP_qNoisyExpectedHypervolumeImprovement", # noqa
name=generation_strategy.name,
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)
51 changes: 51 additions & 0 deletions ax/benchmark/methods/saasbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models
from ax.service.scheduler import SchedulerOptions


def get_saasbo_default() -> BenchmarkMethod:
generation_strategy = GenerationStrategy(
name="SOBOL+FULLYBAYESIAN::default",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.FULLYBAYESIAN,
num_trials=-1,
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name=generation_strategy.name,
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)


def get_saasbo_moo_default() -> BenchmarkMethod:
generation_strategy = GenerationStrategy(
name="SOBOL+FULLYBAYESIANMOO::default",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.FULLYBAYESIANMOO,
num_trials=-1,
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name=generation_strategy.name,
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)
8 changes: 2 additions & 6 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any, Dict

from ax.benchmark.problems.hpo.pytorch_cnn import (
Expand All @@ -12,7 +11,7 @@
)
from ax.core.runner import Runner
from ax.exceptions.core import UserInputError
from ax.utils.common.typeutils import not_none, checked_cast
from ax.utils.common.typeutils import checked_cast

try: # We don't require TorchVision by default.
from torchvision import transforms, datasets
Expand Down Expand Up @@ -72,10 +71,7 @@ class PyTorchCNNTorchvisionRunner(PyTorchCNNRunner):
def serialize_init_args(cls, runner: Runner) -> Dict[str, Any]:
pytorch_cnn_runner = checked_cast(PyTorchCNNRunner, runner)

pattern = re.compile("(?<=::).*") # Extract the dataset name
dataset_name = not_none(pattern.search(pytorch_cnn_runner.name)).group()

return {"name": dataset_name}
return {"name": pytorch_cnn_runner.name}

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
8 changes: 8 additions & 0 deletions sphinx/source/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ Benchmark Methods Modular BoTorch
:undoc-members:
:show-inheritance:

Benchmark Methods SAASBO
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.benchmark.methods.saasbo
:members:
:undoc-members:
:show-inheritance:

Benchmark Methods Choose Generation Strategy
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 776f5a3

Please sign in to comment.