Skip to content

Commit

Permalink
Refactor BestModelSelector to operate on ModelSpecs (#2557)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2557

`BestModelSelector` was previously limited to selecting the best out of a given dictionary of CV diagnostics that were computed in `ModelSpec.cross_validate`. This setup limited extensibility, since any change would require updating `ModelSpec` code to the diagnostics that are computed.

This diff refactors `BestModelSelector` to directly operate on the `ModelSpecs`. This new modular design will let each `BestModelSelector` class compute the necessary diagnostics internally, without locking us up to any pre-specified list.

Other minor changes:
- Removed `CallableEnum` and subclasses and replaced these with a single `ReductionCriterion` enum.
- Split off `BestModelSelector` into a separate file to avoid circular imports.

Reviewed By: mgarrard

Differential Revision: D59249657

fbshipit-source-id: b484e6e859231e99c8e079f6b82038661b44fd68
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jul 9, 2024
1 parent f6bf1c6 commit c03d98a
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 147 deletions.
109 changes: 109 additions & 0 deletions ax/modelbridge/best_model_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from functools import partial
from typing import Callable, List, Union

import numpy as np
from ax.exceptions.core import UserInputError
from ax.modelbridge.model_spec import ModelSpec
from ax.utils.common.typeutils import not_none

ARRAYLIKE = Union[np.ndarray, List[float], List[np.ndarray]]


class BestModelSelector(ABC):
@abstractmethod
def best_model(self, model_specs: List[ModelSpec]) -> int:
"""
Return the index of the best ``ModelSpec``.
"""


class ReductionCriterion(Enum):
"""An enum for callables that are used for aggregating diagnostics over metrics
and selecting the best diagnostic in ``SingleDiagnosticBestModelSelector``.
NOTE: This is used to ensure serializability of the callables.
"""

# NOTE: Callables need to be wrapped in `partial` to be registered as members.
MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean)
MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min)
MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max)

def __call__(self, array_like: ARRAYLIKE) -> np.ndarray:
return self.value(array_like)


class SingleDiagnosticBestModelSelector(BestModelSelector):
"""Choose the best model using a single cross-validation diagnostic.
The input is a list of ``ModelSpec``, each corresponding to one model.
The specified diagnostic is extracted from each of the models,
its values (each of which corresponds to a separate metric) are
aggregated with the aggregation function, the best one is determined
with the criterion, and the index of the best diagnostic result is returned.
Example:
::
s = SingleDiagnosticBestModelSelector(
diagnostic = 'Fisher exact test p',
metric_aggregation = ReductionCriterion.MEAN,
criterion = ReductionCriterion.MIN,
)
best_diagnostic_index = s.best_diagnostic(diagnostics)
Args:
diagnostic: The name of the diagnostic to use, which should be
a key in ``CVDiagnostic``.
metric_aggregation: ``ReductionCriterion`` applied to the values of the
diagnostic for a single model to produce a single number.
criterion: ``ReductionCriterion`` used to determine which of the
(aggregated) diagnostics is the best.
Returns:
int: index of the selected best diagnostic.
"""

def __init__(
self,
diagnostic: str,
metric_aggregation: ReductionCriterion,
criterion: ReductionCriterion,
) -> None:
self.diagnostic = diagnostic
if not isinstance(metric_aggregation, ReductionCriterion) or not isinstance(
criterion, ReductionCriterion
):
raise UserInputError(
"Both `metric_aggregation` and `criterion` must be "
f"`ReductionCriterion`. Got {metric_aggregation=}, {criterion=}."
)
if criterion == ReductionCriterion.MEAN:
raise UserInputError(
f"{criterion=} is not supported. Please use MIN or MAX."
)
self.metric_aggregation = metric_aggregation
self.criterion = criterion

def best_model(self, model_specs: List[ModelSpec]) -> int:
for model_spec in model_specs:
model_spec.cross_validate()
aggregated_diagnostic_values = [
self.metric_aggregation(
list(not_none(model_spec.diagnostics)[self.diagnostic].values())
)
for model_spec in model_specs
]
best_diagnostic = self.criterion(aggregated_diagnostic_values).item()
return aggregated_diagnostic_values.index(best_diagnostic)
86 changes: 2 additions & 84 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,21 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from functools import partial

from logging import Logger
from numbers import Number
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
from warnings import warn

import numpy as np
from ax.core.observation import Observation, ObservationData, recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.modelbridge.base import ModelBridge, unwrap_observation_data
from ax.utils.common.logger import get_logger

