Skip to content

Commit

Permalink
Standardization check for entropy of observations (#2366)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2366

Ads a standardization check to the entropy of observations metric, which assumes standardization implicitly to choose a sensible default bandwidth for the kernel density estimator.

Reviewed By: Balandat

Differential Revision: D56105162

fbshipit-source-id: 33f0290498e187bdb577c49ac57e1c68df3c4edf
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 14, 2024
1 parent 4579469 commit a446aa4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
27 changes: 23 additions & 4 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@

# pyre-strict

from logging import Logger
from typing import Dict, Mapping, Optional, Protocol

import numpy as np

from ax.utils.common.logger import get_logger
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity


logger: Logger = get_logger(__name__)


DEFAULT_KDE_BANDWIDTH = 0.1 # default bandwidth for kernel density estimators

"""
################################ Model Fit Metrics ###############################
"""
Expand Down Expand Up @@ -132,16 +141,16 @@ def entropy_of_observations(
y_obs: np.ndarray,
y_pred: np.ndarray,
se_pred: np.ndarray,
bandwidth: float = 0.1,
bandwidth: float = DEFAULT_KDE_BANDWIDTH,
) -> float:
"""Computes the entropy of the observations y_obs using a kernel density estimator.
This can be used to quantify how "clustered" the outcomes are. NOTE: y_pred and
se_pred are not used, but are required for the API.
Args:
y_obs: An array of observations for a single metric.
y_pred: An array of the predicted values corresponding to y_obs.
se_pred: An array of the standard errors of the predicted values.
y_pred: Unused.
se_pred: Unused.
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
for standardized outcomes y_obs. The rank ordering of the results on a set
of y_obs data sets is not generally sensitive to the bandwidth, if it is
Expand All @@ -153,10 +162,20 @@ def entropy_of_observations(
"""
if y_obs.ndim == 1:
y_obs = y_obs[:, np.newaxis]

# Check if standardization was applied to the observations.
if bandwidth == DEFAULT_KDE_BANDWIDTH:
y_std = np.std(y_obs, axis=0, ddof=1)
if np.any(y_std < 0.5) or np.any(2.0 < y_std): # allowing a fudge factor of 2.
logger.warning(
"Standardization of observations was not applied. "
f"The default bandwidth of {DEFAULT_KDE_BANDWIDTH} is a reasonable "
"choice if observations are standardize, but may not be otherwise."
)
return _entropy_via_kde(y_obs, bandwidth=bandwidth)


def _entropy_via_kde(y: np.ndarray, bandwidth: float = 0.1) -> float:
def _entropy_via_kde(y: np.ndarray, bandwidth: float = DEFAULT_KDE_BANDWIDTH) -> float:
"""Computes the entropy of the kernel density estimate of the input data.
Args:
Expand Down
17 changes: 17 additions & 0 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ def test_entropy_of_observations(self) -> None:
# ordering of entropies stays the same, though the difference is smaller
self.assertTrue(er2 - ec2 > 3)

# test warning if y is not standardized
module_name = "ax.utils.stats.model_fit_stats"
expected_warning = (
"WARNING:ax.utils.stats.model_fit_stats:Standardization of observations "
"was not applied. The default bandwidth of 0.1 is a reasonable "
"choice if observations are standardize, but may not be otherwise."
)
with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=10 * yc, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=yc / 10, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

def test_contingency_table_construction(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1, 3, 2, 5, 7, 3])
Expand Down

0 comments on commit a446aa4

Please sign in to comment.