Skip to content

Commit

Permalink
allow passing of random seed in benchmark
Browse files Browse the repository at this point in the history
Summary: A way to fix seeds in Miles's benchmark

Reviewed By: lena-kashtelyan

Differential Revision: D35538759

fbshipit-source-id: cdcf9c0a7ade827f4f55c1c4a3e264085fa75094
  • Loading branch information
danielrjiang authored and facebook-github-bot committed Apr 19, 2022
1 parent ade023b commit c26fbf4
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
from time import time
from typing import Iterable, List, Tuple
from typing import Iterable, List, Optional, Tuple

import numpy as np
from ax.benchmark.benchmark_method import BenchmarkMethod
Expand All @@ -38,17 +38,21 @@
from ax.modelbridge.registry import Models
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.utils.common.typeutils import not_none
from botorch.utils.sampling import manual_seed


def benchmark_replication(
problem: BenchmarkProblem,
method: BenchmarkMethod,
replication_seed: Optional[int] = None,
) -> BenchmarkResult:
"""Runs one benchmarking replication (equivalent to one optimization loop).
Args:
problem: The BenchmarkProblem to test against (can be synthetic or real)
method: The BenchmarkMethod to test
replication_seed: The seed to use for this replication, set using `manual_seed`
from `botorch.utils.sampling`.
"""

experiment = Experiment(
Expand All @@ -63,20 +67,27 @@ def benchmark_replication(
generation_strategy=method.generation_strategy.clone_reset(),
options=method.scheduler_options,
)

scheduler.run_all_trials()
with manual_seed(seed=replication_seed):
scheduler.run_all_trials()

return _result_from_scheduler(scheduler=scheduler)


def benchmark_test(
problem: BenchmarkProblem, method: BenchmarkMethod, num_replications: int = 10
problem: BenchmarkProblem,
method: BenchmarkMethod,
num_replications: int = 10,
seed: Optional[int] = None,
) -> AggregatedBenchmarkResult:

return AggregatedBenchmarkResult.from_benchmark_results(
results=[
benchmark_replication(problem=problem, method=method)
for _ in range(num_replications)
benchmark_replication(
problem=problem,
method=method,
replication_seed=seed + i if seed is not None else None,
)
for i in range(num_replications)
]
)

Expand All @@ -85,11 +96,12 @@ def benchmark_full_run(
problems: Iterable[BenchmarkProblem],
methods: Iterable[BenchmarkMethod],
num_replications: int = 10,
seed: Optional[int] = None,
) -> List[AggregatedBenchmarkResult]:

return [
benchmark_test(
problem=problem, method=method, num_replications=num_replications
problem=problem, method=method, num_replications=num_replications, seed=seed
)
for problem in problems
for method in methods
Expand All @@ -101,6 +113,7 @@ def benchmark_scored_test(
method: BenchmarkMethod,
baseline_result: AggregatedBenchmarkResult,
num_replications: int = 10,
seed: Optional[int] = None,
) -> ScoredBenchmarkResult:
if isinstance(problem, SingleObjectiveBenchmarkProblem):
optimum = problem.optimal_value
Expand All @@ -113,7 +126,7 @@ def benchmark_scored_test(
)

aggregated_result = benchmark_test(
problem=problem, method=method, num_replications=num_replications
problem=problem, method=method, num_replications=num_replications, seed=seed
)

return ScoredBenchmarkResult.from_result_and_baseline(
Expand All @@ -129,6 +142,7 @@ def benchmark_scored_full_run(
],
methods: Iterable[BenchmarkMethod],
num_replications: int = 10,
seed: Optional[int] = None,
) -> List[ScoredBenchmarkResult]:

return [
Expand All @@ -137,14 +151,18 @@ def benchmark_scored_full_run(
method=method,
baseline_result=baseline_result,
num_replications=num_replications,
seed=seed,
)
for problem, baseline_result in problems_baseline_results
for method in methods
]


def get_sobol_baseline(
problem: BenchmarkProblem, num_replications: int = 100, total_trials: int = 100
problem: BenchmarkProblem,
num_replications: int = 100,
total_trials: int = 100,
seed: Optional[int] = None,
) -> AggregatedBenchmarkResult:
return benchmark_test(
problem=problem,
Expand All @@ -159,6 +177,7 @@ def get_sobol_baseline(
),
),
num_replications=num_replications,
seed=seed,
)


Expand Down

0 comments on commit c26fbf4

Please sign in to comment.