Skip to content

Commit

Permalink
Standardization check for entropy of observations (#2366)
Browse files Browse the repository at this point in the history
Summary:

Ads a standardization check to the entropy of observations metric, which assumes standardization implicitly to choose a sensible default bandwidth for the kernel density estimator.

Differential Revision: D56105162
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Apr 14, 2024
1 parent 4579469 commit 64c2733
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
15 changes: 13 additions & 2 deletions ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@

# pyre-strict

from logging import Logger
from typing import Dict, Mapping, Optional, Protocol

import numpy as np

from ax.utils.common.logger import get_logger
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity


logger: Logger = get_logger(__name__)

"""
################################ Model Fit Metrics ###############################
"""
Expand Down Expand Up @@ -140,8 +146,8 @@ def entropy_of_observations(
Args:
y_obs: An array of observations for a single metric.
y_pred: An array of the predicted values corresponding to y_obs.
se_pred: An array of the standard errors of the predicted values.
y_pred: Unused.
se_pred: Unused.
bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value
for standardized outcomes y_obs. The rank ordering of the results on a set
of y_obs data sets is not generally sensitive to the bandwidth, if it is
Expand All @@ -153,6 +159,11 @@ def entropy_of_observations(
"""
if y_obs.ndim == 1:
y_obs = y_obs[:, np.newaxis]

# Check if standardization was applied to the observations.
y_std = np.std(y_obs, axis=0, ddof=1)
if np.any(y_std < 0.5) or np.any(2.0 < y_std): # allowing a fudge factor of 2.
logger.warning("Standardization of observations was not applied.")
return _entropy_via_kde(y_obs, bandwidth=bandwidth)


Expand Down
16 changes: 16 additions & 0 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ def test_entropy_of_observations(self) -> None:
# ordering of entropies stays the same, though the difference is smaller
self.assertTrue(er2 - ec2 > 3)

# test warning if y is not standardized
module_name = "ax.utils.stats.model_fit_stats"
expected_warning = (
"WARNING:ax.utils.stats.model_fit_stats:Standardization"
" of observations was not applied."
)
with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=10 * yc, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

with self.assertLogs(module_name, level="WARNING") as logger:
ec = entropy_of_observations(y_obs=yc / 10, y_pred=ones, se_pred=ones)
self.assertEqual(len(logger.output), 1)
self.assertEqual(logger.output[0], expected_warning)

def test_contingency_table_construction(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1, 3, 2, 5, 7, 3])
Expand Down

0 comments on commit 64c2733

Please sign in to comment.