diff --git a/ax/modelbridge/best_model_selector.py b/ax/modelbridge/best_model_selector.py new file mode 100644 index 00000000000..88335e0fce2 --- /dev/null +++ b/ax/modelbridge/best_model_selector.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from functools import partial +from typing import Callable, List, Union + +import numpy as np +from ax.exceptions.core import UserInputError +from ax.modelbridge.model_spec import ModelSpec +from ax.utils.common.typeutils import not_none + +ARRAYLIKE = Union[np.ndarray, List[float], List[np.ndarray]] + + +class BestModelSelector(ABC): + @abstractmethod + def best_model(self, model_specs: List[ModelSpec]) -> int: + """ + Return the index of the best ``ModelSpec``. + """ + + +class ReductionCriterion(Enum): + """An enum for callables that are used for aggregating diagnostics over metrics + and selecting the best diagnostic in ``SingleDiagnosticBestModelSelector``. + + NOTE: This is used to ensure serializability of the callables. + """ + + # NOTE: Callables need to be wrapped in `partial` to be registered as members. + MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean) + MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min) + MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max) + + def __call__(self, array_like: ARRAYLIKE) -> np.ndarray: + return self.value(array_like) + + +class SingleDiagnosticBestModelSelector(BestModelSelector): + """Choose the best model using a single cross-validation diagnostic. + + The input is a list of ``ModelSpec``, each corresponding to one model. + The specified diagnostic is extracted from each of the models, + its values (each of which corresponds to a separate metric) are + aggregated with the aggregation function, the best one is determined + with the criterion, and the index of the best diagnostic result is returned. + + Example: + :: + s = SingleDiagnosticBestModelSelector( + diagnostic = 'Fisher exact test p', + metric_aggregation = ReductionCriterion.MEAN, + criterion = ReductionCriterion.MIN, + ) + best_diagnostic_index = s.best_diagnostic(diagnostics) + + Args: + diagnostic: The name of the diagnostic to use, which should be + a key in ``CVDiagnostic``. + metric_aggregation: ``ReductionCriterion`` applied to the values of the + diagnostic for a single model to produce a single number. + criterion: ``ReductionCriterion`` used to determine which of the + (aggregated) diagnostics is the best. + + Returns: + int: index of the selected best diagnostic. + """ + + def __init__( + self, + diagnostic: str, + metric_aggregation: ReductionCriterion, + criterion: ReductionCriterion, + ) -> None: + self.diagnostic = diagnostic + if not isinstance(metric_aggregation, ReductionCriterion) or not isinstance( + criterion, ReductionCriterion + ): + raise UserInputError( + "Both `metric_aggregation` and `criterion` must be " + f"`ReductionCriterion`. Got {metric_aggregation=}, {criterion=}." + ) + if criterion == ReductionCriterion.MEAN: + raise UserInputError( + f"{criterion=} is not supported. Please use MIN or MAX." + ) + self.metric_aggregation = metric_aggregation + self.criterion = criterion + + def best_model(self, model_specs: List[ModelSpec]) -> int: + for model_spec in model_specs: + model_spec.cross_validate() + aggregated_diagnostic_values = [ + self.metric_aggregation( + list(not_none(model_spec.diagnostics)[self.diagnostic].values()) + ) + for model_spec in model_specs + ] + best_diagnostic = self.criterion(aggregated_diagnostic_values).item() + return aggregated_diagnostic_values.index(best_diagnostic) diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 01f94a6ec0f..4e7f55a95f0 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -5,18 +5,14 @@ # LICENSE file in the root directory of this source tree. # pyre-strict + from __future__ import annotations import warnings -from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy -from enum import Enum -from functools import partial - from logging import Logger -from numbers import Number -from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple +from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple from warnings import warn import numpy as np @@ -24,7 +20,6 @@ from ax.core.optimization_config import OptimizationConfig from ax.modelbridge.base import ModelBridge, unwrap_observation_data from ax.utils.common.logger import get_logger - from ax.utils.stats.model_fit_stats import ( _correlation_coefficient, _fisher_exact_test_p, @@ -429,83 +424,6 @@ def _gen_train_test_split( yield set(arm_names[:-n_test]), set(arm_names[-n_test:]) -class BestModelSelector(ABC): - @abstractmethod - def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int: - """ - Return the index of the best diagnostic. - """ - pass - - -class CallableEnum(Enum): - # pyre-fixme[3]: Return annotation cannot be `Any`. - def __call__(self, *args: Optional[Any], **kwargs: Optional[Any]) -> Any: - return self.value(*args, **kwargs) - - -class MetricAggregation(CallableEnum): - MEAN: Callable[[Iterable[Number]], Number] = partial(np.mean) - - -class DiagnosticCriterion(CallableEnum): - MIN: Callable[[Iterable[Number]], Number] = partial(np.amin) - - -class SingleDiagnosticBestModelSelector(BestModelSelector): - """Choose the best model using a single cross-validation diagnostic. - - The input is a list of CVDiagnostics, each corresponding to one model. - The specified diagnostic is extracted from each of the CVDiagnostics, - its values (each of which corresponds to a separate metric) are - aggregated with the aggregation function, the best one is determined - with the criterion, and the index of the best diagnostic result is returned. - - - Example: - - :: - s = SingleDiagnosticBestModelSelector( - diagnostic = 'Fisher exact test p', - criterion = DiagnosticCriterion.MIN, - metric_aggregation = MetricAggregation.MEAN, - ) - best_diagnostic_index = s.best_diagnostic(diagnostics) - - Args: - diagnostic (str): The name of the diagnostic to use, which should be - a key in CVDiagnostic. - metric_aggregation (MetricAggregation): Callable - applied to the values of the diagnostic for a single model to - produce a single number. - criterion (DiagnosticCriterion): Callable used - to determine which of the (aggregated) diagnostics is the best. - - - Returns: - int: index of the selected best diagnostic. - - """ - - def __init__( - self, - diagnostic: str, - metric_aggregation: MetricAggregation, - criterion: DiagnosticCriterion, - ) -> None: - self.diagnostic = diagnostic - self.metric_aggregation = metric_aggregation - self.criterion = criterion - - def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int: - aggregated_diagnostic_values = [ - self.metric_aggregation(list(d[self.diagnostic].values())) - for d in diagnostics - ] - best_diagnostic = self.criterion(aggregated_diagnostic_values) - return aggregated_diagnostic_values.index(best_diagnostic) - - """ ############################## Model Fit Metrics Utils ############################## """ diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 58c8530f634..73545146213 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -9,7 +9,6 @@ from __future__ import annotations from collections import defaultdict - from dataclasses import dataclass, field from logging import Logger from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -28,7 +27,7 @@ from ax.exceptions.core import UserInputError from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints from ax.modelbridge.base import ModelBridge -from ax.modelbridge.cross_validation import BestModelSelector +from ax.modelbridge.best_model_selector import BestModelSelector from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec from ax.modelbridge.registry import ModelRegistryBase from ax.modelbridge.transition_criterion import ( @@ -360,11 +359,8 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec: raise NotImplementedError(CANNOT_SELECT_ONE_MODEL_MSG) return self.model_specs[0] - for model_spec in self.model_specs: - model_spec.cross_validate() - - best_model_index = not_none(self.best_model_selector).best_diagnostic( - diagnostics=[not_none(m.diagnostics) for m in self.model_specs], + best_model_index = not_none(self.best_model_selector).best_model( + model_specs=self.model_specs, ) return self.model_specs[best_model_index] diff --git a/ax/modelbridge/tests/test_best_model_selector.py b/ax/modelbridge/tests/test_best_model_selector.py new file mode 100644 index 00000000000..f96c111add3 --- /dev/null +++ b/ax/modelbridge/tests/test_best_model_selector.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from unittest.mock import Mock + +from ax.exceptions.core import UserInputError +from ax.modelbridge.best_model_selector import ( + ReductionCriterion, + SingleDiagnosticBestModelSelector, +) +from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.registry import Models +from ax.utils.common.testutils import TestCase + + +class TestBestModelSelector(TestCase): + def setUp(self) -> None: + super().setUp() + + # Construct a series of model specs with dummy CV diagnostics. + self.model_specs = [] + for diagnostics in [ + {"Fisher exact test p": {"y_a": 0.0, "y_b": 0.4}}, + {"Fisher exact test p": {"y_a": 0.1, "y_b": 0.1}}, + {"Fisher exact test p": {"y_a": 0.5, "y_b": 0.6}}, + ]: + ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms._cv_results = Mock() + ms._diagnostics = diagnostics + self.model_specs.append(ms) + + def test_user_input_error(self) -> None: + with self.assertRaisesRegex(UserInputError, "ReductionCriterion"): + SingleDiagnosticBestModelSelector( + "Fisher exact test p", metric_aggregation=min, criterion=max + ) + with self.assertRaisesRegex(UserInputError, "use MIN or MAX"): + SingleDiagnosticBestModelSelector( + "Fisher exact test p", + metric_aggregation=ReductionCriterion.MEAN, + criterion=ReductionCriterion.MEAN, + ) + + def test_SingleDiagnosticBestModelSelector_min_mean(self) -> None: + s = SingleDiagnosticBestModelSelector( + diagnostic="Fisher exact test p", + criterion=ReductionCriterion.MIN, + metric_aggregation=ReductionCriterion.MEAN, + ) + self.assertEqual(s.best_model(model_specs=self.model_specs), 1) + + def test_SingleDiagnosticBestModelSelector_min_min(self) -> None: + s = SingleDiagnosticBestModelSelector( + diagnostic="Fisher exact test p", + criterion=ReductionCriterion.MIN, + metric_aggregation=ReductionCriterion.MIN, + ) + self.assertEqual(s.best_model(model_specs=self.model_specs), 0) + + def test_SingleDiagnosticBestModelSelector_max_mean(self) -> None: + s = SingleDiagnosticBestModelSelector( + diagnostic="Fisher exact test p", + criterion=ReductionCriterion.MAX, + metric_aggregation=ReductionCriterion.MEAN, + ) + self.assertEqual(s.best_model(model_specs=self.model_specs), 2) diff --git a/ax/modelbridge/tests/test_cross_validation.py b/ax/modelbridge/tests/test_cross_validation.py index 5cce1383e36..29e278ec48b 100644 --- a/ax/modelbridge/tests/test_cross_validation.py +++ b/ax/modelbridge/tests/test_cross_validation.py @@ -30,7 +30,6 @@ CVDiagnostics, CVResult, has_good_opt_config_model_fit, - SingleDiagnosticBestModelSelector, ) from ax.modelbridge.registry import Models from ax.utils.common.testutils import TestCase @@ -419,27 +418,3 @@ def test_HasGoodOptConfigModelFit(self) -> None: assess_model_fit_result=assess_model_fit_result, ) self.assertFalse(has_good_fit) - - def test_SingleDiagnosticBestModelSelector_min_mean(self) -> None: - s = SingleDiagnosticBestModelSelector( - diagnostic="Fisher exact test p", - criterion=min, - metric_aggregation=np.mean, - ) - self.assertEqual(s.best_diagnostic(self.diagnostics), 1) - - def test_SingleDiagnosticBestModelSelector_min_min(self) -> None: - s = SingleDiagnosticBestModelSelector( - diagnostic="Fisher exact test p", - criterion=min, - metric_aggregation=min, - ) - self.assertEqual(s.best_diagnostic(self.diagnostics), 0) - - def test_SingleDiagnosticBestModelSelector_max_mean(self) -> None: - s = SingleDiagnosticBestModelSelector( - diagnostic="Fisher exact test p", - criterion=max, - metric_aggregation=np.mean, - ) - self.assertEqual(s.best_diagnostic(self.diagnostics), 2) diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index 9639c82f867..d151a1f33ff 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -7,14 +7,13 @@ # pyre-strict from logging import Logger -from unittest.mock import patch, PropertyMock +from unittest.mock import MagicMock, patch, PropertyMock from ax.core.base_trial import TrialStatus from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError -from ax.modelbridge.cross_validation import ( - DiagnosticCriterion, - MetricAggregation, +from ax.modelbridge.best_model_selector import ( + ReductionCriterion, SingleDiagnosticBestModelSelector, ) from ax.modelbridge.factory import get_sobol @@ -267,36 +266,24 @@ def test_properties(self) -> None: class TestGenerationNodeWithBestModelSelector(TestCase): - @fast_botorch_optimize def setUp(self) -> None: super().setUp() - self.branin_experiment = get_branin_experiment() - sobol = Models.SOBOL(search_space=self.branin_experiment.search_space) - sobol_run = sobol.gen(n=20) - self.branin_experiment.new_batch_trial().add_generator_run( - sobol_run - ).run().mark_completed() - data = self.branin_experiment.fetch_data() - - ms_gpei = ModelSpec(model_enum=Models.GPEI) - ms_gpei.fit(experiment=self.branin_experiment, data=data) - + self.branin_experiment = get_branin_experiment( + with_batch=True, with_completed_batch=True + ) + ms_mixed = ModelSpec(model_enum=Models.BO_MIXED) ms_botorch = ModelSpec(model_enum=Models.BOTORCH_MODULAR) - ms_botorch.fit(experiment=self.branin_experiment, data=data) - - self.fitted_model_specs = [ms_gpei, ms_botorch] + self.mock_aggregation = MagicMock( + side_effect=ReductionCriterion.MEAN, spec=ReductionCriterion + ) self.model_selection_node = GenerationNode( node_name="test", - model_specs=self.fitted_model_specs, + model_specs=[ms_mixed, ms_botorch], best_model_selector=SingleDiagnosticBestModelSelector( diagnostic="Fisher exact test p", - # pyre-fixme[6]: For 2nd param expected `DiagnosticCriterion` but - # got `MetricAggregation`. - criterion=MetricAggregation.MEAN, - # pyre-fixme[6]: For 3rd param expected `MetricAggregation` but got - # `DiagnosticCriterion`. - metric_aggregation=DiagnosticCriterion.MIN, + metric_aggregation=self.mock_aggregation, + criterion=ReductionCriterion.MIN, ), ) @@ -308,10 +295,12 @@ def test_gen(self) -> None: # Check that with `ModelSelectionNode` generation from a node with # multiple model specs does not fail. gr = self.model_selection_node.gen(n=1, pending_observations={"branin": []}) - # Currently, `ModelSelectionNode` should just pick the first model - # spec as the one to generate from. - # TODO[adamobeng]: Test correct behavior here when implemented. - self.assertEqual(gr._model_key, "GPEI") + # Check that the metric aggregation function is called twice, once for each + # model spec. + self.assertEqual(self.mock_aggregation.call_count, 2) + # The model specs are practically identical for this example. + # Should pick the first one. + self.assertEqual(gr._model_key, "BO_MIXED") # test model_to_gen_from_name property - self.assertEqual(self.model_selection_node.model_to_gen_from_name, "GPEI") + self.assertEqual(self.model_selection_node.model_to_gen_from_name, "BO_MIXED") diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index c956557ac5b..d825f4f3989 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -135,6 +135,13 @@ Cross Validation :undoc-members: :show-inheritance: +Model Selection +~~~~~~~~~~~~~~~~ +.. automodule:: ax.modelbridge.best_model_selector + :members: + :undoc-members: + :show-inheritance: + Dispatch Utilities ~~~~~~~~~~~~~~~~~~