Skip to content

Commit

Permalink
Unify interface and data handling of model training and generalizatio…
Browse files Browse the repository at this point in the history
…n metrics (#2367)

Summary:

This commit unifies the interface of and the loading of the model's training data in `_predict_on_cross_validation_data` and `_predict_on_training_data`.  Previously, `_predict_on_training_data` would reload the data from the experiment, which could lead to differences in the number of observations to `_predict_on_cross_validation_data` if the model was not fit on all existing data.

In addition, this commit introduces a `force_refit` option for `get_fitted_model_bridge` which forces a reloading of the data and a refitting of the model to the reloaded data, even if a fitted, potentially out-dated model is on the scheduler.

Differential Revision: D56105161
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 14, 2024
1 parent 64c2733 commit 339afac
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
41 changes: 26 additions & 15 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, observations_from_data
from ax.core.observation import Observation, ObservationData, recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge, unwrap_observation_data
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -515,6 +515,7 @@ def compute_model_fit_metrics_from_modelbridge(
before calcualting the model fit metrics. False by default as models
are trained in transformed space and model fit should be
evaluated in transformed space.
Returns:
A nested dictionary mapping from the *model fit* metric names and the
*experimental metric* names to the values of the model fit metrics.
Expand All @@ -528,13 +529,12 @@ def compute_model_fit_metrics_from_modelbridge(
`coefficient of determination of the test error predictions`
```
"""
y_obs, y_pred, se_pred = (
_predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
predict = (
_predict_on_cross_validation_data
if generalization
else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment)
else _predict_on_training_data
)
y_obs, y_pred, se_pred = predict(model_bridge=model_bridge, untransform=untransform)
if fit_metrics_dict is None:
fit_metrics_dict = {
"coefficient_of_determination": coefficient_of_determination,
Expand All @@ -552,7 +552,7 @@ def compute_model_fit_metrics_from_modelbridge(

def _predict_on_training_data(
model_bridge: ModelBridge,
experiment: Experiment,
untransform: bool = False,
) -> Tuple[
Dict[str, np.ndarray],
Dict[str, np.ndarray],
Expand All @@ -566,16 +566,17 @@ def _predict_on_training_data(
Args:
model_bridge: A ModelBridge object with which to make predictions.
experiment: The experiment with whose data to compute the model fit metrics.
untransform: Boolean indicating whether to untransform model predictions.
Returns:
A tuple containing three dictionaries for 1) observed metric values, and the
model's associated 2) predictive means and 3) predictive standard deviations.
"""
data = experiment.lookup_data()
observations = observations_from_data(
experiment=experiment, data=data
) # List[Observation]
observations = model_bridge.get_training_data() # List[Observation]

# NOTE: the following up to the end of the untransform block could be replaced
# with model_bridge's public predict / private _batch_predict method, if we are
# willing to introduce the boolean untransform flag.

# Transform observations -- this will transform both obs data and features
for t in model_bridge.transforms.values():
Expand All @@ -584,14 +585,24 @@ def _predict_on_training_data(
observation_features = [obs.features for obs in observations]

# Make predictions in transformed space
observation_data_pred = model_bridge._predict(observation_features)
observation_data = model_bridge._predict(observation_features)

if untransform:
# Apply reverse transforms, in reverse order
pred_observations = recombine_observations(
observation_features=observation_features, observation_data=observation_data
)
for t in reversed(list(model_bridge.transforms.values())):
pred_observations = t.untransform_observations(pred_observations)

observation_data = [obs.data for obs in pred_observations]

mean_predicted, cov_predicted = unwrap_observation_data(observation_data_pred)
mean_predicted, cov_predicted = unwrap_observation_data(observation_data)
mean_observed = [
obs.data.means_dict for obs in observations
] # List[Dict[str, float]]

metric_names = list(data.metric_names)
metric_names = observations[0].data.metric_names
mean_observed = _list_of_dicts_to_dict_of_lists(
list_of_dicts=mean_observed, keys=metric_names
)
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ def _fit_current_model(self, data: Optional[Data]) -> None:
# model state from last generator run and pass it to the model
# being instantiated in this function.
model_state_on_lgr = self._get_model_state_from_last_generator_run()

if not data.df.empty:
trial_indices_in_data = sorted(data.df["trial_index"].unique())
logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}")
Expand Down
5 changes: 5 additions & 0 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ def test_model_fit_metrics(self) -> None:
)
# need to run some trials to initialize the ModelBridge
scheduler.run_n_trials(max_trials=NUM_SOBOL + 1)

model_bridge = get_fitted_model_bridge(scheduler)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL)

model_bridge = get_fitted_model_bridge(scheduler, force_refit=True)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL + 1)

# testing compute_model_fit_metrics_from_modelbridge with default metrics
fit_metrics = compute_model_fit_metrics_from_modelbridge(
Expand Down
9 changes: 6 additions & 3 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,19 +2112,22 @@ def _get_failure_rate_exceeded_error(
)


def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge:
def get_fitted_model_bridge(
scheduler: Scheduler, force_refit: bool = False
) -> ModelBridge:
"""Returns a fitted ModelBridge object. If the model is fit already, directly
returns the already fitted model. Otherwise, fits and returns a new one.
Args:
scheduler: The scheduler object from which to get the fitted model.
force_refit: If True, will force a data lookup and a refit of the model.
Returns:
A ModelBridge object fitted to the observations of the scheduler's experiment.
"""
gs = scheduler.standard_generation_strategy
model_bridge = gs.model # Optional[ModelBridge]
if model_bridge is None: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if it none is provided.
if model_bridge is None or force_refit: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if none is provided.
model_bridge = cast(ModelBridge, gs.model)
return model_bridge

0 comments on commit 339afac

Please sign in to comment.