diff --git a/ax/service/tests/test_with_db_settings_base.py b/ax/service/tests/test_with_db_settings_base.py index e5fae392674..8c6034c9d0d 100644 --- a/ax/service/tests/test_with_db_settings_base.py +++ b/ax/service/tests/test_with_db_settings_base.py @@ -13,7 +13,10 @@ from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.service.utils.with_db_settings_base import WithDBSettingsBase +from ax.service.utils.with_db_settings_base import ( + try_load_generation_strategy, + WithDBSettingsBase, +) from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.storage.sqa_store.load import ( _load_experiment, @@ -369,3 +372,40 @@ def test_update_experiment_properties_in_db(self) -> None: experiment.name, decoder=self.with_db_settings.db_settings.decoder ) self.assertEqual(loaded_experiment._properties, {"test_property": True}) + + def test_try_load_generation_strategy(self) -> None: + experiment, generation_strategy = self.init_experiment_and_generation_strategy( + save_generation_strategy=False + ) + # test logging with no experiment/gs saved + with self.assertLogs(logger="ax.service.utils.with_db_settings_base") as lg: + output = try_load_generation_strategy( + experiment_name=experiment.name, + decoder=self.with_db_settings.db_settings.decoder, + experiment=experiment, + ) + self.assertIn( + "There is no generation strategy associated with experiment", + lg.output[0], + ) + self.assertIsNone(output) + # test with saved experiment/gs + ( + exp_saved, + gs_saved, + ) = self.with_db_settings._maybe_save_experiment_and_generation_strategy( + experiment, generation_strategy + ) + self.assertTrue(exp_saved) + self.assertTrue(gs_saved) + with self.assertLogs(logger="ax.service.utils.with_db_settings_base") as lg: + output = try_load_generation_strategy( + experiment_name=experiment.name, + decoder=self.with_db_settings.db_settings.decoder, + experiment=experiment, + ) + self.assertIn( + "Loaded generation strategy for experiment", + lg.output[0], + ) + self.assertEqual(output, generation_strategy) diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index 8e89c662f55..aa418469232 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -265,26 +265,12 @@ def _load_experiment_and_generation_strategy( f"Loaded experiment {experiment_name} & {num_trials} trials in " f"{_round_floats_for_logging(time.time() - start_time)} seconds." ) - - try: - start_time = time.time() - generation_strategy = _load_generation_strategy_by_experiment_name( - experiment_name=experiment_name, - decoder=self.db_settings.decoder, - experiment=experiment, - reduced_state=reduced_state, - ) - logger.info( - f"Loaded generation strategy for experiment {experiment_name} in " - f"{_round_floats_for_logging(time.time() - start_time)} seconds." - ) - except ObjectNotFoundError: - logger.info( - "There is no generation strategy associated with experiment " - f"{experiment_name}." - ) - - return experiment, None + generation_strategy = try_load_generation_strategy( + experiment_name=experiment_name, + decoder=self.db_settings.decoder, + experiment=experiment, + reduced_state=reduced_state, + ) return experiment, generation_strategy @@ -626,3 +612,31 @@ def _save_analysis_cards_to_db_if_possible( analysis_cards=[*analysis_cards], config=sqa_config, ) + + +def try_load_generation_strategy( + experiment_name: str, + decoder: Decoder, + experiment: Optional[Experiment] = None, + reduced_state: bool = False, +) -> Optional[GenerationStrategy]: + """Load generation strategy by experiment name, if it exists.""" + try: + start_time = time.time() + generation_strategy = _load_generation_strategy_by_experiment_name( + experiment_name=experiment_name, + decoder=decoder, + experiment=experiment, + reduced_state=reduced_state, + ) + logger.info( + f"Loaded generation strategy for experiment {experiment_name} in " + f"{_round_floats_for_logging(time.time() - start_time)} seconds." + ) + except ObjectNotFoundError: + logger.info( + "There is no generation strategy associated with experiment " + f"{experiment_name}." + ) + return None + return generation_strategy