Skip to content

Commit

Permalink
Merge 339afac into 4579469
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianAment authored Apr 14, 2024
2 parents 4579469 + 339afac commit 3ce51b0
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 21 deletions.
41 changes: 26 additions & 15 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, observations_from_data
from ax.core.observation import Observation, ObservationData, recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge, unwrap_observation_data
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -515,6 +515,7 @@ def compute_model_fit_metrics_from_modelbridge(
before calcualting the model fit metrics. False by default as models
are trained in transformed space and model fit should be
evaluated in transformed space.
Returns:
A nested dictionary mapping from the *model fit* metric names and the
*experimental metric* names to the values of the model fit metrics.
Expand All @@ -528,13 +529,12 @@ def compute_model_fit_metrics_from_modelbridge(
`coefficient of determination of the test error predictions`
```
"""
y_obs, y_pred, se_pred = (
_predict_on_cross_validation_data(
model_bridge=model_bridge, untransform=untransform
)
predict = (
_predict_on_cross_validation_data
if generalization
else _predict_on_training_data(model_bridge=model_bridge, experiment=experiment)
else _predict_on_training_data
)
y_obs, y_pred, se_pred = predict(model_bridge=model_bridge, untransform=untransform)
if fit_metrics_dict is None:
fit_metrics_dict = {
"coefficient_of_determination": coefficient_of_determination,
Expand All @@ -552,7 +552,7 @@ def compute_model_fit_metrics_from_modelbridge(

def _predict_on_training_data(
model_bridge: ModelBridge,
experiment: Experiment,
untransform: bool = False,
) -> Tuple[
Dict[str, np.ndarray],
Dict[str, np.ndarray],
Expand All @@ -566,16 +566,17 @@ def _predict_on_training_data(
Args:
model_bridge: A ModelBridge object with which to make predictions.
experiment: The experiment with whose data to compute the model fit metrics.
untransform: Boolean indicating whether to untransform model predictions.
Returns:
A tuple containing three dictionaries for 1) observed metric values, and the
model's associated 2) predictive means and 3) predictive standard deviations.
"""
data = experiment.lookup_data()
observations = observations_from_data(
experiment=experiment, data=data
) # List[Observation]
observations = model_bridge.get_training_data() # List[Observation]

# NOTE: the following up to the end of the untransform block could be replaced
# with model_bridge's public predict / private _batch_predict method, if we are
# willing to introduce the boolean untransform flag.

# Transform observations -- this will transform both obs data and features
for t in model_bridge.transforms.values():
Expand All @@ -584,14 +585,24 @@ def _predict_on_training_data(
observation_features = [obs.features for obs in observations]

# Make predictions in transformed space
observation_data_pred = model_bridge._predict(observation_features)
observation_data = model_bridge._predict(observation_features)

if untransform:
# Apply reverse transforms, in reverse order
pred_observations = recombine_observations(
observation_features=observation_features, observation_data=observation_data
)
for t in reversed(list(model_bridge.transforms.values())):
pred_observations = t.untransform_observations(pred_observations)

observation_data = [obs.data for obs in pred_observations]

mean_predicted, cov_predicted = unwrap_observation_data(observation_data_pred)
mean_predicted, cov_predicted = unwrap_observation_data(observation_data)
mean_observed = [
obs.data.means_dict for obs in observations
] # List[Dict[str, float]]

metric_names = list(data.metric_names)
metric_names = observations[0].data.metric_names
mean_observed = _list_of_dicts_to_dict_of_lists(
list_of_dicts=mean_observed, keys=metric_names
)
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ def _fit_current_model(self, data: Optional[Data]) -> None:
# model state from last generator run and pass it to the model
# being instantiated in this function.
model_state_on_lgr = self._get_model_state_from_last_generator_run()

if not data.df.empty:
trial_indices_in_data = sorted(data.df["trial_index"].unique())
logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}")
Expand Down
5 changes: 5 additions & 0 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ def test_model_fit_metrics(self) -> None:
)
# need to run some trials to initialize the ModelBridge
scheduler.run_n_trials(max_trials=NUM_SOBOL + 1)

model_bridge = get_fitted_model_bridge(scheduler)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL)

model_bridge = get_fitted_model_bridge(scheduler, force_refit=True)
self.assertEqual(len(model_bridge.get_training_data()), NUM_SOBOL + 1)

# testing compute_model_fit_metrics_from_modelbridge with default metrics
fit_metrics = compute_model_fit_metrics_from_modelbridge(
Expand Down
9 changes: 6 additions & 3 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,19 +2112,22 @@ def _get_failure_rate_exceeded_error(
)


def get_fitted_model_bridge(scheduler: Scheduler) -> ModelBridge:
def get_fitted_model_bridge(
scheduler: Scheduler, force_refit: bool = False
) -> ModelBridge:
"""Returns a fitted ModelBridge object. If the model is fit already, directly
returns the already fitted model. Otherwise, fits and returns a new one.
Args:
scheduler: The scheduler object from which to get the fitted model.
force_refit: If True, will force a data lookup and a refit of the model.
Returns:
A ModelBridge object fitted to the observations of the scheduler's experiment.
"""
gs = scheduler.standard_generation_strategy
model_bridge = gs.model # Optional[ModelBridge]
if model_bridge is None: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if it none is provided.
if model_bridge is None or force_refit: # Need to re-fit the model.
gs._fit_current_model(data=None) # Will lookup_data if none is provided.
model_bridge = cast(ModelBridge, gs.model)
return model_bridge
15 changes: 13 additions & 2 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@

# pyre-strict

from logging import Logger
from typing import Dict, Mapping, Optional, Protocol

import numpy as np

from ax.utils.common.logger import get_logger
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity


logger: Logger = get_logger(__name__)

"""
################################ Model Fit Metrics ###############################
"""
Expand Down Expand Up @@ -140,8 +146,8 @@ def entropy_of_observations(
Args:
y_obs: An array of observations for a single metric.
y_pred: An array of the predicted values corresponding to y_obs.
se_pred: An array of the standard errors of the predicted values.
y_pred: Unused.
se_pred: Unused.
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
for standardized outcomes y_obs. The rank ordering of the results on a set
of y_obs data sets is not generally sensitive to the bandwidth, if it is
Expand All @@ -153,6 +159,11 @@ def entropy_of_observations(
"""
if y_obs.ndim == 1:
y_obs = y_obs[:, np.newaxis]

# Check if standardization was applied to the observations.
y_std = np.std(y_obs, axis=0, ddof=1)
if np.any(y_std < 0.5) or np.any(2.0 < y_std): # allowing a fudge factor of 2.
logger.warning("Standardization of observations was not applied.")
return _entropy_via_kde(y_obs, bandwidth=bandwidth)


Expand Down
16 changes: 16 additions & 0 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ def test_entropy_of_observations(self) -> None:
# ordering of entropies stays the same, though the difference is smaller
self.assertTrue(er2 - ec2 > 3)

# test warning if y is not standardized
module_name = "ax.utils.stats.model_fit_stats"
expected_warning = (
"WARNING:ax.utils.stats.model_fit_stats:Standardization"
" of observations was not applied."
)
with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=10 * yc, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=yc / 10, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

def test_contingency_table_construction(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1, 3, 2, 5, 7, 3])
Expand Down

0 comments on commit 3ce51b0

Please sign in to comment.