From b471c128f437b30abde92cf116e4b7bde862e818 Mon Sep 17 00:00:00 2001 From: cnellington Date: Fri, 19 Apr 2024 11:20:04 -0400 Subject: [PATCH] added select_good_bootstraps tests and reference, fixed model list indexing bug --- contextualized/analysis/__init__.py | 1 + contextualized/analysis/bootstraps.py | 25 ++++++++++---------- contextualized/analysis/tests.py | 34 +++++++++++++++++++++++++-- docs/source/analysis.rst | 2 ++ tests.py | 1 + 5 files changed, 48 insertions(+), 15 deletions(-) diff --git a/contextualized/analysis/__init__.py b/contextualized/analysis/__init__.py index 7a36aa9..bf9b00c 100644 --- a/contextualized/analysis/__init__.py +++ b/contextualized/analysis/__init__.py @@ -3,6 +3,7 @@ """ from contextualized.analysis.accuracy_split import print_acc_by_covars +from contextualized.analysis.bootstraps import select_good_bootstraps from contextualized.analysis.embeddings import ( plot_lowdim_rep, plot_embedding_for_all_covars, diff --git a/contextualized/analysis/bootstraps.py b/contextualized/analysis/bootstraps.py index a9e2eba..d967afd 100644 --- a/contextualized/analysis/bootstraps.py +++ b/contextualized/analysis/bootstraps.py @@ -1,25 +1,24 @@ -# Utility functions for bootstraps +import numpy as np +from contextualized.easy.wrappers import SKLearnWrapper -def select_good_bootstraps(sklearn_wrapper, train_errs, tol=2, **kwargs): +def select_good_bootstraps(sklearn_wrapper: SKLearnWrapper, train_errs: np.ndarray, tol: float = 2) -> SKLearnWrapper: """ Select bootstraps that are good for a given model. - Parameters - ---------- - sklearn_wrapper : contextualized.easy.wrappers.SKLearnWrapper - train_errs : np.ndarray of shape (n_bootstraps, n_samples, n_outcomes) - tol : float tolerance for the mean of the train_errs + Args: + sklearn_wrapper (contextualized.easy.wrappers.SKLearnWrapper): Wrapper for the sklearn model. + train_errs (np.ndarray): Training errors for each bootstrap (n_bootstraps, n_samples, n_outcomes). + tol (float): Only bootstraps with mean train_errs below tol * min(train_errs) are kept. - Returns - ------- - sklearn_wrapper : sklearn_wrapper with only selected bootstraps + Returns: + contextualized.easy.wrappers.SKLearnWrapper: The input model with only selected bootstraps. """ if len(train_errs.shape) == 2: train_errs = train_errs[:, :, None] train_errs_by_bootstrap = np.mean(train_errs, axis=(1, 2)) - sklearn_wrapper.models = sklearn_wrapper.models[ - train_errs_by_bootstrap < tol * np.min(train_errs_by_bootstrap) - ] + train_errs_min = np.min(train_errs_by_bootstrap) + sklearn_wrapper.models = [model for train_err, model in zip(train_errs_by_bootstrap, sklearn_wrapper.models) if train_err < train_errs_min * tol] + sklearn_wrapper.n_bootstraps = len(sklearn_wrapper.models) return sklearn_wrapper diff --git a/contextualized/analysis/tests.py b/contextualized/analysis/tests.py index 1438e4b..75e52fe 100644 --- a/contextualized/analysis/tests.py +++ b/contextualized/analysis/tests.py @@ -3,12 +3,14 @@ """ import unittest +import copy import numpy as np import pandas as pd -from unittest.mock import MagicMock + from contextualized.analysis import ( - test_each_context + test_each_context, + select_good_bootstraps ) from contextualized.easy import ContextualizedRegressor @@ -59,5 +61,33 @@ def test_expected_significant_pval(self): self.assertTrue(all(pval >= 0.05 for pval in other_pvals['Pvals']), "Other p-values are significant.") +class TestSelectGoodBootstraps(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setUp(self): + self.model = ContextualizedRegressor(n_bootstraps = 3) + C = np.random.uniform(0, 1, size=(100, 2)) + X = np.random.uniform(0, 1, size=(100, 2)) + Y = np.random.uniform(0, 1, size=(100, 2)) + self.model.fit(C, X, Y) + Y_pred = self.model.predict(C, X, individual_preds = True) + self.train_errs = np.zeros_like((Y - Y_pred) ** 2) + self.train_errs[0] = 0.1 + self.train_errs[1] = 0.2 + self.train_errs[2] = 0.3 + self.model_copy = copy.deepcopy(self.model) + select_good_bootstraps(self.model, self.train_errs) + + def test_model_has_fewer_bootstraps(self): + """ + Test that the model has fewer bootstraps after calling select_good_bootstraps. + """ + self.assertEqual(len(self.model.models), 1) + self.assertEqual(len(self.model_copy.models), 3) + self.assertLess(len(self.model.models), len(self.model_copy.models)) + + if __name__ == '__main__': unittest.main() diff --git a/docs/source/analysis.rst b/docs/source/analysis.rst index 4887c95..f20b410 100644 --- a/docs/source/analysis.rst +++ b/docs/source/analysis.rst @@ -15,6 +15,7 @@ All functions can be loaded directly from the module, e.g. ``from contextualized pvals.test_each_context pvals.get_possible_pvals accuracy_split.print_acc_by_covars + bootstraps.select_good_bootstraps embeddings.plot_lowdim_rep embeddings.plot_embedding_for_all_covars effects.plot_homogeneous_context_effects @@ -27,6 +28,7 @@ All functions can be loaded directly from the module, e.g. ``from contextualized .. autofunction:: contextualized.analysis.pvals.test_each_context .. autofunction:: contextualized.analysis.pvals.get_possible_pvals .. autofunction:: contextualized.analysis.accuracy_split.print_acc_by_covars +.. autofunction:: contextualized.analysis.bootstraps.select_good_bootstraps .. autofunction:: contextualized.analysis.embeddings.plot_lowdim_rep .. autofunction:: contextualized.analysis.embeddings.plot_embedding_for_all_covars .. autofunction:: contextualized.analysis.effects.plot_homogeneous_context_effects diff --git a/tests.py b/tests.py index 52113b3..43ced63 100644 --- a/tests.py +++ b/tests.py @@ -1,6 +1,7 @@ import unittest from contextualized.regression.tests import * from contextualized.dags.tests import * +from contextualized.dags.tests_fast import * from contextualized.easy.tests.test_regressor import * from contextualized.easy.tests.test_classifier import * from contextualized.easy.tests.test_markov_networks import *