Skip to content

Commit

Permalink
Update sqa storage to include generation nodes (#2076)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2076

This diff does the following:
update the sqa storage to include generationnodes

In coming diffs:
(2) delete now unused GenStep functions
(3) final pass on all the doc strings and variables -- lots to clean up here
(4) add transition criterion to the repr string + some of the other fields that havent made it yet on GeneratinoNode
(5) Do a final pass of the generationStrategy/GenerationNode files to see what else can be migrated/condensed
(6) rename transiton criterion to action criterion
(7) remove conditionals for legacy usecase
( clean up any lingering todos

Reviewed By: lena-kashtelyan

Differential Revision: D51970237

fbshipit-source-id: 107d763d7dfb4f929c3c8a4f47514f75a81e8c48
  • Loading branch information
mgarrard authored and facebook-github-bot committed Dec 15, 2023
1 parent a3f1844 commit c397f1f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
20 changes: 18 additions & 2 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,24 @@ def generation_strategy_from_sqa(
decoder_registry=self.config.json_decoder_registry,
class_decoder_registry=self.config.json_class_decoder_registry,
)
gs = GenerationStrategy(name=gs_sqa.name, steps=steps)
gs._curr = gs._steps[gs_sqa.curr_index]
nodes = object_from_json(
gs_sqa.nodes,
decoder_registry=self.config.json_decoder_registry,
class_decoder_registry=self.config.json_class_decoder_registry,
)

# GenerationStrategies can ony be initialized with either steps or nodes.
# Determine which to use to initialize this GenerationStrategy.
if len(steps) > 0:
gs = GenerationStrategy(name=gs_sqa.name, steps=steps)
gs._curr = gs._steps[gs_sqa.curr_index]
else:
gs = GenerationStrategy(name=gs_sqa.name, nodes=nodes)
curr_node_name = gs_sqa.curr_node_name
for node in gs._nodes:
if node.node_name == curr_node_name:
gs._curr = node
break
immutable_ss_and_oc = (
experiment.immutable_search_space_and_opt_config
if experiment is not None
Expand Down
17 changes: 15 additions & 2 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def generation_strategy_to_sqa(
cast(Type[Base], GenerationStrategy)
]
generator_runs_sqa = []
node_based_strategy = generation_strategy.is_node_based
for idx, gr in enumerate(generation_strategy._generator_runs):
# Never reduce the state of the last generator run because that
# generator run is needed to recreate the model when reloading the
Expand All @@ -857,10 +858,22 @@ def generation_strategy_to_sqa(
generation_strategy._steps,
encoder_registry=self.config.json_encoder_registry,
class_encoder_registry=self.config.json_class_encoder_registry,
),
curr_index=generation_strategy.current_step_index,
)
if not node_based_strategy
else [],
curr_index=generation_strategy.current_step_index
if not node_based_strategy
else -1,
generator_runs=generator_runs_sqa,
experiment_id=experiment_id,
nodes=object_to_json(
generation_strategy._nodes,
encoder_registry=self.config.json_encoder_registry,
class_encoder_registry=self.config.json_class_encoder_registry,
)
if node_based_strategy
else [],
curr_node_name=generation_strategy.current_node_name,
)
return gs_sqa

Expand Down
55 changes: 55 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,61 @@ def test_EncodeDecodeGenerationStrategy(self) -> None:
not_none(new_generation_strategy._experiment)._name, experiment._name
)

def test_EncodeDecodeGenerationNodeBasedGenerationStrategy(self) -> None:
"""Test to ensure that GenerationNode based GenerationStrategies are
able to be encoded/decoded correctly.
"""
# we don't support callable models for GenNode based strategies
generation_strategy = get_generation_strategy(
with_generation_nodes=True, with_callable_model_kwarg=False
)
# Check that we can save a generation strategy without an experiment
# attached.
save_generation_strategy(generation_strategy=generation_strategy)
# Also try restoring this generation strategy by its ID in the DB.
new_generation_strategy = load_generation_strategy_by_id(
# pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`.
gs_id=generation_strategy._db_id
)
# Some fields of the reloaded GS are not expected to be set (both will be
# set during next model fitting call), so we unset them on the original GS as
# well.
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertIsNone(generation_strategy._experiment)

# Cannot load generation strategy before it has been saved
experiment = get_branin_experiment()
save_experiment(experiment)
with self.assertRaises(ObjectNotFoundError):
load_generation_strategy_by_experiment_name(experiment_name=experiment.name)

# Check that we can encode and decode the generation strategy *after*
# it has generated some trials and been updated with some data.
# Since we now need to `gen`, we remove the fake callable kwarg we added,
# since model does not expect it.
generation_strategy = get_generation_strategy(with_generation_nodes=True)
experiment.new_trial(generation_strategy.gen(experiment=experiment))
generation_strategy.gen(experiment, data=get_branin_data())
save_experiment(experiment)
# TODO @mgarrard passes up until this point
save_generation_strategy(generation_strategy=generation_strategy)
# Try restoring the generation strategy using the experiment its
# attached to.
new_generation_strategy = load_generation_strategy_by_experiment_name(
experiment_name=experiment.name
)
# Some fields of the reloaded GS are not expected to be set (both will be
# set during next model fitting call), so we unset them on the original GS as
# well.
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertIsInstance(new_generation_strategy._nodes[0].model_enum, Models)
self.assertEqual(len(new_generation_strategy._generator_runs), 2)
self.assertEqual(
not_none(new_generation_strategy._experiment)._name, experiment._name
)

def test_EncodeDecodeGenerationStrategyReducedState(self) -> None:
"""Try restoring the generation strategy using the experiment its attached to,
passing the experiment object.
Expand Down

0 comments on commit c397f1f

Please sign in to comment.