diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index b30f5720024..630c97f26ab 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -16,6 +16,7 @@ from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.metric import Metric from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig @@ -50,6 +51,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.common.timeutils import current_timestamp_in_millis +from ax.utils.common.typeutils import not_none from ax.utils.testing.core_stubs import ( DummyEarlyStoppingStrategy, DummyGlobalStoppingStrategy, @@ -63,8 +65,6 @@ SpecialGenerationStrategy, ) -from pyre_extensions import none_throws - from sqlalchemy.orm.exc import StaleDataError DUMMY_EXCEPTION = "test_exception" @@ -217,9 +217,18 @@ def run_multiple(self, trials: Iterable[BaseTrial]) -> Dict[int, Dict[str, Any]] class TestAxScheduler(TestCase): - """Tests base `Scheduler` functionality.""" + """Tests base `Scheduler` functionality. This test case is meant to + test Scheduler using `GenerationStrategy`, but be extensible so + it can be applied to any type of `GenerationStrategyInterface` + by overriding `GENERATION_STRATEGY_INTERFACE_CLASS` and + `_get_generation_strategy_strategy_for_test()`. You may also need + to subclass and change some specific tests that don't apply to + your specific `GenerationStrategyInterface`.""" + + GENERATION_STRATEGY_INTERFACE_CLASS = GenerationStrategy def setUp(self) -> None: + super().setUp() self.branin_experiment = get_branin_experiment() self.branin_timestamp_map_metric_experiment = ( get_branin_experiment_with_timestamp_map_metric() @@ -235,6 +244,7 @@ def setUp(self) -> None: optimization_config=OptimizationConfig( objective=Objective(metric=Metric(name="branin")) ), + name="branin_experiment_no_impl_runner_or_metrics", ) self.sobol_GPEI_GS = choose_generation_strategy( search_space=get_branin_search_space() @@ -255,29 +265,79 @@ def setUp(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=-1, max_parallelism=1)] ) - def test_init(self) -> None: + def _get_generation_strategy_strategy_for_test( + self, + experiment: Experiment, + generation_strategy: Optional[GenerationStrategy] = None, + ) -> GenerationStrategyInterface: + return not_none(generation_strategy) + + @property + def db_config(self) -> SQAConfig: + encoder_registry = { + SyntheticRunnerWithStatusPolling: runner_to_dict, + **CORE_ENCODER_REGISTRY, + } + decoder_registry = { + SyntheticRunnerWithStatusPolling.__name__: SyntheticRunnerWithStatusPolling, + **CORE_DECODER_REGISTRY, + } + runner_registry = { + SyntheticRunnerWithStatusPolling: 1998, + InfinitePollRunner: 1999, + **CORE_RUNNER_REGISTRY, + } + + return SQAConfig( + json_encoder_registry=encoder_registry, + json_decoder_registry=decoder_registry, + runner_registry=runner_registry, + ) + + @property + def db_settings(self) -> DBSettings: + config = self.db_config + encoder = Encoder(config=config) + decoder = Decoder(config=config) + return DBSettings(encoder=encoder, decoder=decoder) + + def test_init_with_no_impl(self) -> None: with self.assertRaisesRegex( UnsupportedError, "`Scheduler` requires that experiment specifies a `Runner`.", ): - scheduler = Scheduler( + Scheduler( experiment=self.branin_experiment_no_impl_runner_or_metrics, - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment_no_impl_runner_or_metrics, + generation_strategy=self.sobol_GPEI_GS, + ), options=SchedulerOptions(total_trials=10), ) + + def test_init_with_no_impl_with_runner(self) -> None: self.branin_experiment_no_impl_runner_or_metrics.runner = self.runner with self.assertRaisesRegex( UnsupportedError, ".*Metrics {'branin'} do not implement fetching logic.", ): - scheduler = Scheduler( + Scheduler( experiment=self.branin_experiment_no_impl_runner_or_metrics, - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment_no_impl_runner_or_metrics, + generation_strategy=self.sobol_GPEI_GS, + ), options=SchedulerOptions(total_trials=10), ) - scheduler = Scheduler( + + def test_init_with_branin_experiment(self) -> None: + rgs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GPEI_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=rgs, options=SchedulerOptions( total_trials=0, tolerated_trial_failure_rate=0.2, @@ -285,7 +345,7 @@ def test_init(self) -> None: ), ) self.assertEqual(scheduler.experiment, self.branin_experiment) - self.assertEqual(scheduler.generation_strategy, self.sobol_GPEI_GS) + self.assertEqual(scheduler.generation_strategy, rgs) self.assertEqual(scheduler.options.total_trials, 0) self.assertEqual(scheduler.options.tolerated_trial_failure_rate, 0.2) self.assertEqual(scheduler.options.init_seconds_between_polls, 10) @@ -302,15 +362,20 @@ def test_init(self) -> None: ) def test_repr(self) -> None: - scheduler = Scheduler( + branin_gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GPEI_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=branin_gs, options=SchedulerOptions( total_trials=0, tolerated_trial_failure_rate=0.2, init_seconds_between_polls=10, ), ) + self.maxDiff = None self.assertEqual( f"{scheduler}", ( @@ -331,13 +396,17 @@ def test_repr(self) -> None: ) def test_validate_early_stopping_strategy(self) -> None: + branin_gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GPEI_GS, + ) with patch( f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", return_value=False, ), self.assertRaises(ValueError): Scheduler( experiment=self.branin_experiment, - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=branin_gs, options=SchedulerOptions( early_stopping_strategy=DummyEarlyStoppingStrategy() ), @@ -346,25 +415,31 @@ def test_validate_early_stopping_strategy(self) -> None: # should not error Scheduler( experiment=self.branin_experiment, - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=branin_gs, options=SchedulerOptions( early_stopping_strategy=DummyEarlyStoppingStrategy() ), ) - @patch.object( - GenerationStrategy, - "gen_for_multiple_trials_with_multiple_models", - return_value=[[get_generator_run()]], - ) - def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None: - scheduler = Scheduler( + def test_run_multi_arm_generator_run_error(self) -> None: + branin_gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GPEI_GS, - options=SchedulerOptions(total_trials=1), ) - with self.assertRaisesRegex(SchedulerInternalError, ".* only one was expected"): - scheduler.run_all_trials() + with patch.object( + type(branin_gs), + "gen_for_multiple_trials_with_multiple_models", + return_value=[[get_generator_run()]], + ): + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=branin_gs, + options=SchedulerOptions(total_trials=1), + ) + with self.assertRaisesRegex( + SchedulerInternalError, ".* only one was expected" + ): + scheduler.run_all_trials() @patch( # Record calls to function, but still execute it. @@ -377,10 +452,14 @@ def test_run_multi_arm_generator_run_error(self, mock_gen: Mock) -> None: def test_run_all_trials_using_runner_and_metrics( self, mock_get_pending: Mock ) -> None: + branin_gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) # With runners & metrics, `Scheduler.run_all_trials` should run. scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=branin_gs, options=SchedulerOptions( total_trials=8, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -423,9 +502,13 @@ def test_run_all_trials_using_runner_and_metrics( def test_run_all_trials_callback(self) -> None: n_total_trials = 8 + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=n_total_trials, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -447,10 +530,14 @@ def base_run_n_trials( # pyre-fixme[2]: Parameter annotation cannot contain `Any`. idle_callback: Optional[Callable[[Scheduler], Any]], ) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) # With runners & metrics, `Scheduler.run_all_trials` should run. scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, # Short between polls so test is fast. @@ -485,11 +572,15 @@ def _callback(scheduler: Scheduler) -> None: self.assertTrue(test_obj[1] == "apple") def test_run_preattached_trials_only(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) # assert that pre-attached trials run when max_trials = number of # pre-attached trials scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, # Short between polls so test is fast. @@ -518,10 +609,14 @@ def test_run_preattached_trials_only(self) -> None: def test_inferring_reference_point(self) -> None: experiment = get_branin_experiment_with_multi_objective() experiment.runner = self.runner + gs = self._get_generation_strategy_strategy_for_test( + experiment=experiment, + generation_strategy=self.sobol_GS_no_parallelism, + ) scheduler = Scheduler( experiment=experiment, - generation_strategy=self.sobol_GS_no_parallelism, + generation_strategy=gs, options=SchedulerOptions( # Stops the optimization after 5 trials. global_stopping_strategy=DummyGlobalStoppingStrategy( @@ -537,9 +632,13 @@ def test_inferring_reference_point(self) -> None: mock_infer_rp.assert_called_once() def test_global_stopping(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GS_no_parallelism, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.sobol_GS_no_parallelism, + generation_strategy=gs, options=SchedulerOptions( # Stops the optimization after 5 trials. global_stopping_strategy=DummyGlobalStoppingStrategy( @@ -552,9 +651,13 @@ def test_global_stopping(self) -> None: self.assertEqual(scheduler.estimate_global_stopping_savings(), 0.5) def test_ignore_global_stopping(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GS_no_parallelism, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.sobol_GS_no_parallelism, + generation_strategy=gs, options=SchedulerOptions( # Stops the optimization after 5 trials. global_stopping_strategy=DummyGlobalStoppingStrategy( @@ -566,10 +669,14 @@ def test_ignore_global_stopping(self) -> None: self.assertEqual(len(scheduler.experiment.trials), 10) def test_stop_trial(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) # With runners & metrics, `Scheduler.run_all_trials` should run. scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, # Short between polls so test is fast. @@ -585,9 +692,13 @@ def test_stop_trial(self) -> None: @patch(f"{Scheduler.__module__}.MAX_SECONDS_BETWEEN_REPORTS", 2) def test_stop_at_MAX_SECONDS_BETWEEN_REPORTS(self) -> None: self.branin_experiment.runner = InfinitePollRunner() + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GS_no_parallelism, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=8, init_seconds_between_polls=1, # No wait between polls so test is fast. @@ -606,9 +717,13 @@ def test_stop_at_MAX_SECONDS_BETWEEN_REPORTS(self) -> None: ) def test_timeout(self) -> None: - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions( total_trials=8, init_seconds_between_polls=0, # No wait between polls so test is fast. @@ -619,10 +734,14 @@ def test_timeout(self) -> None: self.assertIn("aborted", scheduler.experiment._properties["run_trials_success"]) def test_logging(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GPEI_GS, + ) with NamedTemporaryFile() as temp_file: Scheduler( experiment=self.branin_experiment, - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=1, init_seconds_between_polls=0, # No wait bw polls so test is fast. @@ -634,12 +753,16 @@ def test_logging(self) -> None: temp_file.close() def test_logging_level(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GPEI_GS, + ) # We don't have any warnings yet, so warning level of logging shouldn't yield # any logs as of now. with NamedTemporaryFile() as temp_file: Scheduler( experiment=self.branin_experiment, - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=3, init_seconds_between_polls=0, # No wait bw polls so test is fast. @@ -652,11 +775,15 @@ def test_logging_level(self) -> None: temp_file.close() def test_retries(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) # Check that retries will be performed for a retriable error. self.branin_experiment.runner = BrokenRunnerRuntimeError() scheduler = Scheduler( experiment=self.branin_experiment, - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions(total_trials=1), ) # Should raise after 3 retries. @@ -666,12 +793,16 @@ def test_retries(self) -> None: self.assertEqual(scheduler.run_trial_call_count, 3) def test_retries_nonretriable_error(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) # Check that no retries will be performed for `ValueError`, since we # exclude it from the retriable errors. self.branin_experiment.runner = BrokenRunnerValueError() scheduler = Scheduler( experiment=self.branin_experiment, - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions(total_trials=1), ) # Should raise right away since ValueError is non-retriable. @@ -681,9 +812,13 @@ def test_retries_nonretriable_error(self) -> None: self.assertEqual(scheduler.run_trial_call_count, 1) def test_set_ttl(self) -> None: - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions( total_trials=2, ttl_seconds_for_trials=1, @@ -705,9 +840,13 @@ def test_failure_rate(self) -> None: ) self.branin_experiment.runner = RunnerWithFrequentFailedTrials() - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GS_no_parallelism, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=options, ) with self.assertRaises(FailureRateExceededError): @@ -721,9 +860,13 @@ def test_failure_rate(self) -> None: # fail after only 2 trials. num_preexisting_trials = len(scheduler.experiment.trials) self.branin_experiment.runner = RunnerWithAllFailedTrials() - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GS_no_parallelism, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=options, ) self.assertEqual(scheduler._num_preexisting_trials, num_preexisting_trials) @@ -731,71 +874,63 @@ def test_failure_rate(self) -> None: scheduler.run_all_trials() self.assertEqual(len(scheduler.experiment.trials), num_preexisting_trials + 2) - def test_sqa_storage(self) -> None: + def test_sqa_storage_without_experiment_name(self) -> None: init_test_engine_and_session_factory(force_init=True) - encoder_registry = { - SyntheticRunnerWithStatusPolling: runner_to_dict, - **CORE_ENCODER_REGISTRY, - } - decoder_registry = { - SyntheticRunnerWithStatusPolling.__name__: SyntheticRunnerWithStatusPolling, - **CORE_DECODER_REGISTRY, - } - runner_registry = { - SyntheticRunnerWithStatusPolling: 1998, - **CORE_RUNNER_REGISTRY, - } - - config = SQAConfig( - json_encoder_registry=encoder_registry, - json_decoder_registry=decoder_registry, - runner_registry=runner_registry, + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, ) - encoder = Encoder(config=config) - decoder = Decoder(config=config) - db_settings = DBSettings(encoder=encoder, decoder=decoder) - experiment = self.branin_experiment # Scheduler currently requires that the experiment be pre-saved. with self.assertRaisesRegex(ValueError, ".* must specify a name"): - experiment._name = None - scheduler = Scheduler( - experiment=experiment, - generation_strategy=self.two_sobol_steps_GS, + self.branin_experiment._name = None + Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions(total_trials=1), - db_settings=db_settings, + db_settings=self.db_settings, ) - experiment._name = "test_experiment" + + def test_sqa_storage_with_experiment_name(self) -> None: + init_test_engine_and_session_factory(force_init=True) + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + self.assertIsNotNone(self.branin_experiment) NUM_TRIALS = 5 scheduler = Scheduler( - experiment=experiment, - generation_strategy=self.two_sobol_steps_GS, + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions( total_trials=NUM_TRIALS, init_seconds_between_polls=0, # No wait between polls so test is fast. ), - db_settings=db_settings, + db_settings=self.db_settings, ) # Check that experiment and GS were saved. - exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name) - self.assertEqual(exp, experiment) - self.assertEqual(gs, self.two_sobol_steps_GS) + exp, loaded_gs = scheduler._load_experiment_and_generation_strategy( + self.branin_experiment.name + ) + self.assertEqual(exp, self.branin_experiment) + self.assertEqual( + len(gs._generator_runs), len(not_none(loaded_gs)._generator_runs) + ) scheduler.run_all_trials() # Check that experiment and GS were saved and test reloading with reduced state. - exp, gs = scheduler._load_experiment_and_generation_strategy( - experiment.name, reduced_state=True + exp, loaded_gs = scheduler._load_experiment_and_generation_strategy( + self.branin_experiment.name, reduced_state=True ) # pyre-fixme[16]: `Optional` has no attribute `trials`. self.assertEqual(len(exp.trials), NUM_TRIALS) - # pyre-fixme[16]: `Optional` has no attribute `_generator_runs`. - self.assertEqual(len(gs._generator_runs), NUM_TRIALS) - # Test `from_stored_experiment`. + # Because of RGS, gs has queued additional unused candidates + self.assertGreaterEqual(len(gs._generator_runs), NUM_TRIALS) new_scheduler = Scheduler.from_stored_experiment( - experiment_name=experiment.name, + experiment_name=self.branin_experiment.name, options=SchedulerOptions( total_trials=NUM_TRIALS + 1, init_seconds_between_polls=0, # No wait between polls so test is fast. ), - db_settings=db_settings, + db_settings=self.db_settings, ) # Hack "resumed from storage timestamp" into `exp` to make sure all other fields # are equal, since difference in resumed from storage timestamps is expected. @@ -806,7 +941,10 @@ def test_sqa_storage(self) -> None: ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS ] self.assertEqual(new_scheduler.experiment, exp) - self.assertEqual(new_scheduler.generation_strategy, gs) + self.assertEqual( + len(gs._generator_runs), + len(new_scheduler.generation_strategy._generator_runs), + ) self.assertEqual( len( new_scheduler.experiment._properties[ @@ -818,9 +956,13 @@ def test_sqa_storage(self) -> None: def test_run_trials_and_yield_results(self) -> None: total_trials = 3 + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = TestScheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( init_seconds_between_polls=0, ), @@ -847,9 +989,13 @@ def test_run_trials_and_yield_results(self) -> None: def test_run_trials_and_yield_results_with_early_stopper(self) -> None: total_trials = 3 self.branin_experiment.runner = InfinitePollRunner() + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = EarlyStopsInsteadOfNormalCompletionScheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, @@ -903,9 +1049,13 @@ def should_stop_trials_early( self.branin_timestamp_map_metric_experiment.runner = ( RunnerWithEarlyStoppingStrategy() ) - scheduler = TestScheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_timestamp_map_metric_experiment, generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = TestScheduler( + experiment=self.branin_timestamp_map_metric_experiment, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, @@ -956,20 +1106,23 @@ def should_stop_trials_early( self.assertAlmostEqual(scheduler.estimate_early_stopping_savings(), 0.5) def test_run_trials_in_batches(self) -> None: - # TODO[drfreund]: Use `Runner` instead when `poll_available_capacity` - # is moved to `Runner` - class PollAvailableCapacityScheduler(Scheduler): - def poll_available_capacity(self) -> None: - return 2 - - scheduler = PollAvailableCapacityScheduler( - experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, - options=SchedulerOptions( - init_seconds_between_polls=0, - run_trials_in_batches=True, - ), - ) + with patch.object( + type(self.branin_experiment.runner), + "poll_available_capacity", + return_value=2, + ): + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, # Has runner and metrics. + generation_strategy=gs, + options=SchedulerOptions( + init_seconds_between_polls=0, + run_trials_in_batches=True, + ), + ) with patch.object( scheduler, "run_trials", side_effect=scheduler.run_trials @@ -981,9 +1134,13 @@ def poll_available_capacity(self) -> None: def test_base_report_results(self) -> None: self.branin_experiment.runner = NoReportResultsRunner() + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( init_seconds_between_polls=0, ), @@ -997,9 +1154,13 @@ def test_base_report_results(self) -> None: ) def test_optimization_complete(self, _) -> None: # With runners & metrics, `Scheduler.run_all_trials` should run. + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( max_pending_trials=100, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -1025,9 +1186,13 @@ def test_suppress_all_storage_errors(self, mock_save_exp: Mock, _) -> None: encoder = Encoder(config=config) decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( max_pending_trials=100, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -1040,9 +1205,13 @@ def test_suppress_all_storage_errors(self, mock_save_exp: Mock, _) -> None: def test_max_pending_trials(self) -> None: # With runners & metrics, `BareBonesTestScheduler.run_all_trials` should run. + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.sobol_GPEI_GS, + ) scheduler = TestScheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.sobol_GPEI_GS, + generation_strategy=gs, options=SchedulerOptions( max_pending_trials=1, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -1073,9 +1242,13 @@ def test_max_pending_trials(self) -> None: last_n_completed = curr_n_completed def test_get_best_trial(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, # Short between polls so test is fast. @@ -1086,9 +1259,9 @@ def test_get_best_trial(self) -> None: scheduler.run_n_trials(max_trials=1) - trial, params, _arm = none_throws(scheduler.get_best_trial()) - just_params, _just_arm = none_throws(scheduler.get_best_parameters()) - just_params_unmodeled, _just_arm_unmodled = none_throws( + trial, params, _arm = not_none(scheduler.get_best_trial()) + just_params, _just_arm = not_none(scheduler.get_best_parameters()) + just_params_unmodeled, _just_arm_unmodled = not_none( scheduler.get_best_parameters(use_model_predictions=False) ) with self.assertRaisesRegex( @@ -1123,9 +1296,14 @@ def test_get_best_trial_moo(self) -> None: experiment = get_branin_experiment_with_multi_objective() experiment.runner = self.runner - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=experiment, generation_strategy=self.sobol_GPEI_GS, + ) + + scheduler = Scheduler( + experiment=experiment, + generation_strategy=gs, # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. options=SchedulerOptions(init_seconds_between_polls=0.1), ) @@ -1145,9 +1323,13 @@ def test_get_best_trial_moo(self) -> None: self.assertIsNotNone(scheduler.get_pareto_optimal_parameters()) def test_batch_trial(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( # pyre-fixme[6]: For 1st param expected `Optional[int]` but got `float`. init_seconds_between_polls=0.1, # Short between polls so test is fast. @@ -1168,9 +1350,13 @@ def test_poll_and_process_results_with_reasons(self) -> None: ) self.branin_experiment.runner = RunnerWithFailedAndAbandonedTrials() - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GS_no_parallelism, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=options, ) @@ -1204,6 +1390,10 @@ def test_poll_and_process_results_with_reasons(self) -> None: def test_fetch_and_process_trials_data_results_failed_objective_available_while_running( # noqa self, ) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) with patch( f"{BraninTimestampMapMetric.__module__}.BraninTimestampMapMetric.f", side_effect=[Exception("yikes!"), {"mean": 0, "timestamp": 12345}], @@ -1221,7 +1411,7 @@ def test_fetch_and_process_trials_data_results_failed_objective_available_while_ ) as lg: scheduler = Scheduler( experiment=get_branin_experiment_with_timestamp_map_metric(), - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions(), ) scheduler.run_n_trials(max_trials=1) @@ -1239,12 +1429,16 @@ def test_fetch_and_process_trials_data_results_failed_objective_available_while_ def test_fetch_and_process_trials_data_results_failed_non_objective( self, ) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) with patch( f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!") ), self.assertLogs(logger="ax.service.scheduler") as lg: scheduler = Scheduler( - experiment=get_branin_experiment_with_timestamp_map_metric(), - generation_strategy=self.two_sobol_steps_GS, + experiment=self.branin_timestamp_map_metric_experiment, + generation_strategy=gs, options=SchedulerOptions(), ) scheduler.run_n_trials(max_trials=1) @@ -1260,6 +1454,10 @@ def test_fetch_and_process_trials_data_results_failed_non_objective( ) def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) with patch( f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!") ), patch( @@ -1269,8 +1467,8 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: logger="ax.service.scheduler" ) as lg: scheduler = Scheduler( - experiment=get_branin_experiment(), - generation_strategy=self.two_sobol_steps_GS, + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions(), ) @@ -1295,9 +1493,13 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: def test_completion_criterion(self) -> None: # Tests non-GSS parts of the completion criterion. - scheduler = Scheduler( + gs = self._get_generation_strategy_strategy_for_test( experiment=self.branin_experiment, generation_strategy=self.sobol_GPEI_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions( total_trials=None, ), @@ -1324,20 +1526,9 @@ def test_completion_criterion(self) -> None: self.assertEqual(message, "Exceeding the total number of trials.") def test_get_fitted_model_bridge(self) -> None: - # setting up experiment and generation strategy - branin_experiment = Experiment( - name="branin_test_experiment", - search_space=get_branin_search_space(), - runner=SyntheticRunner(), - optimization_config=OptimizationConfig( - objective=Objective( - metric=BraninMetric(name="branin", param_names=["x1", "x2"]), - minimize=True, - ), - ), - is_test=True, - ) - branin_experiment._properties[Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF] = True + self.branin_experiment._properties[ + Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF + ] = True # generation strategy NUM_SOBOL = 5 generation_strategy = GenerationStrategy( @@ -1348,15 +1539,25 @@ def test_get_fitted_model_bridge(self) -> None: GenerationStep(model=Models.GPEI, num_trials=-1), ] ) - scheduler = Scheduler( - experiment=branin_experiment, + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, generation_strategy=generation_strategy, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, options=SchedulerOptions(), ) # need to run some trials to initialize the ModelBridge scheduler.run_n_trials(max_trials=NUM_SOBOL + 1) + self._helper_path_that_refits_the_model_if_it_is_not_already_initialized( + scheduler=scheduler, + ) - # testing path that refits the _model, if it is not already initialized + def _helper_path_that_refits_the_model_if_it_is_not_already_initialized( + self, + scheduler: Scheduler, + ) -> None: with patch.object( GenerationStrategy, "model", @@ -1366,7 +1567,7 @@ def test_get_fitted_model_bridge(self) -> None: with patch.object( GenerationStrategy, "_fit_current_model", - wraps=generation_strategy._fit_current_model, + wraps=scheduler.standard_generation_strategy._fit_current_model, ) as fit_model: get_fitted_model_bridge(scheduler) fit_model.assert_called_once() @@ -1424,9 +1625,14 @@ def test_standard_generation_strategy(self) -> None: def test_get_improvement_over_baseline(self) -> None: n_total_trials = 8 + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=n_total_trials, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -1452,6 +1658,11 @@ def test_get_improvement_over_baseline_robustness(self) -> None: experiment = get_branin_experiment_with_multi_objective() experiment.runner = self.runner + gs = self._get_generation_strategy_strategy_for_test( + experiment=experiment, + generation_strategy=self.sobol_GPEI_GS, + ) + scheduler = Scheduler( experiment=experiment, generation_strategy=self.sobol_GPEI_GS, @@ -1464,9 +1675,13 @@ def test_get_improvement_over_baseline_robustness(self) -> None: baseline_arm_name=None, ) + gs = self._get_generation_strategy_strategy_for_test( + experiment=experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=2, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`. @@ -1496,10 +1711,14 @@ def test_get_improvement_over_baseline_no_baseline(self) -> None: """Test that get_improvement_over_baseline returns UserInputError when baseline is not found in data.""" n_total_trials = 8 + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) scheduler = Scheduler( experiment=self.branin_experiment, # Has runner and metrics. - generation_strategy=self.two_sobol_steps_GS, + generation_strategy=gs, options=SchedulerOptions( total_trials=n_total_trials, # pyre-fixme[6]: For 2nd param expected `Optional[int]` but got `float`.