From acf6940ac282a41d9e9d36f9bcce015c3be4e4c6 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Tue, 28 Nov 2023 20:08:06 -0800 Subject: [PATCH] use generation_node_name on the GeneratorRun (#2003) Summary: Replace the use of generation_step_index with generation_node_name because GenerationStrategy should no longer access index Differential Revision: D51432075 --- ax/modelbridge/generation_strategy.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 5079b674384..5ca30e83d87 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -508,11 +508,15 @@ def _get_model_state_from_last_generator_run(self) -> Dict[str, Any]: # split them per-model there. model_state_on_lgr = {} model_on_curr = self._curr.model_enum - if ( - lgr is not None - and lgr._generation_step_index == self._curr.index - and lgr._model_state_after_gen - ): + if lgr is None: + return model_state_on_lgr + + if all(isinstance(s, GenerationStep) for s in self._steps): + grs_equal = lgr._generation_step_index == self._curr.index + else: + grs_equal = lgr._generation_node_name == self._curr.node_name + + if grs_equal and lgr._model_state_after_gen: if self.model or isinstance(model_on_curr, ModelRegistryBase): # TODO[drfreund]: Consider moving this to `GenerationStep` or # `GenerationNode`.