Skip to content

Commit

Permalink
allow more candidate trials than max_trials in Scheduler (#2689)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2689

see title. This is useful for queueing trials manually

Reviewed By: bernardbeckerman

Differential Revision: D61508135

fbshipit-source-id: 76b0c8f0671e7b367706edadf22314b781074458
  • Loading branch information
sdaulton authored and facebook-github-bot committed Aug 22, 2024
1 parent c8da4ba commit a163817
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
19 changes: 12 additions & 7 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,13 +920,6 @@ def run_trials_and_yield_results(
n_initial_candidate_trials = len(self.candidate_trials)
if n_initial_candidate_trials == 0 and max_trials < 0:
raise UserInputError(f"Expected `max_trials` >= 0, got {max_trials}.")
elif max_trials < n_initial_candidate_trials:
raise UserInputError(
"The number of pre-attached candidate trials "
f"({n_initial_candidate_trials}) is greater than `max_trials = "
f"{max_trials}`. Increase `max_trials` or reduce the number of "
"pre-attached candidate trials."
)

# trials are pre-existing only if they do not still require running
n_existing = len(self.experiment.trials) - n_initial_candidate_trials
Expand Down Expand Up @@ -1570,6 +1563,7 @@ def _complete_optimization(
num_preexisting_trials=num_preexisting_trials,
status=RunTrialsStatus.SUCCESS,
)
self.warn_if_non_terminal_trials()
return res

def _validate_options(self, options: SchedulerOptions) -> None:
Expand Down Expand Up @@ -2151,6 +2145,17 @@ def _get_failure_rate_exceeded_error(
)
)

def warn_if_non_terminal_trials(self) -> None:
"""Warns if there are any non-terminal trials on the experiment."""
non_terminal_trials = [
t.index for t in self.experiment.trials.values() if not t.status.is_terminal
]
if len(non_terminal_trials) > 0:
self.logger.warning(
f"Found {len(non_terminal_trials)} non-terminal trials on "
f"{self.experiment.name}: {non_terminal_trials}."
)


def get_fitted_model_bridge(
scheduler: Scheduler, force_refit: bool = False
Expand Down
46 changes: 42 additions & 4 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,11 @@ def test_run_preattached_trials_only(self) -> None:
# pyre-fixme[6]: For 1st param expected `Dict[str, Union[None, bool, float,
# int, str]]` but got `Dict[str, int]`.
trial.add_arm(Arm(parameters=parameter_dict))
with self.assertRaisesRegex(
UserInputError, "number of pre-attached candidate trials .* is greater than"
):
scheduler.run_n_trials(max_trials=0)

# check no new trials are run, when max_trials = 0
scheduler.run_n_trials(max_trials=0)
self.assertEqual(trial.status, TrialStatus.CANDIDATE)
# check that candidate trial is run, when max_trials = 1
scheduler.run_n_trials(max_trials=1)
self.assertEqual(len(scheduler.experiment.trials), 1)
self.assertDictEqual(
Expand All @@ -813,6 +814,43 @@ def test_run_preattached_trials_only(self) -> None:
all(t.completed_successfully for t in scheduler.experiment.trials.values())
)

def test_run_multiple_preattached_trials_only(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
# assert that pre-attached trials run when max_trials = number of
# pre-attached trials
scheduler = Scheduler(
experiment=self.branin_experiment, # Has runner and metrics.
generation_strategy=gs,
options=SchedulerOptions(
init_seconds_between_polls=0, # Short between polls so test is fast.
trial_type=TrialType.BATCH_TRIAL,
),
db_settings=self.db_settings_if_always_needed,
)
trial1 = scheduler.experiment.new_trial()
trial1.add_arm(Arm(parameters={"x1": 5, "x2": 5}))
trial2 = scheduler.experiment.new_trial()
trial2.add_arm(Arm(parameters={"x1": 6, "x2": 3}))

# check that first candidate trial is run when called with max_trials = 1
with self.assertLogs(logger="ax.service.scheduler") as lg:
scheduler.run_n_trials(max_trials=1)
self.assertIn(
"Found 1 non-terminal trials on branin_test_experiment: [1]",
lg.output[-1],
)
self.assertIn(trial1.status, [TrialStatus.RUNNING, TrialStatus.COMPLETED])
self.assertEqual(trial2.status, TrialStatus.CANDIDATE)
# check that next candidate trial is run, when max_trials = 1
scheduler.run_n_trials(max_trials=1)
self.assertEqual(len(scheduler.experiment.trials), 2)
self.assertTrue( # Make sure all trials got to complete.
all(t.completed_successfully for t in scheduler.experiment.trials.values())
)

def test_global_stopping(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
Expand Down

0 comments on commit a163817

Please sign in to comment.