Skip to content

Commit

Permalink
Refactor computation of optimization trace (#2747)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2747

* Add `BenchmarkProblem.get_oracle_experiment_from_params`, a method to compute an experiment where parameters are evaluated at oracle values. This will be useful once we enable inference regret.
* Add a helper `get_oracle_experiment_from_experiment` to replicate the old behavior of `get_oracle_experiment`.
* Remove `get_opt_trace` from `BenchmarkProblem` and absorbe that logic int `benchmark_replication`; once we have inference regret enabled, how we compute the trace should depend on the _method_, not the problem. The problem should only be responsible for computing oracle values given a parameterization.
* Arc lint

Differential Revision: D62250058
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 6, 2024
1 parent 746d3c9 commit 1ca80a2
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 28 deletions.
11 changes: 10 additions & 1 deletion ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.core.experiment import Experiment
from ax.core.utils import get_model_times
from ax.service.scheduler import Scheduler
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.utils.common.logger import get_logger
from ax.utils.common.random import with_rng_seed

Expand Down Expand Up @@ -116,7 +117,15 @@ def benchmark_replication(
with with_rng_seed(seed=seed):
scheduler.run_n_trials(max_trials=problem.num_trials)

optimization_trace = problem.get_opt_trace(experiment=experiment)
oracle_experiment = problem.get_oracle_experiment_from_experiment(
experiment=experiment
)
optimization_trace = np.array(
BestPointMixin._get_trace(
experiment=oracle_experiment,
optimization_config=problem.optimization_config,
)
)

try:
# Catch any errors that may occur during score computation, such as errors
Expand Down
83 changes: 56 additions & 27 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

# pyre-strict

from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any, Optional, Union

import numpy as np
import pandas as pd

from ax.benchmark.benchmark_metric import BenchmarkMetric
Expand All @@ -25,9 +25,8 @@
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import ComparisonOp
from ax.core.types import ComparisonOp, TParamValue
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.utils.common.base import Base
from botorch.test_functions.base import (
BaseTestProblem,
Expand Down Expand Up @@ -86,55 +85,85 @@ class BenchmarkProblem(Base):
search_space: SearchSpace = field(repr=False)
runner: BenchmarkRunner = field(repr=False)

def get_oracle_experiment(self, experiment: Experiment) -> Experiment:
def get_oracle_experiment_from_params(
self,
dict_of_dict_of_params: Mapping[int, Mapping[str, [Mapping[str, TParamValue]]]],
) -> Experiment:
"""
Get a new experiment with the same search space and optimization config
as those belonging to this problem, but with parameterizations evaluated
at oracle values.
Args:
dict_of_dict_of_params: Keys are trial indices, values are Mappings
(e.g. dicts) that map arm names to parameterizations.
Example:
>>> problem.get_oracle_experiment_from_params(
... {
... 0: {
... "0_0": {"x0": 0.0, "x1": 0.0},
... "0_1": {"x0": 0.3, "x1": 0.4},
... },
... 1: {"1_0": {"x0": 0.0, "x1": 0.0}},
... }
... )
"""
records = []

new_experiment = Experiment(
experiment = Experiment(
search_space=self.search_space, optimization_config=self.optimization_config
)
for trial_index, trial in experiment.trials.items():
for arm in trial.arms:
if len(dict_of_dict_of_params) == 0:
return experiment

for trial_index, dict_of_params in dict_of_dict_of_params.items():
if len(dict_of_params) == 0:
raise ValueError(
"Can't create a trial with no arms. Each sublist in "
"list_of_list_of_params must have at least one element."
)
for arm_name, params in dict_of_params.items():
for metric_name, metric_value in zip(
self.runner.outcome_names,
self.runner.evaluate_oracle(parameters=arm.parameters),
self.runner.evaluate_oracle(parameters=params),
):
records.append(
{
"arm_name": arm.name,
"arm_name": arm_name,
"metric_name": metric_name,
"mean": metric_value.item(),
"mean": metric_value,
"sem": 0.0,
"trial_index": trial_index,
}
)

new_experiment.attach_trial(
parameterizations=[arm.parameters for arm in trial.arms],
arm_names=[arm.name for arm in trial.arms],
experiment.attach_trial(
parameterizations=list(dict_of_params.values()),
arm_names=list(dict_of_params.keys()),
)
for trial in new_experiment.trials.values():
for trial in experiment.trials.values():
trial.mark_completed()

data = Data(df=pd.DataFrame.from_records(records))
new_experiment.attach_data(data=data, overwrite_existing_data=True)
return new_experiment
experiment.attach_data(data=data, overwrite_existing_data=True)
return experiment

def get_oracle_experiment_from_experiment(
self, experiment: Experiment
) -> Experiment:
return self.get_oracle_experiment_from_params(
dict_of_dict_of_params={
trial.index: {arm.name: arm.parameters for arm in trial.arms}
for trial in experiment.trials.values()
}
)

@property
def is_moo(self) -> bool:
"""Whether the problem is multi-objective."""
return isinstance(self.optimization_config, MultiObjectiveOptimizationConfig)

def get_opt_trace(self, experiment: Experiment) -> np.ndarray:
"""Evaluate the optimization trace of a list of Trials."""
oracle_experiment = self.get_oracle_experiment(experiment=experiment)

return np.array(
BestPointMixin._get_trace(
experiment=oracle_experiment,
optimization_config=self.optimization_config,
)
)


def _get_constraints(
num_constraints: int, observe_noise_sd: bool
Expand Down
71 changes: 71 additions & 0 deletions ax/benchmark/tests/test_benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict

import math
from math import pi
from typing import Optional, Union

import torch
Expand All @@ -22,6 +23,7 @@
from ax.core.types import ComparisonOp
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import get_branin_experiment
from botorch.test_functions.base import ConstrainedBaseTestProblem
from botorch.test_functions.multi_fidelity import AugmentedBranin
from botorch.test_functions.multi_objective import BraninCurrin, ConstrainedBraninCurrin
Expand Down Expand Up @@ -322,3 +324,72 @@ def test_maximization_problem(self) -> None:
test_problem_kwargs={},
)
self.assertFalse(test_problem.optimization_config.objective.minimize)

def test_get_oracle_experiment_from_params(self) -> None:
problem = create_problem_from_botorch(
test_problem_class=Branin,
test_problem_kwargs={},
num_trials=5,
)
# first is near optimum
near_opt_params = {"x0": -pi, "x1": 12.275}
other_params = {"x0": 0.5, "x1": 0.5}
unbatched_experiment = problem.get_oracle_experiment_from_params(
{0: {"0": near_opt_params}, 1: {"1": other_params}}
)
self.assertEqual(len(unbatched_experiment.trials), 2)
self.assertTrue(
all(t.status.is_completed for t in unbatched_experiment.trials.values())
)
self.assertTrue(
all(len(t.arms) == 1 for t in unbatched_experiment.trials.values())
)
df = unbatched_experiment.fetch_data().df
self.assertAlmostEqual(df["mean"].iloc[0], Branin._optimal_value, places=5)

batched_experiment = problem.get_oracle_experiment_from_params(
{0: {"0_0": near_opt_params, "0_1": other_params}}
)
self.assertEqual(len(batched_experiment.trials), 1)
self.assertEqual(len(batched_experiment.trials[0].arms), 2)
df = batched_experiment.fetch_data().df
self.assertAlmostEqual(df["mean"].iloc[0], Branin._optimal_value, places=5)

# Test empty inputs
experiment = problem.get_oracle_experiment_from_params({})
self.assertEqual(len(experiment.trials), 0)

with self.assertRaisesRegex(ValueError, "trial with no arms"):
problem.get_oracle_experiment_from_params({0: {}})

def test_get_oracle_experiment_from_experiment(self) -> None:
problem = create_problem_from_botorch(
test_problem_class=Branin,
test_problem_kwargs={"negate": True},
num_trials=5,
)

# empty experiment
empty_experiment = get_branin_experiment(with_trial=False)
oracle_experiment = problem.get_oracle_experiment_from_experiment(
empty_experiment
)
self.assertEqual(oracle_experiment.search_space, problem.search_space)
self.assertEqual(
oracle_experiment.optimization_config, problem.optimization_config
)
self.assertEqual(oracle_experiment.trials.keys(), set())

experiment = get_branin_experiment(
with_trial=True,
search_space=problem.search_space,
with_status_quo=False,
)
oracle_experiment = problem.get_oracle_experiment_from_experiment(
experiment=experiment
)
self.assertEqual(oracle_experiment.search_space, problem.search_space)
self.assertEqual(
oracle_experiment.optimization_config, problem.optimization_config
)
self.assertEqual(oracle_experiment.trials.keys(), experiment.trials.keys())

0 comments on commit 1ca80a2

Please sign in to comment.