Skip to content

Commit

Permalink
added select_good_bootstraps tests and reference, fixed model list in…
Browse files Browse the repository at this point in the history
…dexing bug
  • Loading branch information
cnellington committed Apr 19, 2024
1 parent f84c12a commit b471c12
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 15 deletions.
1 change: 1 addition & 0 deletions contextualized/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from contextualized.analysis.accuracy_split import print_acc_by_covars
from contextualized.analysis.bootstraps import select_good_bootstraps
from contextualized.analysis.embeddings import (
plot_lowdim_rep,
plot_embedding_for_all_covars,
Expand Down
25 changes: 12 additions & 13 deletions contextualized/analysis/bootstraps.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
# Utility functions for bootstraps
import numpy as np
from contextualized.easy.wrappers import SKLearnWrapper


def select_good_bootstraps(sklearn_wrapper, train_errs, tol=2, **kwargs):
def select_good_bootstraps(sklearn_wrapper: SKLearnWrapper, train_errs: np.ndarray, tol: float = 2) -> SKLearnWrapper:
"""
Select bootstraps that are good for a given model.
Parameters
----------
sklearn_wrapper : contextualized.easy.wrappers.SKLearnWrapper
train_errs : np.ndarray of shape (n_bootstraps, n_samples, n_outcomes)
tol : float tolerance for the mean of the train_errs
Args:
sklearn_wrapper (contextualized.easy.wrappers.SKLearnWrapper): Wrapper for the sklearn model.
train_errs (np.ndarray): Training errors for each bootstrap (n_bootstraps, n_samples, n_outcomes).
tol (float): Only bootstraps with mean train_errs below tol * min(train_errs) are kept.
Returns
-------
sklearn_wrapper : sklearn_wrapper with only selected bootstraps
Returns:
contextualized.easy.wrappers.SKLearnWrapper: The input model with only selected bootstraps.
"""
if len(train_errs.shape) == 2:
train_errs = train_errs[:, :, None]

train_errs_by_bootstrap = np.mean(train_errs, axis=(1, 2))
sklearn_wrapper.models = sklearn_wrapper.models[
train_errs_by_bootstrap < tol * np.min(train_errs_by_bootstrap)
]
train_errs_min = np.min(train_errs_by_bootstrap)
sklearn_wrapper.models = [model for train_err, model in zip(train_errs_by_bootstrap, sklearn_wrapper.models) if train_err < train_errs_min * tol]
sklearn_wrapper.n_bootstraps = len(sklearn_wrapper.models)
return sklearn_wrapper
34 changes: 32 additions & 2 deletions contextualized/analysis/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
"""

import unittest
import copy
import numpy as np
import pandas as pd
from unittest.mock import MagicMock


from contextualized.analysis import (
test_each_context
test_each_context,
select_good_bootstraps
)

from contextualized.easy import ContextualizedRegressor
Expand Down Expand Up @@ -59,5 +61,33 @@ def test_expected_significant_pval(self):
self.assertTrue(all(pval >= 0.05 for pval in other_pvals['Pvals']), "Other p-values are significant.")


class TestSelectGoodBootstraps(unittest.TestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def setUp(self):
self.model = ContextualizedRegressor(n_bootstraps = 3)
C = np.random.uniform(0, 1, size=(100, 2))
X = np.random.uniform(0, 1, size=(100, 2))
Y = np.random.uniform(0, 1, size=(100, 2))
self.model.fit(C, X, Y)
Y_pred = self.model.predict(C, X, individual_preds = True)
self.train_errs = np.zeros_like((Y - Y_pred) ** 2)
self.train_errs[0] = 0.1
self.train_errs[1] = 0.2
self.train_errs[2] = 0.3
self.model_copy = copy.deepcopy(self.model)
select_good_bootstraps(self.model, self.train_errs)

def test_model_has_fewer_bootstraps(self):
"""
Test that the model has fewer bootstraps after calling select_good_bootstraps.
"""
self.assertEqual(len(self.model.models), 1)
self.assertEqual(len(self.model_copy.models), 3)
self.assertLess(len(self.model.models), len(self.model_copy.models))


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions docs/source/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ All functions can be loaded directly from the module, e.g. ``from contextualized
pvals.test_each_context
pvals.get_possible_pvals
accuracy_split.print_acc_by_covars
bootstraps.select_good_bootstraps
embeddings.plot_lowdim_rep
embeddings.plot_embedding_for_all_covars
effects.plot_homogeneous_context_effects
Expand All @@ -27,6 +28,7 @@ All functions can be loaded directly from the module, e.g. ``from contextualized
.. autofunction:: contextualized.analysis.pvals.test_each_context
.. autofunction:: contextualized.analysis.pvals.get_possible_pvals
.. autofunction:: contextualized.analysis.accuracy_split.print_acc_by_covars
.. autofunction:: contextualized.analysis.bootstraps.select_good_bootstraps
.. autofunction:: contextualized.analysis.embeddings.plot_lowdim_rep
.. autofunction:: contextualized.analysis.embeddings.plot_embedding_for_all_covars
.. autofunction:: contextualized.analysis.effects.plot_homogeneous_context_effects
Expand Down
1 change: 1 addition & 0 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from contextualized.regression.tests import *
from contextualized.dags.tests import *
from contextualized.dags.tests_fast import *
from contextualized.easy.tests.test_regressor import *
from contextualized.easy.tests.test_classifier import *
from contextualized.easy.tests.test_markov_networks import *
Expand Down

0 comments on commit b471c12

Please sign in to comment.