Skip to content

Commit

Permalink
add option for using posterior predictive in cross-validation
Browse files Browse the repository at this point in the history
Summary:
see title. This change is particularly important for model selection using the NLL if we have noisy observations. Using the posterior over the true function and not the noisy observations gives quite misleading results about model calibration.

I also think that predicted vs actual plots from LOOCV are insightful when using the posterior predictive when the observations are noisy. We may want to consider adding observation_noise to `predict`, but we can do that in a follow-up.

Differential Revision: D58227612
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jun 13, 2024
1 parent 0dce67c commit 45cc46c
Show file tree
Hide file tree
Showing 17 changed files with 142 additions and 22 deletions.
6 changes: 6 additions & 0 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,12 +907,16 @@ def cross_validate(
self,
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
use_posterior_predictive: bool = False,
) -> List[ObservationData]:
"""Make a set of cross-validation predictions.
Args:
cv_training_data: The training data to use for cross validation.
cv_test_points: The test points at which predictions will be made.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise).
Returns:
A list of predictions at the test points.
Expand All @@ -936,6 +940,7 @@ def cross_validate(
search_space=search_space,
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
use_posterior_predictive=use_posterior_predictive,
)
# Apply reverse transforms, in reverse order
cv_test_observations = [
Expand All @@ -952,6 +957,7 @@ def _cross_validate(
search_space: SearchSpace,
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
use_posterior_predictive: bool = False,
) -> List[ObservationData]:
"""Apply the terminal transform, make predictions on the test points,
and reverse terminal transform on the results.
Expand Down
23 changes: 20 additions & 3 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def cross_validate(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
test_selector: Optional[Callable] = None,
untransform: bool = True,
use_posterior_predictive: bool = False,
) -> List[CVResult]:
"""Cross validation for model predictions.
Expand Down Expand Up @@ -110,6 +111,12 @@ def cross_validate(
of the original data in regions where outliers have been removed,
we have found it to better reflect the how good the model used
for candidate generation actually is.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise). Note: we should reconsider how we compute
cross-validation and model fit metrics where there is non-
Gaussian noise.
Returns:
A CVResult for each observation in the training data.
"""
Expand Down Expand Up @@ -160,7 +167,9 @@ def cross_validate(
# Make the prediction
if untransform:
cv_test_predictions = model.cross_validate(
cv_training_data=cv_training_data, cv_test_points=cv_test_points
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
use_posterior_predictive=use_posterior_predictive,
)
else:
# Get test predictions in transformed space
Expand All @@ -184,6 +193,7 @@ def cross_validate(
search_space=search_space,
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
use_posterior_predictive=use_posterior_predictive,
)
# Get test observations in transformed space
cv_test_data = deepcopy(cv_test_data)
Expand All @@ -195,7 +205,9 @@ def cross_validate(
return result


def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResult]:
def cross_validate_by_trial(
model: ModelBridge, trial: int = -1, use_posterior_predictive: bool = False
) -> List[CVResult]:
"""Cross validation for model predictions on a particular trial.
Uses all of the data up until the specified trial to predict each of the
Expand All @@ -204,6 +216,9 @@ def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResul
Args:
model: Fitted model (ModelBridge) to cross validate.
trial: Trial for which predictions are evaluated.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise).
Returns:
A CVResult for each observation in the training data.
Expand Down Expand Up @@ -239,7 +254,9 @@ def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResul
cv_test_data.append(obs)
# Make the prediction
cv_test_predictions = model.cross_validate(
cv_training_data=cv_training_data, cv_test_points=cv_test_points
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
use_posterior_predictive=use_posterior_predictive,
)
# Form CVResult objects
result = [
Expand Down
7 changes: 6 additions & 1 deletion ax/modelbridge/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _cross_validate(
search_space: SearchSpace,
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
use_posterior_predictive: bool = False,
) -> List[ObservationData]:
"""Make predictions at cv_test_points using only the data in obs_feats
and obs_data.
Expand All @@ -208,7 +209,11 @@ def _cross_validate(
]
# Use the model to do the cross validation
f_test, cov_test = self.model.cross_validate(
Xs_train=Xs_train, Ys_train=Ys_train, Yvars_train=Yvars_train, X_test=X_test
Xs_train=Xs_train,
Ys_train=Ys_train,
Yvars_train=Yvars_train,
X_test=X_test,
use_posterior_predictive=use_posterior_predictive,
)
# Convert array back to ObservationData
return array_to_observation_data(f=f_test, cov=cov_test, outcomes=self.outcomes)
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def _cross_validate(
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
parameters: Optional[List[str]] = None,
use_posterior_predictive: bool = False,
**kwargs: Any,
) -> List[ObservationData]:
"""Make predictions at cv_test_points using only the data in obs_feats
Expand All @@ -294,6 +295,7 @@ def _cross_validate(
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
parameters=parameters, # we pass the map_keys too by default
use_posterior_predictive=use_posterior_predictive,
**kwargs,
)
observation_features, observation_data = separate_observations(cv_training_data)
Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _cross_validate(
search_space: SearchSpace,
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
use_posterior_predictive: bool = False,
) -> List[ObservationData]:
raise NotImplementedError

Expand Down
15 changes: 15 additions & 0 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,24 @@ def warn_and_return_mock_obs(
search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]),
cv_training_data=[get_observation2trans()],
cv_test_points=[get_observation1().features], # untransformed after
use_posterior_predictive=False,
)
self.assertTrue(cv_predictions == [get_observation1().data])

# Test use_posterior_predictive in CV
modelbridge.cross_validate(
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
use_posterior_predictive=True,
)

modelbridge._cross_validate.assert_called_with(
search_space=SearchSpace([FixedParameter("x", ParameterType.FLOAT, 8.0)]),
cv_training_data=[get_observation2trans()],
cv_test_points=[get_observation1().features], # untransformed after
use_posterior_predictive=True,
)

# Test stored training data
obs = modelbridge.get_training_data()
self.assertTrue(obs == [get_observation1(), get_observation2()])
Expand Down
28 changes: 27 additions & 1 deletion ax/modelbridge/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def test_CrossValidate(self) -> None:
# Test ModelBridge._cross_validate was called correctly.
z = ma._cross_validate.mock_calls
self.assertEqual(len(z), 3)
ma._cross_validate.assert_called_with(**self.transformed_cv_input_dict)
ma._cross_validate.assert_called_with(
**self.transformed_cv_input_dict, use_posterior_predictive=False
)

# Test selector

Expand All @@ -219,6 +221,21 @@ def test_selector(obs: Observation) -> bool:
)
self.assertTrue(np.array_equal(sorted(all_test), np.array([2.0, 2.0, 3.0])))

# test observation noise
for untransform in (True, False):
result = cross_validate(
model=ma,
folds=-1,
use_posterior_predictive=True,
untransform=untransform,
)
if untransform:
mock_cv = ma.cross_validate
else:
mock_cv = ma._cross_validate
call_kwargs = mock_cv.mock_calls[-1].kwargs
self.assertTrue(call_kwargs["use_posterior_predictive"])

def test_CrossValidateByTrial(self) -> None:
# With only 1 trial
ma = mock.MagicMock()
Expand Down Expand Up @@ -261,6 +278,15 @@ def test_CrossValidateByTrial(self) -> None:
self.assertEqual(len(result), 1)
self.assertEqual(result[0].observed.features.trial_index, 2)

mock_cv = ma.cross_validate
call_kwargs = mock_cv.mock_calls[-1].kwargs
self.assertFalse(call_kwargs["use_posterior_predictive"])

# test observation noise
result = cross_validate_by_trial(model=ma, use_posterior_predictive=True)
call_kwargs = mock_cv.mock_calls[-1].kwargs
self.assertTrue(call_kwargs["use_posterior_predictive"])

def test_cross_validate_gives_a_useful_error_for_model_with_no_data(self) -> None:
exp = get_branin_experiment()
sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space)
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def _cross_validate(
cv_training_data: List[Observation],
cv_test_points: List[ObservationFeatures],
parameters: Optional[List[str]] = None,
use_posterior_predictive: bool = False,
**kwargs: Any,
) -> List[ObservationData]:
"""Make predictions at cv_test_points using only the data in obs_feats
Expand All @@ -453,6 +454,7 @@ def _cross_validate(
datasets=datasets,
X_test=torch.as_tensor(X_test, dtype=self.dtype, device=self.device),
search_space_digest=search_space_digest,
use_posterior_predictive=use_posterior_predictive,
**kwargs,
)
# Convert array back to ObservationData
Expand Down
4 changes: 4 additions & 0 deletions ax/models/discrete_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def cross_validate(
Ys_train: List[List[float]],
Yvars_train: List[List[float]],
X_test: List[TParamValueList],
use_posterior_predictive: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""Do cross validation with the given training and test sets.
Expand All @@ -116,6 +117,9 @@ def cross_validate(
each outcome.
Yvars_train: The variances of each entry in Ys, same shape.
X_test: List of the j parameterizations at which to make predictions.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise).
Returns:
2-element tuple containing
Expand Down
8 changes: 6 additions & 2 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
],
Model,
]
TModelPredictor = Callable[[Model, Tensor], Tuple[Tensor, Tensor]]
TModelPredictor = Callable[[Model, Tensor, bool], Tuple[Tensor, Tensor]]


