Skip to content

Commit

Permalink
Make GS inherit from GSInterface (#1991)
Browse files Browse the repository at this point in the history
Summary:

Mainly this means implementing `gen_multiple_with_ensembling()`

Reviewed By: lena-kashtelyan

Differential Revision: D51306927
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Nov 30, 2023
1 parent e44f621 commit 21e1573
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 15 deletions.
5 changes: 1 addition & 4 deletions ax/core/generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ class GenerationStrategyInterface(ABC, Base):
_experiment: Optional[Experiment] = None

@abstractmethod
def gen_multiple_with_ensembling(
def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
num_generator_runs: int,
data: Optional[Data] = None,
n: int = 1,
extra_gen_metadata: Optional[Dict[str, Any]] = None,
) -> List[List[GeneratorRun]]:
"""Produce GeneratorRuns for multiple trials at once with the possibility of
ensembling, or using multiple models per trial, getting multiple
Expand All @@ -52,8 +51,6 @@ def gen_multiple_with_ensembling(
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
extra_gen_metadata: A dictionary containing any additional metadata
to be attached to created GeneratorRuns.
Returns:
A list of lists of lists generator runs. Each outer list represents
Expand Down
64 changes: 53 additions & 11 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
import pandas as pd
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.utils import extend_pending_observations
from ax.core.utils import (
extend_pending_observations,
get_pending_observation_features_based_on_trial_status,
)
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.exceptions.generation_strategy import GenerationStrategyCompleted

from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.utils.common.base import Base
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast, not_none

Expand All @@ -39,7 +42,7 @@
)


class GenerationStrategy(Base):
class GenerationStrategy(GenerationStrategyInterface):
"""GenerationStrategy describes which model should be used to generate new
points for which trials, enabling and automating use of different models
throughout the optimization process. For instance, it allows to use one
Expand Down Expand Up @@ -191,14 +194,6 @@ def uses_non_registered_models(self) -> bool:
registered and therefore cannot be stored."""
return not self._uses_registered_models

@property
def last_generator_run(self) -> Optional[GeneratorRun]:
"""Latest generator run produced by this generation strategy.
Returns None if no generator runs have been produced yet.
"""
# Used to restore current model when decoding a serialized GS.
return self._generator_runs[-1] if self._generator_runs else None

@property
def trials_as_df(self) -> Optional[pd.DataFrame]:
"""Puts information on individual trials into a data frame for easy
Expand Down Expand Up @@ -287,6 +282,53 @@ def gen(
**kwargs,
)[0]

def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
num_generator_runs: int,
data: Optional[Data] = None,
n: int = 1,
) -> List[List[GeneratorRun]]:
"""Produce GeneratorRuns for multiple trials at once with the possibility of
ensembling, or using multiple models per trial, getting multiple
GeneratorRuns per trial.
NOTE: This method is in development. Please do not use it yet.
Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
n: Integer representing how many trials should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from ``n``.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
Returns:
A list of lists of lists generator runs. Each outer list represents
a trial being suggested and each inner list represents a generator
run for that trial.
"""
grs = self._gen_multiple(
experiment=experiment,
num_generator_runs=num_generator_runs,
data=data,
n=n,
pending_observations=get_pending_observation_features_based_on_trial_status(
experiment=experiment
),
)
return [[gr] for gr in grs]

def current_generator_run_limit(
self,
) -> Tuple[int, bool]:
Expand Down
88 changes: 88 additions & 0 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,94 @@ def test_gen_multiple(self) -> None:
for p in original_pending[m]:
self.assertIn(p, pending[m])

def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
exp = get_experiment_with_multi_objective()
sobol_GPEI_gs = GenerationStrategy(
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=5,
model_kwargs=self.step_model_kwargs,
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
model_kwargs=self.step_model_kwargs,
),
]
)
with mock_patch_method_original(
mock_path=f"{ModelSpec.__module__}.ModelSpec.gen",
original_method=ModelSpec.gen,
) as model_spec_gen_mock, mock_patch_method_original(
mock_path=f"{ModelSpec.__module__}.ModelSpec.fit",
original_method=ModelSpec.fit,
) as model_spec_fit_mock:
# Generate first four Sobol GRs (one more to gen after that if
# first four become trials.
grs = sobol_GPEI_gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp, num_generator_runs=3
)
self.assertEqual(len(grs), 3)
for gr in grs:
self.assertEqual(len(gr), 1)
self.assertIsInstance(gr[0], GeneratorRun)

# We should only fit once; refitting for each `gen` would be
# wasteful as there is no new data.
model_spec_fit_mock.assert_called_once()
self.assertEqual(model_spec_gen_mock.call_count, 3)
pending_in_each_gen = enumerate(
args_and_kwargs.kwargs.get("pending_observations")
for args_and_kwargs in model_spec_gen_mock.call_args_list
)
for gr, (idx, pending) in zip(grs, pending_in_each_gen):
exp.new_trial(generator_run=gr[0]).mark_running(no_runner_required=True)
if idx > 0:
prev_grs = grs[idx - 1]
for arm in prev_grs[0].arms:
for m in pending:
self.assertIn(ObservationFeatures.from_arm(arm), pending[m])
model_spec_gen_mock.reset_mock()

# Check case with pending features initially specified; we should get two
# GRs now (remaining in Sobol step) even though we requested 3.
original_pending = not_none(get_pending(experiment=exp))
first_3_trials_obs_feats = [
ObservationFeatures.from_arm(arm=a, trial_index=np.int64(idx))
for idx, trial in exp.trials.items()
for a in trial.arms
]
for m in original_pending:
self.assertTrue(
same_elements(original_pending[m], first_3_trials_obs_feats)
)

grs = sobol_GPEI_gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp,
num_generator_runs=3,
)
self.assertEqual(len(grs), 2)
for gr in grs:
self.assertEqual(len(gr), 1)
self.assertIsInstance(gr[0], GeneratorRun)

pending_in_each_gen = enumerate(
args_and_kwargs[1].get("pending_observations")
for args_and_kwargs in model_spec_gen_mock.call_args_list
)
for gr, (idx, pending) in zip(grs, pending_in_each_gen):
exp.new_trial(generator_run=gr[0]).mark_running(no_runner_required=True)
if idx > 0:
prev_grs = grs[idx - 1]
for arm in prev_grs[0].arms:
for m in pending:
# In this case, we should see both the originally-pending
# and the new arms as pending observation features.
self.assertIn(ObservationFeatures.from_arm(arm), pending[m])
for p in original_pending[m]:
self.assertIn(p, pending[m])

# ------------- Testing helpers (put tests above this line) -------------

def _run_GS_for_N_rounds(
Expand Down

0 comments on commit 21e1573

Please sign in to comment.