Skip to content

Commit

Permalink
Unify interface and data handling of model training and generalizatio…
Browse files Browse the repository at this point in the history
…n metrics (#2367)

Summary:
Pull Request resolved: #2367

This commit unifies the interface of and the loading of the model's training data in `_predict_on_cross_validation_data` and `_predict_on_training_data`.  Previously, `_predict_on_training_data` would reload the data from the experiment, which could lead to differences in the number of observations to `_predict_on_cross_validation_data` if the model was not fit on all existing data.

In addition, this commit introduces a `force_refit` option for `get_fitted_model_bridge` which forces a reloading of the data and a refitting of the model to the reloaded data, even if a fitted, potentially out-dated model is on the scheduler.

Reviewed By: sunnyshen321

Differential Revision: D56105161

fbshipit-source-id: caa32d9cfa1bb57d4ad2206fe6017a92f4541477
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 15, 2024
1 parent d2cd55f commit f8d5377
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 58 deletions.
40 changes: 27 additions & 13 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, observations_from_data
from ax.core.observation import Observation, ObservationData, recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge, unwrap_observation_data
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -515,6 +515,7 @@ def compute_model_fit_metrics_from_modelbridge(
before calcualting the model fit metrics. False by default as models
are trained in transformed space and model fit should be
evaluated in transformed space.
Returns:
A nested dictionary mapping from the *model fit* metric names and the
*experimental metric* names to the values of the model fit metrics.
Expand All @@ -528,12 +529,13 @@ def compute_model_fit_metrics_from_modelbridge(
`coefficient of determination of the test error predictions`
```
"""
y_obs, y_pred, se_pred = (
_predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
predict_func = (
_predict_on_cross_validation_data
if generalization
else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment)
else _predict_on_training_data
)
y_obs, y_pred, se_pred = predict_func(
model_bridge=model_bridge, untransform=untransform
)
if fit_metrics_dict is None:
fit_metrics_dict = {
Expand All @@ -552,7 +554,7 @@ def compute_model_fit_metrics_from_modelbridge(

def _predict_on_training_data(
model_bridge: ModelBridge,
experiment: Experiment,
untransform: bool = False,
) -> Tuple[
Dict[str, np.ndarray],
Dict[str, np.ndarray],
Expand All @@ -566,16 +568,17 @@ def _predict_on_training_data(
Args:
model_bridge: A ModelBridge object with which to make predictions.
experiment: The experiment with whose data to compute the model fit metrics.
untransform: Boolean indicating whether to untransform model predictions.
Returns:
A tuple containing three dictionaries for 1) observed metric values, and the
model's associated 2) predictive means and 3) predictive standard deviations.
"""
data = experiment.lookup_data()
observations = observations_from_data(
experiment=experiment, data=data
) # List[Observation]
observations = model_bridge.get_training_data() # List[Observation]

# NOTE: the following up to the end of the untransform block could be replaced
# with model_bridge's public predict / private _batch_predict method, if the
# latter had a boolean untransform flag.

# Transform observations -- this will transform both obs data and features
for t in model_bridge.transforms.values():
Expand All @@ -586,12 +589,23 @@ def _predict_on_training_data(
# Make predictions in transformed space
observation_data_pred = model_bridge._predict(observation_features)

if untransform:
# Apply reverse transforms, in reverse order
pred_observations = recombine_observations(
observation_features=observation_features,
observation_data=observation_data_pred,
)
for t in reversed(list(model_bridge.transforms.values())):
pred_observations = t.untransform_observations(pred_observations)

observation_data_pred = [obs.data for obs in pred_observations]

mean_predicted, cov_predicted = unwrap_observation_data(observation_data_pred)
mean_observed = [
obs.data.means_dict for obs in observations
] # List[Dict[str, float]]

metric_names = list(data.metric_names)
metric_names = observations[0].data.metric_names
mean_observed = _list_of_dicts_to_dict_of_lists(
list_of_dicts=mean_observed, keys=metric_names
)
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ def _fit_current_model(self, data: Optional[Data]) -> None:
# model state from last generator run and pass it to the model
# being instantiated in this function.
model_state_on_lgr = self._get_model_state_from_last_generator_run()

if not data.df.empty:
trial_indices_in_data = sorted(data.df["trial_index"].unique())
logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}")
Expand Down
95 changes: 54 additions & 41 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import warnings
from itertools import product
from typing import cast, Dict

import numpy as np
Expand All @@ -17,6 +18,7 @@
from ax.metrics.branin import BraninMetric
from ax.modelbridge.cross_validation import (
_predict_on_cross_validation_data,
_predict_on_training_data,
compute_model_fit_metrics_from_modelbridge,
)
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
Expand Down Expand Up @@ -67,7 +69,12 @@ def test_model_fit_metrics(self) -> None:
)
# need to run some trials to initialize the ModelBridge
scheduler.run_n_trials(max_trials=NUM_SOBOL + 1)

model_bridge = get_fitted_model_bridge(scheduler)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL)

model_bridge = get_fitted_model_bridge(scheduler, force_refit=True)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL + 1)

# testing compute_model_fit_metrics_from_modelbridge with default metrics
fit_metrics = compute_model_fit_metrics_from_modelbridge(
Expand All @@ -90,45 +97,51 @@ def test_model_fit_metrics(self) -> None:
self.assertIsInstance(std_branin, float)

# checking non-default model-fit-metric
untransform = False
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=True,
untransform=untransform,
fit_metrics_dict={"Entropy": entropy_of_observations},
)
entropy = fit_metrics.get("Entropy")
self.assertIsInstance(entropy, dict)
entropy = cast(Dict[str, float], entropy)
self.assertTrue("branin" in entropy)
entropy_branin = entropy["branin"]
self.assertIsInstance(entropy_branin, float)

y_obs, _, _ = _predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis]
entropy_truth = _entropy_via_kde(y_obs_branin)
self.assertAlmostEqual(entropy_branin, entropy_truth)
for untransform, generalization in product([True, False], [True, False]):
with self.subTest(untransform=untransform):
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=scheduler.experiment,
generalization=generalization,
untransform=untransform,
fit_metrics_dict={"Entropy": entropy_of_observations},
)
entropy = fit_metrics.get("Entropy")
self.assertIsInstance(entropy, dict)
entropy = cast(Dict[str, float], entropy)
self.assertTrue("branin" in entropy)
entropy_branin = entropy["branin"]
self.assertIsInstance(entropy_branin, float)

# testing with empty metrics
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
fit_metrics_dict={},
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)

# testing log filtering
with warnings.catch_warnings(record=True) as ws:
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=False,
generalization=True,
)
self.assertFalse(
any("Input data is not standardized" in str(w.message) for w in ws)
)
predict = (
_predict_on_cross_validation_data
if generalization
else _predict_on_training_data
)
y_obs, _, _ = predict(
model_bridge=model_bridge, untransform=untransform
)
y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis]
entropy_truth = _entropy_via_kde(y_obs_branin)
self.assertAlmostEqual(entropy_branin, entropy_truth)

# testing with empty metrics
empty_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
fit_metrics_dict={},
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)

# testing log filtering
with warnings.catch_warnings(record=True) as ws:
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=untransform,
generalization=generalization,
)
self.assertFalse(
any("Input data is not standardized" in str(w.message) for w in ws)
)
9 changes: 6 additions & 3 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,19 +2112,22 @@ def _get_failure_rate_exceeded_error(
)


def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge:
def get_fitted_model_bridge(
scheduler: Scheduler, force_refit: bool = False
) -> ModelBridge:
"""Returns a fitted ModelBridge object. If the model is fit already, directly
returns the already fitted model. Otherwise, fits and returns a new one.
Args:
scheduler: The scheduler object from which to get the fitted model.
force_refit: If True, will force a data lookup and a refit of the model.
Returns:
A ModelBridge object fitted to the observations of the scheduler's experiment.
"""
gs = scheduler.standard_generation_strategy
model_bridge = gs.model # Optional[ModelBridge]
if model_bridge is None: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if it none is provided.
if model_bridge is None or force_refit: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if none is provided.
model_bridge = cast(ModelBridge, gs.model)
return model_bridge

0 comments on commit f8d5377

Please sign in to comment.