From 64c2733d4d10459dd46f03db7745dc956cca0e7d Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Sun, 14 Apr 2024 10:48:05 -0700 Subject: [PATCH] Standardization check for entropy of observations (#2366) Summary: Ads a standardization check to the entropy of observations metric, which assumes standardization implicitly to choose a sensible default bandwidth for the kernel density estimator. Differential Revision: D56105162 --- ax/utils/stats/model_fit_stats.py | 15 +++++++++++++-- ax/utils/stats/tests/test_model_fit_stats.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/ax/utils/stats/model_fit_stats.py b/ax/utils/stats/model_fit_stats.py index 34647e448d0..58a1a56adb7 100644 --- a/ax/utils/stats/model_fit_stats.py +++ b/ax/utils/stats/model_fit_stats.py @@ -5,12 +5,18 @@ # pyre-strict +from logging import Logger from typing import Dict, Mapping, Optional, Protocol import numpy as np + +from ax.utils.common.logger import get_logger from scipy.stats import fisher_exact, norm, pearsonr, spearmanr from sklearn.neighbors import KernelDensity + +logger: Logger = get_logger(__name__) + """ ################################ Model Fit Metrics ############################### """ @@ -140,8 +146,8 @@ def entropy_of_observations( Args: y_obs: An array of observations for a single metric. - y_pred: An array of the predicted values corresponding to y_obs. - se_pred: An array of the standard errors of the predicted values. + y_pred: Unused. + se_pred: Unused. bandwidth: The kernel bandwidth. Defaults to 0.1, which is a reasonable value for standardized outcomes y_obs. The rank ordering of the results on a set of y_obs data sets is not generally sensitive to the bandwidth, if it is @@ -153,6 +159,11 @@ def entropy_of_observations( """ if y_obs.ndim == 1: y_obs = y_obs[:, np.newaxis] + + # Check if standardization was applied to the observations. + y_std = np.std(y_obs, axis=0, ddof=1) + if np.any(y_std < 0.5) or np.any(2.0 < y_std): # allowing a fudge factor of 2. + logger.warning("Standardization of observations was not applied.") return _entropy_via_kde(y_obs, bandwidth=bandwidth) diff --git a/ax/utils/stats/tests/test_model_fit_stats.py b/ax/utils/stats/tests/test_model_fit_stats.py index 324af2ef9fc..cad041872a1 100644 --- a/ax/utils/stats/tests/test_model_fit_stats.py +++ b/ax/utils/stats/tests/test_model_fit_stats.py @@ -46,6 +46,22 @@ def test_entropy_of_observations(self) -> None: # ordering of entropies stays the same, though the difference is smaller self.assertTrue(er2 - ec2 > 3) + # test warning if y is not standardized + module_name = "ax.utils.stats.model_fit_stats" + expected_warning = ( + "WARNING:ax.utils.stats.model_fit_stats:Standardization" + " of observations was not applied." + ) + with self.assertLogs(module_name, level="WARNING") as logger: + ec = entropy_of_observations(y_obs=10 * yc, y_pred=ones, se_pred=ones) + self.assertEqual(len(logger.output), 1) + self.assertEqual(logger.output[0], expected_warning) + + with self.assertLogs(module_name, level="WARNING") as logger: + ec = entropy_of_observations(y_obs=yc / 10, y_pred=ones, se_pred=ones) + self.assertEqual(len(logger.output), 1) + self.assertEqual(logger.output[0], expected_warning) + def test_contingency_table_construction(self) -> None: # Create a dummy set of observations and predictions y_obs = np.array([1, 3, 2, 5, 7, 3])