Skip to content

Commit

Permalink
Replace self._curr.model with self._curr.fitted_model (facebook#2019)
Browse files Browse the repository at this point in the history
Summary:

We want to use the GenerationNode equivalent of model which is fitted_model instead of the GenerationStep field model

Reviewed By: lena-kashtelyan

Differential Revision: D51431892
  • Loading branch information
Mia Garrard authored and facebook-github-bot committed Nov 29, 2023
1 parent 14a5615 commit fc7dca6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
9 changes: 9 additions & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ def fitted_model(self) -> ModelBridge:
"""fitted_model from self.model_spec_to_gen_from for convenience"""
return self.model_spec_to_gen_from.fitted_model

@property
def _fitted_model(self) -> Optional[ModelBridge]:
"""Private property to return optional fitted_model from
self.model_spec_to_gen_from for convenience. If no model is fit,
will return None. If using the non-private `fitted_model` property,
and no model is fit, a UserInput error will be raised.
"""
return self.model_spec_to_gen_from._fitted_model

@property
def fixed_features(self) -> Optional[ObservationFeatures]:
"""fixed_features from self.model_spec_to_gen_from for convenience"""
Expand Down
6 changes: 3 additions & 3 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def model(self) -> Optional[ModelBridge]:
"""Current model in this strategy. Returns None if no model has been set
yet (i.e., if no generator runs have been produced from this GS).
"""
return self._curr.model_spec._fitted_model
return self._curr._fitted_model

@property
def experiment(self) -> Experiment:
Expand Down Expand Up @@ -456,7 +456,7 @@ def _fit_current_model(self, data: Optional[Data]) -> None:
logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}")

self._curr.fit(experiment=self.experiment, data=data, **model_state_on_lgr)
self._model = self._curr.model_spec.fitted_model
self._model = self._curr._fitted_model

def _maybe_move_to_next_step(self, raise_data_required_error: bool = True) -> bool:
"""Moves this generation strategy to next step if the current step is completed,
Expand Down Expand Up @@ -507,7 +507,7 @@ def _get_model_state_from_last_generator_run(self) -> Dict[str, Any]:
# Potential solution: store generator runs on `GenerationNode`-s and
# split them per-model there.
model_state_on_lgr = {}
model_on_curr = self._curr.model
model_on_curr = self._curr.model_enum
if (
lgr is not None
and lgr._generation_step_index == self._curr.index
Expand Down
9 changes: 9 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ def test_EncodeDecode(self) -> None:

original_object.evaluation_function = None
converted_object.evaluation_function = None
if class_ == "BenchmarkMethod":
# Some fields of the reloaded GS are not expected to be set (both will
# be set during next model fitting call), so we unset them on the
# original GS as well.
original_object.generation_strategy._unset_non_persistent_state_fields()
if isinstance(original_object, torch.nn.Module):
self.assertIsInstance(
converted_object,
Expand Down Expand Up @@ -407,6 +412,10 @@ def test_DecodeGenerationStrategy(self) -> None:
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
# Some fields of the reloaded GS are not expected to be set (both will be
# set during next model fitting call), so we unset them on the original GS as
# well.
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertGreater(len(new_generation_strategy._steps), 0)
self.assertIsInstance(new_generation_strategy._steps[0].model, Models)
Expand Down

0 comments on commit fc7dca6

Please sign in to comment.