From fc7dca6f3fee051788f585824dd2a154265e7299 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Tue, 28 Nov 2023 20:34:55 -0800 Subject: [PATCH] Replace self._curr.model with self._curr.fitted_model (#2019) Summary: We want to use the GenerationNode equivalent of model which is fitted_model instead of the GenerationStep field model Reviewed By: lena-kashtelyan Differential Revision: D51431892 --- ax/modelbridge/generation_node.py | 9 +++++++++ ax/modelbridge/generation_strategy.py | 6 +++--- ax/storage/json_store/tests/test_json_store.py | 9 +++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 3344aa94535..9909e3f59aa 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -170,6 +170,15 @@ def fitted_model(self) -> ModelBridge: """fitted_model from self.model_spec_to_gen_from for convenience""" return self.model_spec_to_gen_from.fitted_model + @property + def _fitted_model(self) -> Optional[ModelBridge]: + """Private property to return optional fitted_model from + self.model_spec_to_gen_from for convenience. If no model is fit, + will return None. If using the non-private `fitted_model` property, + and no model is fit, a UserInput error will be raised. + """ + return self.model_spec_to_gen_from._fitted_model + @property def fixed_features(self) -> Optional[ObservationFeatures]: """fixed_features from self.model_spec_to_gen_from for convenience""" diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 0cb61849132..5079b674384 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -159,7 +159,7 @@ def model(self) -> Optional[ModelBridge]: """Current model in this strategy. Returns None if no model has been set yet (i.e., if no generator runs have been produced from this GS). """ - return self._curr.model_spec._fitted_model + return self._curr._fitted_model @property def experiment(self) -> Experiment: @@ -456,7 +456,7 @@ def _fit_current_model(self, data: Optional[Data]) -> None: logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}") self._curr.fit(experiment=self.experiment, data=data, **model_state_on_lgr) - self._model = self._curr.model_spec.fitted_model + self._model = self._curr._fitted_model def _maybe_move_to_next_step(self, raise_data_required_error: bool = True) -> bool: """Moves this generation strategy to next step if the current step is completed, @@ -507,7 +507,7 @@ def _get_model_state_from_last_generator_run(self) -> Dict[str, Any]: # Potential solution: store generator runs on `GenerationNode`-s and # split them per-model there. model_state_on_lgr = {} - model_on_curr = self._curr.model + model_on_curr = self._curr.model_enum if ( lgr is not None and lgr._generation_step_index == self._curr.index diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index a90d708a26f..3160556dcc8 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -317,6 +317,11 @@ def test_EncodeDecode(self) -> None: original_object.evaluation_function = None converted_object.evaluation_function = None + if class_ == "BenchmarkMethod": + # Some fields of the reloaded GS are not expected to be set (both will + # be set during next model fitting call), so we unset them on the + # original GS as well. + original_object.generation_strategy._unset_non_persistent_state_fields() if isinstance(original_object, torch.nn.Module): self.assertIsInstance( converted_object, @@ -407,6 +412,10 @@ def test_DecodeGenerationStrategy(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) + # Some fields of the reloaded GS are not expected to be set (both will be + # set during next model fitting call), so we unset them on the original GS as + # well. + generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertGreater(len(new_generation_strategy._steps), 0) self.assertIsInstance(new_generation_strategy._steps[0].model, Models)