diff --git a/econml/bootstrap.py b/econml/bootstrap.py index 23cdeac50..d092f8fb2 100644 --- a/econml/bootstrap.py +++ b/econml/bootstrap.py @@ -44,32 +44,62 @@ class BootstrapEstimator: In case a method ending in '_interval' exists on the wrapped object, whether that should be preferred (meaning this wrapper will compute the mean of it). This option only affects behavior if `compute_means` is set to ``True``. + + stratify_treatment: bool, default False + Whether to stratify by treatment when calling fit; this will ensure that each stratum of treatment + is subsampled independently, so that each resample will have the same number of entries with each + treatment as the original sample did. """ - def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False): + def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, + compute_means=True, prefer_wrapped=False, stratify_treatment=False): self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)] self._n_bootstrap_samples = n_bootstrap_samples self._n_jobs = n_jobs self._compute_means = compute_means self._prefer_wrapped = prefer_wrapped + self._stratify_treatment = stratify_treatment # TODO: Add a __dir__ implementation? + def _stratified_indices(self, Y, T, *args, **kwargs): + assert 1 <= np.ndim(T) <= 2 + unique = np.unique(T, axis=0) + indices = [] + for el in unique: + ind, = np.where(np.all(T == el, axis=1) if np.ndim(T) == 2 else T == el) + indices.append(ind) + return indices + def fit(self, *args, **named_args): """ Fit the model. The full signature of this method is the same as that of the wrapped object's `fit` method. """ - n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0] - indices = np.random.choice(n_samples, size=(self._n_bootstrap_samples, n_samples), replace=True) + + if self._stratify_treatment: + index_chunks = self._stratified_indices(*args, **named_args) + else: + n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0] + index_chunks = [np.arange(n_samples)] # one chunk with all indices + + indices = [] + for chunk in index_chunks: + n_samples = len(chunk) + indices.append(chunk[np.random.choice(n_samples, + size=(self._n_bootstrap_samples, n_samples), + replace=True)]) + + indices = np.hstack(indices) def fit(x, *args, **kwargs): x.fit(*args, **kwargs) return x # Explicitly return x in case fit fails to return its target def convertArg(arg, inds): - return arg[inds] if arg is not None else None + return np.asarray(arg)[inds] if arg is not None else None + self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)( delayed(fit)(obj, *[convertArg(arg, inds) for arg in args], @@ -84,6 +114,11 @@ def __getattr__(self, name): Additionally, the suffix "_interval" is supported for getting an interval instead of a point estimate. """ + + # don't proxy special methods + if name.startswith('__'): + raise AttributeError(name) + def proxy(make_call, name, summary): def summarize_with(f): return summary(np.array(Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)( diff --git a/econml/inference.py b/econml/inference.py index 026aa196a..b18dc7c2c 100644 --- a/econml/inference.py +++ b/econml/inference.py @@ -52,7 +52,9 @@ def __init__(self, n_bootstrap_samples=100, n_jobs=-1): self._n_jobs = n_jobs def fit(self, estimator, *args, **kwargs): - est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False) + discrete_treatment = estimator._discrete_treatment if hasattr(estimator, '_discrete_treatment') else False + est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False, + stratify_treatment=discrete_treatment) est.fit(*args, **kwargs) self._est = est diff --git a/econml/tests/test_bootstrap.py b/econml/tests/test_bootstrap.py index cf165c83e..7e298554b 100644 --- a/econml/tests/test_bootstrap.py +++ b/econml/tests/test_bootstrap.py @@ -5,7 +5,7 @@ from econml.inference import BootstrapInference from econml.dml import LinearDMLCateEstimator from econml.two_stage_least_squares import NonparametricTwoStageLeastSquares -from sklearn.linear_model import LinearRegression +from sklearn.linear_model import LinearRegression, LogisticRegression from sklearn.preprocessing import PolynomialFeatures import numpy as np import unittest @@ -265,3 +265,18 @@ def test_internal_options(self): # TODO: test that the estimated effect is usually within the bounds # and that the true effect is also usually within the bounds + + def test_stratify(self): + """Test that we can properly stratify by treatment""" + T = [1, 0, 1, 2, 0, 2] + Y = [1, 2, 3, 4, 5, 6] + X = np.array([1, 1, 2, 2, 1, 2]).reshape(-1, 1) + est = LinearDMLCateEstimator(model_y=LinearRegression(), model_t=LogisticRegression(), discrete_treatment=True) + est.fit(Y, T, inference='bootstrap') + est.const_marginal_effect_interval() + + est.fit(Y, T, X=X, inference='bootstrap') + est.const_marginal_effect_interval(X) + + est.fit(Y, np.asarray(T).reshape(-1, 1), inference='bootstrap') # test stratifying 2D treatment + est.const_marginal_effect_interval() diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 4bc9ad53b..207d1167f 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -97,11 +97,7 @@ def make_random(is_discrete, d): model_t = LogisticRegression() if is_discrete else Lasso() - # TODO: add stratification to bootstrap so that we can use it - # even with discrete treatments - all_infs = [None, 'statsmodels'] - if not is_discrete: - all_infs.append(BootstrapInference(1)) + all_infs = [None, 'statsmodels', BootstrapInference(1)] for est, multi, infs in\ [(LinearDMLCateEstimator(model_y=Lasso(),