diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index adeccd4c94b..20d1608aa4c 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -740,8 +740,24 @@ def generation_strategy_from_sqa( decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ) - gs = GenerationStrategy(name=gs_sqa.name, steps=steps) - gs._curr = gs._steps[gs_sqa.curr_index] + nodes = object_from_json( + gs_sqa.nodes, + decoder_registry=self.config.json_decoder_registry, + class_decoder_registry=self.config.json_class_decoder_registry, + ) + + # GenerationStrategies can ony be initialized with either steps or nodes. + # Determine which to use to initialize this GenerationStrategy. + if len(steps) > 0: + gs = GenerationStrategy(name=gs_sqa.name, steps=steps) + gs._curr = gs._steps[gs_sqa.curr_index] + else: + gs = GenerationStrategy(name=gs_sqa.name, nodes=nodes) + curr_node_name = gs_sqa.curr_node_name + for node in gs._nodes: + if node.node_name == curr_node_name: + gs._curr = node + break immutable_ss_and_oc = ( experiment.immutable_search_space_and_opt_config if experiment is not None diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 0b2029f5811..5462498af85 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -840,6 +840,7 @@ def generation_strategy_to_sqa( cast(Type[Base], GenerationStrategy) ] generator_runs_sqa = [] + node_based_strategy = generation_strategy.is_node_based for idx, gr in enumerate(generation_strategy._generator_runs): # Never reduce the state of the last generator run because that # generator run is needed to recreate the model when reloading the @@ -857,10 +858,22 @@ def generation_strategy_to_sqa( generation_strategy._steps, encoder_registry=self.config.json_encoder_registry, class_encoder_registry=self.config.json_class_encoder_registry, - ), - curr_index=generation_strategy.current_step_index, + ) + if not node_based_strategy + else [], + curr_index=generation_strategy.current_step_index + if not node_based_strategy + else -1, generator_runs=generator_runs_sqa, experiment_id=experiment_id, + nodes=object_to_json( + generation_strategy._nodes, + encoder_registry=self.config.json_encoder_registry, + class_encoder_registry=self.config.json_class_encoder_registry, + ) + if node_based_strategy + else [], + curr_node_name=generation_strategy.current_node_name, ) return gs_sqa diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 83a9a8a335f..69a3c97c8a8 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1308,6 +1308,61 @@ def test_EncodeDecodeGenerationStrategy(self) -> None: not_none(new_generation_strategy._experiment)._name, experiment._name ) + def test_EncodeDecodeGenerationNodeBasedGenerationStrategy(self) -> None: + """Test to ensure that GenerationNode based GenerationStrategies are + able to be encoded/decoded correctly. + """ + # we don't support callable models for GenNode based strategies + generation_strategy = get_generation_strategy( + with_generation_nodes=True, with_callable_model_kwarg=False + ) + # Check that we can save a generation strategy without an experiment + # attached. + save_generation_strategy(generation_strategy=generation_strategy) + # Also try restoring this generation strategy by its ID in the DB. + new_generation_strategy = load_generation_strategy_by_id( + # pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`. + gs_id=generation_strategy._db_id + ) + # Some fields of the reloaded GS are not expected to be set (both will be + # set during next model fitting call), so we unset them on the original GS as + # well. + generation_strategy._unset_non_persistent_state_fields() + self.assertEqual(generation_strategy, new_generation_strategy) + self.assertIsNone(generation_strategy._experiment) + + # Cannot load generation strategy before it has been saved + experiment = get_branin_experiment() + save_experiment(experiment) + with self.assertRaises(ObjectNotFoundError): + load_generation_strategy_by_experiment_name(experiment_name=experiment.name) + + # Check that we can encode and decode the generation strategy *after* + # it has generated some trials and been updated with some data. + # Since we now need to `gen`, we remove the fake callable kwarg we added, + # since model does not expect it. + generation_strategy = get_generation_strategy(with_generation_nodes=True) + experiment.new_trial(generation_strategy.gen(experiment=experiment)) + generation_strategy.gen(experiment, data=get_branin_data()) + save_experiment(experiment) + # TODO @mgarrard passes up until this point + save_generation_strategy(generation_strategy=generation_strategy) + # Try restoring the generation strategy using the experiment its + # attached to. + new_generation_strategy = load_generation_strategy_by_experiment_name( + experiment_name=experiment.name + ) + # Some fields of the reloaded GS are not expected to be set (both will be + # set during next model fitting call), so we unset them on the original GS as + # well. + generation_strategy._unset_non_persistent_state_fields() + self.assertEqual(generation_strategy, new_generation_strategy) + self.assertIsInstance(new_generation_strategy._nodes[0].model_enum, Models) + self.assertEqual(len(new_generation_strategy._generator_runs), 2) + self.assertEqual( + not_none(new_generation_strategy._experiment)._name, experiment._name + ) + def test_EncodeDecodeGenerationStrategyReducedState(self) -> None: """Try restoring the generation strategy using the experiment its attached to, passing the experiment object.