from ax.utils.stats.model_fit_stats import (
_correlation_coefficient,
_fisher_exact_test_p,
Expand Down Expand Up @@ -429,83 +424,6 @@ def _gen_train_test_split(
yield set(arm_names[:-n_test]), set(arm_names[-n_test:])


class BestModelSelector(ABC):
@abstractmethod
def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
"""
Return the index of the best diagnostic.
"""
pass


class CallableEnum(Enum):
# pyre-fixme[3]: Return annotation cannot be `Any`.
def __call__(self, *args: Optional[Any], **kwargs: Optional[Any]) -> Any:
return self.value(*args, **kwargs)


class MetricAggregation(CallableEnum):
MEAN: Callable[[Iterable[Number]], Number] = partial(np.mean)


class DiagnosticCriterion(CallableEnum):
MIN: Callable[[Iterable[Number]], Number] = partial(np.amin)


class SingleDiagnosticBestModelSelector(BestModelSelector):
"""Choose the best model using a single cross-validation diagnostic.
The input is a list of CVDiagnostics, each corresponding to one model.
The specified diagnostic is extracted from each of the CVDiagnostics,
its values (each of which corresponds to a separate metric) are
aggregated with the aggregation function, the best one is determined
with the criterion, and the index of the best diagnostic result is returned.
Example:
::
s = SingleDiagnosticBestModelSelector(
diagnostic = 'Fisher exact test p',
criterion = DiagnosticCriterion.MIN,
metric_aggregation = MetricAggregation.MEAN,
)
best_diagnostic_index = s.best_diagnostic(diagnostics)
Args:
diagnostic (str): The name of the diagnostic to use, which should be
a key in CVDiagnostic.
metric_aggregation (MetricAggregation): Callable
applied to the values of the diagnostic for a single model to
produce a single number.
criterion (DiagnosticCriterion): Callable used
to determine which of the (aggregated) diagnostics is the best.
Returns:
int: index of the selected best diagnostic.
"""

def __init__(
self,
diagnostic: str,
metric_aggregation: MetricAggregation,
criterion: DiagnosticCriterion,
) -> None:
self.diagnostic = diagnostic
self.metric_aggregation = metric_aggregation
self.criterion = criterion

def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
aggregated_diagnostic_values = [
self.metric_aggregation(list(d[self.diagnostic].values()))
for d in diagnostics
]
best_diagnostic = self.criterion(aggregated_diagnostic_values)
return aggregated_diagnostic_values.index(best_diagnostic)


"""
############################## Model Fit Metrics Utils ##############################
"""
Expand Down
10 changes: 3 additions & 7 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

from collections import defaultdict

from dataclasses import dataclass, field
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
Expand All @@ -28,7 +27,7 @@
from ax.exceptions.core import UserInputError
from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.cross_validation import BestModelSelector
from ax.modelbridge.best_model_selector import BestModelSelector
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import ModelRegistryBase
from ax.modelbridge.transition_criterion import (
Expand Down Expand Up @@ -360,11 +359,8 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
raise NotImplementedError(CANNOT_SELECT_ONE_MODEL_MSG)
return self.model_specs[0]

for model_spec in self.model_specs:
model_spec.cross_validate()

best_model_index = not_none(self.best_model_selector).best_diagnostic(
diagnostics=[not_none(m.diagnostics) for m in self.model_specs],
best_model_index = not_none(self.best_model_selector).best_model(
model_specs=self.model_specs,
)
return self.model_specs[best_model_index]

Expand Down
71 changes: 71 additions & 0 deletions ax/modelbridge/tests/test_best_model_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from unittest.mock import Mock

from ax.exceptions.core import UserInputError
from ax.modelbridge.best_model_selector import (
ReductionCriterion,
SingleDiagnosticBestModelSelector,
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase


class TestBestModelSelector(TestCase):
def setUp(self) -> None:
super().setUp()

# Construct a series of model specs with dummy CV diagnostics.
self.model_specs = []
for diagnostics in [
{"Fisher exact test p": {"y_a": 0.0, "y_b": 0.4}},
{"Fisher exact test p": {"y_a": 0.1, "y_b": 0.1}},
{"Fisher exact test p": {"y_a": 0.5, "y_b": 0.6}},
]:
ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR)
ms._cv_results = Mock()
ms._diagnostics = diagnostics
self.model_specs.append(ms)

def test_user_input_error(self) -> None:
with self.assertRaisesRegex(UserInputError, "ReductionCriterion"):
SingleDiagnosticBestModelSelector(
"Fisher exact test p", metric_aggregation=min, criterion=max
)
with self.assertRaisesRegex(UserInputError, "use MIN or MAX"):
SingleDiagnosticBestModelSelector(
"Fisher exact test p",
metric_aggregation=ReductionCriterion.MEAN,
criterion=ReductionCriterion.MEAN,
)

def test_SingleDiagnosticBestModelSelector_min_mean(self) -> None:
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=ReductionCriterion.MIN,
metric_aggregation=ReductionCriterion.MEAN,
)
self.assertEqual(s.best_model(model_specs=self.model_specs), 1)

def test_SingleDiagnosticBestModelSelector_min_min(self) -> None:
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=ReductionCriterion.MIN,
metric_aggregation=ReductionCriterion.MIN,
)
self.assertEqual(s.best_model(model_specs=self.model_specs), 0)

def test_SingleDiagnosticBestModelSelector_max_mean(self) -> None:
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=ReductionCriterion.MAX,
metric_aggregation=ReductionCriterion.MEAN,
)
self.assertEqual(s.best_model(model_specs=self.model_specs), 2)
25 changes: 0 additions & 25 deletions ax/modelbridge/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
CVDiagnostics,
CVResult,
has_good_opt_config_model_fit,
SingleDiagnosticBestModelSelector,
)
from ax.modelbridge.registry import Models
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -419,27 +418,3 @@ def test_HasGoodOptConfigModelFit(self) -> None:
assess_model_fit_result=assess_model_fit_result,
)
self.assertFalse(has_good_fit)

def test_SingleDiagnosticBestModelSelector_min_mean(self) -> None:
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=min,
metric_aggregation=np.mean,
)
self.assertEqual(s.best_diagnostic(self.diagnostics), 1)

def test_SingleDiagnosticBestModelSelector_min_min(self) -> None:
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=min,
metric_aggregation=min,
)
self.assertEqual(s.best_diagnostic(self.diagnostics), 0)

def test_SingleDiagnosticBestModelSelector_max_mean(self) -> None:
s = SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
criterion=max,
metric_aggregation=np.mean,
)
self.assertEqual(s.best_diagnostic(self.diagnostics), 2)
Loading

0 comments on commit c03d98a

Please sign in to comment.