Skip to content

Commit

Permalink
Enable expanded inference interface for bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Jun 9, 2020
1 parent d130fca commit 23db5ad
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 37 deletions.
32 changes: 25 additions & 7 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,31 @@ def call(lower=5, upper=95):
return call

def get_inference():
raise NotImplementedError("The {0} method is not yet supported by bootstrap inference; "
"consider using a different inference method if available.".format(name))
# can't import from econml.inference at top level without creating mutual dependencies
from .inference import InferenceResults
# TODO: consider treating percentile bootstrap differently since we can work directly with
# the empirical distribution
prefix = name[: - len("_inference")]
if prefix in ['const_marginal_effect', 'effect']:
inf_type = 'effect'
elif prefix == 'coef_':
inf_type = 'coefficient'
elif prefix == 'intercept_':
inf_type = 'intercept'
else:
raise AttributeError("Unsupported inference: " + name)

def get_inference():
pred = getattr(self._wrapped, prefix)
stderr = getattr(self, prefix + '_std')
d_t = self._wrapped._d_t[0] if self._wrapped._d_t else 1
d_t = 1 if prefix == 'effect' else d_t
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1
return InferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
pred_dist=None, fname_transformer=None)

return get_inference

caught = None
m = None
Expand All @@ -202,11 +225,6 @@ def get_inference():
return m()
except AttributeError as err:
caught = err
if name.endswith("_inference"):
try:
return get_inference()
except AttributeError as err:
caught = err
if self._compute_means:
return get_mean()

Expand Down
40 changes: 37 additions & 3 deletions econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,50 @@ def fit(self, estimator, *args, **kwargs):
bootstrap_type=self._bootstrap_type)
est.fit(*args, **kwargs)
self._est = est
self._d_t = estimator._d_t
self._d_y = estimator._d_y
self.d_t = self._d_t[0] if self._d_t else 1
self.d_y = self._d_y[0] if self._d_y else 1

def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError()

m = getattr(self._est, name)
if name.endswith('_interval'): # convert alpha to lower/upper
def wrapped(*args, alpha=0.1, **kwargs):
return m(*args, lower=100 * alpha / 2, upper=100 * (1 - alpha / 2), **kwargs)
return wrapped
else:
return m

def wrapped(*args, alpha=0.1, **kwargs):
return m(*args, lower=100 * alpha / 2, upper=100 * (1 - alpha / 2), **kwargs)
return wrapped
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
smry = Summary()
try:
coef_table = self.coef__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = [i + '\n' +
j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist()
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", str(e))
try:
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = [i + '\n' + j for (i, j)
in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist()
intercept_stubs = [i + ' | ' + j for (i, j)
in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", str(e))
if len(smry.tables) > 0:
return smry


class GenericModelFinalInference(Inference):
Expand Down
12 changes: 0 additions & 12 deletions econml/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,3 @@ def test_stratify_orthoiv(self):
inference = BootstrapInference(n_bootstrap_samples=20)
est.fit(Y, T, Z, X=X, inference=inference)
est.const_marginal_effect_interval(X)

def test_inference_throws_helpful_error(self):
"""Test that we see that inference methods are not yet implemented"""
T = np.random.normal(size=(1000, 1))
Y = T + np.random.normal(size=(1000, 1))

opts = BootstrapInference(5, 2)

est = LinearDMLCateEstimator().fit(Y, T, inference=opts)

with self.assertRaises(NotImplementedError):
eff = est.const_marginal_effect_inference()
33 changes: 18 additions & 15 deletions econml/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sklearn.base import clone
from sklearn.preprocessing import PolynomialFeatures
from econml.dml import LinearDMLCateEstimator
from econml.inference import BootstrapInference


class TestInference(unittest.TestCase):
Expand All @@ -26,21 +27,23 @@ def setUpClass(cls):
def test_inference_results(self):
"""Tests the inference results summary."""
# Test inference results when `cate_feature_names` doesn not exist
cate_est = LinearDMLCateEstimator(
featurizer=PolynomialFeatures(degree=1,
include_bias=False)
)
wrapped_est = self._NoFeatNamesEst(cate_est)
wrapped_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
inference='statsmodels'
)
summary_results = wrapped_est.summary()
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
np.testing.assert_array_equal(coef_rows, ['X{}'.format(i) for i in range(TestInference.d_x)])

for inference in [BootstrapInference(n_bootstrap_samples=5), 'statsmodels']:
cate_est = LinearDMLCateEstimator(
featurizer=PolynomialFeatures(degree=1,
include_bias=False)
)
wrapped_est = self._NoFeatNamesEst(cate_est)
wrapped_est.fit(
TestInference.Y,
TestInference.T,
TestInference.X,
TestInference.W,
inference=inference
)
summary_results = wrapped_est.summary()
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
np.testing.assert_array_equal(coef_rows, ['X{}'.format(i) for i in range(TestInference.d_x)])

class _NoFeatNamesEst:
def __init__(self, cate_est):
Expand Down

0 comments on commit 23db5ad

Please sign in to comment.