From f8d5377aa5470d2cdd317c334218fe292d922397 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Mon, 15 Apr 2024 13:27:24 -0700 Subject: [PATCH] Unify interface and data handling of model training and generalization metrics (#2367) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2367 This commit unifies the interface of and the loading of the model's training data in `_predict_on_cross_validation_data` and `_predict_on_training_data`. Previously, `_predict_on_training_data` would reload the data from the experiment, which could lead to differences in the number of observations to `_predict_on_cross_validation_data` if the model was not fit on all existing data. In addition, this commit introduces a `force_refit` option for `get_fitted_model_bridge` which forces a reloading of the data and a refitting of the model to the reloaded data, even if a fitted, potentially out-dated model is on the scheduler. Reviewed By: sunnyshen321 Differential Revision: D56105161 fbshipit-source-id: caa32d9cfa1bb57d4ad2206fe6017a92f4541477 --- ax/modelbridge/cross_validation.py | 40 +++++--- ax/modelbridge/generation_strategy.py | 1 - .../tests/test_model_fit_metrics.py | 95 +++++++++++-------- ax/service/scheduler.py | 9 +- 4 files changed, 87 insertions(+), 58 deletions(-) diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index fbdae071b33..99e96106419 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -19,7 +19,7 @@ import numpy as np from ax.core.experiment import Experiment -from ax.core.observation import Observation, ObservationData, observations_from_data +from ax.core.observation import Observation, ObservationData, recombine_observations from ax.core.optimization_config import OptimizationConfig from ax.modelbridge.base import ModelBridge, unwrap_observation_data from ax.utils.common.logger import get_logger @@ -515,6 +515,7 @@ def compute_model_fit_metrics_from_modelbridge( before calcualting the model fit metrics. False by default as models are trained in transformed space and model fit should be evaluated in transformed space. + Returns: A nested dictionary mapping from the *model fit* metric names and the *experimental metric* names to the values of the model fit metrics. @@ -528,12 +529,13 @@ def compute_model_fit_metrics_from_modelbridge( `coefficient of determination of the test error predictions` ``` """ - y_obs, y_pred, se_pred = ( - _predict_on_cross_validation_data( - model_bridge=model_bridge, untransform=untransform - ) + predict_func = ( + _predict_on_cross_validation_data if generalization - else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment) + else _predict_on_training_data + ) + y_obs, y_pred, se_pred = predict_func( + model_bridge=model_bridge, untransform=untransform ) if fit_metrics_dict is None: fit_metrics_dict = { @@ -552,7 +554,7 @@ def compute_model_fit_metrics_from_modelbridge( def _predict_on_training_data( model_bridge: ModelBridge, - experiment: Experiment, + untransform: bool = False, ) -> Tuple[ Dict[str, np.ndarray], Dict[str, np.ndarray], @@ -566,16 +568,17 @@ def _predict_on_training_data( Args: model_bridge: A ModelBridge object with which to make predictions. - experiment: The experiment with whose data to compute the model fit metrics. + untransform: Boolean indicating whether to untransform model predictions. Returns: A tuple containing three dictionaries for 1) observed metric values, and the model's associated 2) predictive means and 3) predictive standard deviations. """ - data = experiment.lookup_data() - observations = observations_from_data( - experiment=experiment, data=data - ) # List[Observation] + observations = model_bridge.get_training_data() # List[Observation] + + # NOTE: the following up to the end of the untransform block could be replaced + # with model_bridge's public predict / private _batch_predict method, if the + # latter had a boolean untransform flag. # Transform observations -- this will transform both obs data and features for t in model_bridge.transforms.values(): @@ -586,12 +589,23 @@ def _predict_on_training_data( # Make predictions in transformed space observation_data_pred = model_bridge._predict(observation_features) + if untransform: + # Apply reverse transforms, in reverse order + pred_observations = recombine_observations( + observation_features=observation_features, + observation_data=observation_data_pred, + ) + for t in reversed(list(model_bridge.transforms.values())): + pred_observations = t.untransform_observations(pred_observations) + + observation_data_pred = [obs.data for obs in pred_observations] + mean_predicted, cov_predicted = unwrap_observation_data(observation_data_pred) mean_observed = [ obs.data.means_dict for obs in observations ] # List[Dict[str, float]] - metric_names = list(data.metric_names) + metric_names = observations[0].data.metric_names mean_observed = _list_of_dicts_to_dict_of_lists( list_of_dicts=mean_observed, keys=metric_names ) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 9d292a5a575..bfe163ac635 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -717,7 +717,6 @@ def _fit_current_model(self, data: Optional[Data]) -> None: # model state from last generator run and pass it to the model # being instantiated in this function. model_state_on_lgr = self._get_model_state_from_last_generator_run() - if not data.df.empty: trial_indices_in_data = sorted(data.df["trial_index"].unique()) logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}") diff --git a/ax/modelbridge/tests/test_model_fit_metrics.py b/ax/modelbridge/tests/test_model_fit_metrics.py index 96af249d946..58f55cc9c88 100644 --- a/ax/modelbridge/tests/test_model_fit_metrics.py +++ b/ax/modelbridge/tests/test_model_fit_metrics.py @@ -7,6 +7,7 @@ # pyre-strict import warnings +from itertools import product from typing import cast, Dict import numpy as np @@ -17,6 +18,7 @@ from ax.metrics.branin import BraninMetric from ax.modelbridge.cross_validation import ( _predict_on_cross_validation_data, + _predict_on_training_data, compute_model_fit_metrics_from_modelbridge, ) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy @@ -67,7 +69,12 @@ def test_model_fit_metrics(self) -> None: ) # need to run some trials to initialize the ModelBridge scheduler.run_n_trials(max_trials=NUM_SOBOL + 1) + model_bridge = get_fitted_model_bridge(scheduler) + self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL) + + model_bridge = get_fitted_model_bridge(scheduler, force_refit=True) + self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL + 1) # testing compute_model_fit_metrics_from_modelbridge with default metrics fit_metrics = compute_model_fit_metrics_from_modelbridge( @@ -90,45 +97,51 @@ def test_model_fit_metrics(self) -> None: self.assertIsInstance(std_branin, float) # checking non-default model-fit-metric - untransform = False - fit_metrics = compute_model_fit_metrics_from_modelbridge( - model_bridge=model_bridge, - experiment=scheduler.experiment, - generalization=True, - untransform=untransform, - fit_metrics_dict={"Entropy": entropy_of_observations}, - ) - entropy = fit_metrics.get("Entropy") - self.assertIsInstance(entropy, dict) - entropy = cast(Dict[str, float], entropy) - self.assertTrue("branin" in entropy) - entropy_branin = entropy["branin"] - self.assertIsInstance(entropy_branin, float) - - y_obs, _, _ = _predict_on_cross_validation_data( - model_bridge=model_bridge, untransform=untransform - ) - y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis] - entropy_truth = _entropy_via_kde(y_obs_branin) - self.assertAlmostEqual(entropy_branin, entropy_truth) + for untransform, generalization in product([True, False], [True, False]): + with self.subTest(untransform=untransform): + fit_metrics = compute_model_fit_metrics_from_modelbridge( + model_bridge=model_bridge, + experiment=scheduler.experiment, + generalization=generalization, + untransform=untransform, + fit_metrics_dict={"Entropy": entropy_of_observations}, + ) + entropy = fit_metrics.get("Entropy") + self.assertIsInstance(entropy, dict) + entropy = cast(Dict[str, float], entropy) + self.assertTrue("branin" in entropy) + entropy_branin = entropy["branin"] + self.assertIsInstance(entropy_branin, float) - # testing with empty metrics - empty_metrics = compute_model_fit_metrics_from_modelbridge( - model_bridge=model_bridge, - experiment=self.branin_experiment, - fit_metrics_dict={}, - ) - self.assertIsInstance(empty_metrics, dict) - self.assertTrue(len(empty_metrics) == 0) - - # testing log filtering - with warnings.catch_warnings(record=True) as ws: - fit_metrics = compute_model_fit_metrics_from_modelbridge( - model_bridge=model_bridge, - experiment=self.branin_experiment, - untransform=False, - generalization=True, - ) - self.assertFalse( - any("Input data is not standardized" in str(w.message) for w in ws) - ) + predict = ( + _predict_on_cross_validation_data + if generalization + else _predict_on_training_data + ) + y_obs, _, _ = predict( + model_bridge=model_bridge, untransform=untransform + ) + y_obs_branin = np.array(y_obs["branin"])[:, np.newaxis] + entropy_truth = _entropy_via_kde(y_obs_branin) + self.assertAlmostEqual(entropy_branin, entropy_truth) + + # testing with empty metrics + empty_metrics = compute_model_fit_metrics_from_modelbridge( + model_bridge=model_bridge, + experiment=self.branin_experiment, + fit_metrics_dict={}, + ) + self.assertIsInstance(empty_metrics, dict) + self.assertTrue(len(empty_metrics) == 0) + + # testing log filtering + with warnings.catch_warnings(record=True) as ws: + fit_metrics = compute_model_fit_metrics_from_modelbridge( + model_bridge=model_bridge, + experiment=self.branin_experiment, + untransform=untransform, + generalization=generalization, + ) + self.assertFalse( + any("Input data is not standardized" in str(w.message) for w in ws) + ) diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 3fcfb8639b2..aecd3c7e200 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -2112,19 +2112,22 @@ def _get_failure_rate_exceeded_error( ) -def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge: +def get_fitted_model_bridge( + scheduler: Scheduler, force_refit: bool = False +) -> ModelBridge: """Returns a fitted ModelBridge object. If the model is fit already, directly returns the already fitted model. Otherwise, fits and returns a new one. Args: scheduler: The scheduler object from which to get the fitted model. + force_refit: If True, will force a data lookup and a refit of the model. Returns: A ModelBridge object fitted to the observations of the scheduler's experiment. """ gs = scheduler.standard_generation_strategy model_bridge = gs.model # Optional[ModelBridge] - if model_bridge is None: # Need to re-fit the model. - gs._fit_current_model(data=None) # Will lookup_data if it none is provided. + if model_bridge is None or force_refit: # Need to re-fit the model. + gs._fit_current_model(data=None) # Will lookup_data if none is provided. model_bridge = cast(ModelBridge, gs.model) return model_bridge