Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create utility for trying to load GS #2694

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion ax/service/tests/test_with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.service.utils.with_db_settings_base import WithDBSettingsBase
from ax.service.utils.with_db_settings_base import (
try_load_generation_strategy,
WithDBSettingsBase,
)
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.load import (
_load_experiment,
Expand Down Expand Up @@ -369,3 +372,40 @@ def test_update_experiment_properties_in_db(self) -> None:
experiment.name, decoder=self.with_db_settings.db_settings.decoder
)
self.assertEqual(loaded_experiment._properties, {"test_property": True})

def test_try_load_generation_strategy(self) -> None:
experiment, generation_strategy = self.init_experiment_and_generation_strategy(
save_generation_strategy=False
)
# test logging with no experiment/gs saved
with self.assertLogs(logger="ax.service.utils.with_db_settings_base") as lg:
output = try_load_generation_strategy(
experiment_name=experiment.name,
decoder=self.with_db_settings.db_settings.decoder,
experiment=experiment,
)
self.assertIn(
"There is no generation strategy associated with experiment",
lg.output[0],
)
self.assertIsNone(output)
# test with saved experiment/gs
(
exp_saved,
gs_saved,
) = self.with_db_settings._maybe_save_experiment_and_generation_strategy(
experiment, generation_strategy
)
self.assertTrue(exp_saved)
self.assertTrue(gs_saved)
with self.assertLogs(logger="ax.service.utils.with_db_settings_base") as lg:
output = try_load_generation_strategy(
experiment_name=experiment.name,
decoder=self.with_db_settings.db_settings.decoder,
experiment=experiment,
)
self.assertIn(
"Loaded generation strategy for experiment",
lg.output[0],
)
self.assertEqual(output, generation_strategy)
54 changes: 34 additions & 20 deletions ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,12 @@ def _load_experiment_and_generation_strategy(
f"Loaded experiment {experiment_name} & {num_trials} trials in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)

try:
start_time = time.time()
generation_strategy = _load_generation_strategy_by_experiment_name(
experiment_name=experiment_name,
decoder=self.db_settings.decoder,
experiment=experiment,
reduced_state=reduced_state,
)
logger.info(
f"Loaded generation strategy for experiment {experiment_name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
except ObjectNotFoundError:
logger.info(
"There is no generation strategy associated with experiment "
f"{experiment_name}."
)

return experiment, None
generation_strategy = try_load_generation_strategy(
experiment_name=experiment_name,
decoder=self.db_settings.decoder,
experiment=experiment,
reduced_state=reduced_state,
)

return experiment, generation_strategy

Expand Down Expand Up @@ -626,3 +612,31 @@ def _save_analysis_cards_to_db_if_possible(
analysis_cards=[*analysis_cards],
config=sqa_config,
)


def try_load_generation_strategy(
experiment_name: str,
decoder: Decoder,
experiment: Optional[Experiment] = None,
reduced_state: bool = False,
) -> Optional[GenerationStrategy]:
"""Load generation strategy by experiment name, if it exists."""
try:
start_time = time.time()
generation_strategy = _load_generation_strategy_by_experiment_name(
experiment_name=experiment_name,
decoder=decoder,
experiment=experiment,
reduced_state=reduced_state,
)
logger.info(
f"Loaded generation strategy for experiment {experiment_name} in "
f"{_round_floats_for_logging(time.time() - start_time)} seconds."
)
except ObjectNotFoundError:
logger.info(
"There is no generation strategy associated with experiment "
f"{experiment_name}."
)
return None
return generation_strategy