Skip to content

Commit

Permalink
Add Sobol benchmark method
Browse files Browse the repository at this point in the history
Summary: This makes it easier to add Sobol into the mix when running benchmarks.

Reviewed By: esantorella

Differential Revision: D54647122
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 7, 2024
1 parent d0df866 commit a99117b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
34 changes: 34 additions & 0 deletions ax/benchmark/methods/sobol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Optional

from ax.benchmark.benchmark_method import (
BenchmarkMethod,
get_sequential_optimization_scheduler_options,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.scheduler import SchedulerOptions


def get_sobol_benchmark_method(
distribute_replications: bool,
scheduler_options: Optional[SchedulerOptions] = None,
) -> BenchmarkMethod:
generation_strategy = GenerationStrategy(
name="Sobol",
steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)],
)

return BenchmarkMethod(
name=generation_strategy.name,
generation_strategy=generation_strategy,
scheduler_options=scheduler_options
or get_sequential_optimization_scheduler_options(),
distribute_replications=distribute_replications,
)
14 changes: 14 additions & 0 deletions ax/benchmark/tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ax.benchmark.benchmark import benchmark_replication
from ax.benchmark.benchmark_method import get_sequential_optimization_scheduler_options
from ax.benchmark.methods.modular_botorch import get_sobol_botorch_modular_acquisition
from ax.benchmark.methods.sobol import get_sobol_benchmark_method
from ax.benchmark.problems.registry import get_problem
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -54,3 +55,16 @@ def test_benchmark_replication_runs(self) -> None:
problem = get_problem(problem_name="ackley4", num_trials=n_sobol_trials + 1)
result = benchmark_replication(problem=problem, method=method, seed=0)
self.assertTrue(np.isfinite(result.score_trace).all())

def test_sobol(self) -> None:
method = get_sobol_benchmark_method(
scheduler_options=get_sequential_optimization_scheduler_options(),
distribute_replications=False,
)
self.assertEqual(method.name, "Sobol")
gs = method.generation_strategy
self.assertEqual(len(gs._steps), 1)
self.assertEqual(gs._steps[0].model, Models.SOBOL)
problem = get_problem(problem_name="ackley4", num_trials=3)
result = benchmark_replication(problem=problem, method=method, seed=0)
self.assertTrue(np.isfinite(result.score_trace).all())

0 comments on commit a99117b

Please sign in to comment.