# pyre-fixme[33]: Aliased annotation cannot contain `Any`.
Expand Down Expand Up @@ -466,6 +466,7 @@ def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed he
self,
datasets: List[SupervisedDataset],
X_test: Tensor,
use_posterior_predictive: bool = False,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
if self._model is None:
Expand All @@ -488,7 +489,10 @@ def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed he
use_loocv_pseudo_likelihood=self.use_loocv_pseudo_likelihood,
**self._kwargs,
)
return self.model_predictor(model=model, X=X_test) # pyre-ignore: [28]
# pyre-ignore: [28]
return self.model_predictor(
model=model, X=X_test, use_posterior_predictive=use_posterior_predictive
)

def feature_importances(self) -> np.ndarray:
return get_feature_importances_from_botorch_model(model=self._model)
Expand Down
14 changes: 11 additions & 3 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,15 @@ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
return f, cov

def predict_from_surrogate(
self, surrogate_label: str, X: Tensor
self,
surrogate_label: str,
X: Tensor,
use_posterior_predictive: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Predict from the Surrogate with the given label."""
return self.surrogates[surrogate_label].predict(X=X)
return self.surrogates[surrogate_label].predict(
X=X, use_posterior_predictive=use_posterior_predictive
)

@copy_doc(TorchModel.gen)
def gen(
Expand Down Expand Up @@ -504,6 +509,7 @@ def cross_validate(
datasets: Sequence[SupervisedDataset],
X_test: Tensor,
search_space_digest: SearchSpaceDigest,
use_posterior_predictive: bool = False,
**additional_model_inputs: Any,
) -> Tuple[Tensor, Tensor]:
# Will fail if metric_names exist across multiple models
Expand Down Expand Up @@ -561,7 +567,9 @@ def cross_validate(
**additional_model_inputs,
)
X_test_prediction = self.predict_from_surrogate(
surrogate_label=surrogate_label, X=X_test
surrogate_label=surrogate_label,
X=X_test,
use_posterior_predictive=use_posterior_predictive,
)
finally:
# Reset the surrogates back to this model's surrogate, make
Expand Down
11 changes: 9 additions & 2 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,17 +594,24 @@ def _discard_cached_model_and_data_if_search_space_digest_changed(
self._last_datasets = {}
self._last_search_space_digest = search_space_digest

def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
def predict(
self, X: Tensor, use_posterior_predictive: bool = False
) -> Tuple[Tensor, Tensor]:
"""Predicts outcomes given an input tensor.
Args:
X: A ``n x d`` tensor of input parameters.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise).
Returns:
Tensor: The predicted posterior mean as an ``n x o``-dim tensor.
Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor.
"""
return predict_from_model(model=self.model, X=X)
return predict_from_model(
model=self.model, X=X, use_posterior_predictive=use_posterior_predictive
)

def best_in_sample_point(
self,
Expand Down
1 change: 1 addition & 0 deletions ax/models/torch/randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def cross_validate( # pyre-ignore [14]: not using metric_names or ssd
self,
datasets: List[SupervisedDataset],
X_test: Tensor,
use_posterior_predictive: bool = False,
) -> Tuple[Tensor, Tensor]:
Xs, Ys, Yvars = _datasets_to_legacy_inputs(datasets=datasets)
cv_models: List[RandomForestRegressor] = []
Expand Down
Loading

0 comments on commit 45cc46c

Please sign in to comment.