Skip to content

Commit

Permalink
Fix bootstrap inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Nov 4, 2020
1 parent 6fbc585 commit 9a153bc
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 156 deletions.
40 changes: 29 additions & 11 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,20 @@ def call(lower=5, upper=95):
return call

def get_inference():
# can't import from econml.inference at top level without creating mutual dependencies
from .inference import EmpiricalInferenceResults
# can't import from econml.inference at top level without creating cyclical dependencies
from .inference import EmpiricalInferenceResults, NormalInferenceResults

prefix = name[: - len("_inference")]
fname_transformer = None

if prefix in ['const_marginal_effect', 'marginal_effect', 'effect']:
inf_type = 'effect'
elif prefix == 'coef_':
inf_type = 'coefficient'
if (hasattr(self._instances[0], 'cate_feature_names') and
callable(self._instances[0].cate_feature_names)):
def fname_transformer(x):
return self._instances[0].cate_feature_names(x)
elif prefix == 'intercept_':
inf_type = 'intercept'
else:
Expand All @@ -204,6 +210,8 @@ def get_inference():
d_t = 1 if prefix == 'effect' else d_t
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1

can_call = callable(getattr(self._instances[0], prefix))

def get_inference_nonparametric(kind):
def get_dist(est, arr):
if kind == 'percentile':
Expand All @@ -212,21 +220,31 @@ def get_dist(est, arr):
return 2 * est - arr
else:
raise ValueError("Invalid kind, must be either 'percentile' or 'pivot'")
return proxy(callable(getattr(self._instances[0], prefix)), prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type, fname_transformer=None))

def get_inference_parametric():
result = proxy(can_call, prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type,
fname_transformer=fname_transformer))
# Note that inference results are always methods even if the inference is for a property
# (e.g. coef__inference() is a method but coef_ is a property)
# Therefore we must insert a lambda if getting inference for a non-callable
return result if can_call else lambda: result

def get_inference_parametric(*args, **kwargs):
pred = getattr(self._wrapped, prefix)
if can_call:
pred = pred(*args, **kwargs)
stderr = getattr(self, prefix + '_std')
if can_call:
stderr = stderr(*args, **kwargs)
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
pred_dist=None, fname_transformer=None)
fname_transformer=fname_transformer)

return {'normal': get_inference_parametric,
'percentile': lambda: get_inference_nonparametric('percentile'),
'pivot': lambda: get_inference_nonparametric('pivot')}[self._bootstrap_type]
return {'normal': get_inference_parametric if can_call else lambda: get_inference_parametric(),
'percentile': get_inference_nonparametric('percentile'),
'pivot': get_inference_nonparametric('pivot')}[self._bootstrap_type]

caught = None
m = None
Expand Down
58 changes: 53 additions & 5 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import deepcopy
from warnings import warn
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot, Summary
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults

Expand Down Expand Up @@ -585,7 +585,6 @@ def intercept__inference(self):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect.
Expand All @@ -608,7 +607,34 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass
smry = Summary()
d_t = self._d_t[0] if self._d_t else 1
d_y = self._d_y[0] if self._d_y else 1
try:
coef_table = self.coef__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = [i + '\n' +
j for (i, j) in coef_table.columns] if d_t > 1 else coef_table.columns.tolist()
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if d_y > 1 else coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", str(e))
try:
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = [i + '\n' + j for (i, j)
in intercept_table.columns] if d_t > 1 else intercept_table.columns.tolist()
intercept_stubs = [i + ' | ' + j for (i, j)
in intercept_table.index] if d_y > 1 else intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", str(e))
if len(smry.tables) > 0:
return smry


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
Expand Down Expand Up @@ -761,7 +787,6 @@ def intercept__inference(self, T):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect associated with treatment T.
Expand All @@ -784,7 +809,30 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass
smry = Summary()
try:
coef_table = self.coef__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = coef_table.columns.tolist()
coef_stubs = coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", e)
try:
intercept_table = self.intercept__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = intercept_table.columns.tolist()
intercept_stubs = intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", e)

if len(smry.tables) > 0:
return smry


class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscreteMixin):
Expand Down
76 changes: 13 additions & 63 deletions econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,37 +32,7 @@ def fit(self, estimator, *args, **kwargs):
pass


class _SummaryMixin:
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
smry = Summary()
try:
coef_table = self.coef__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = [i + '\n' +
j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist()
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", str(e))
try:
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = [i + '\n' + j for (i, j)
in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist()
intercept_stubs = [i + ' | ' + j for (i, j)
in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", str(e))
if len(smry.tables) > 0:
return smry


class BootstrapInference(_SummaryMixin, Inference):
class BootstrapInference(Inference):
"""
Inference instance to perform bootstrapping.
Expand Down Expand Up @@ -216,7 +186,7 @@ def effect_inference(self, X, *, T0, T1):
pred_stderr=e_stderr, inf_type='effect', fname_transformer=None)


class LinearModelFinalInference(_SummaryMixin, GenericModelFinalInference):
class LinearModelFinalInference(GenericModelFinalInference):
"""
Inference based on predict_interval of the model_final model. Assumes that estimator
class has a model_final method and that model is linear. Thus, the predict(cross_product(X, T1 - T0)) gives
Expand Down Expand Up @@ -468,32 +438,6 @@ def intercept__inference(self, T):
pred_stderr=self.fitted_models_final[ind].intercept_stderr_,
inf_type='intercept', fname_transformer=None)

def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
smry = Summary()
try:
coef_table = self.coef__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = coef_table.columns.tolist()
coef_stubs = coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", e)
try:
intercept_table = self.intercept__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = intercept_table.columns.tolist()
intercept_stubs = intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", e)

if len(smry.tables) > 0:
return smry


class StatsModelsInferenceDiscrete(LinearModelFinalInferenceDiscrete):
"""
Expand Down Expand Up @@ -893,7 +837,7 @@ class EmpiricalInferenceResults(InferenceResults):

def __init__(self, d_t, d_y, pred, pred_dist, inf_type, fname_transformer):
self.pred_dist = pred_dist
super().__init__(d_y, d_t, pred, inf_type, fname_transformer)
super().__init__(d_t, d_y, pred, inf_type, fname_transformer)

@property
def stderr(self):
Expand Down Expand Up @@ -930,7 +874,7 @@ def conf_int(self, alpha=0.1):
"""
lower = alpha / 2
upper = 1 - alpha / 2
return np.percentile(self.pred_dist, lower, axis=0), np.percentile(self.pred_dist, upper, axis=0)
return np.percentile(self.pred_dist, lower * 100, axis=0), np.percentile(self.pred_dist, upper * 100, axis=0)

def pvalue(self, value=0):
"""
Expand All @@ -949,7 +893,8 @@ def pvalue(self, value=0):
the corresponding singleton dimensions in the output will be collapsed
(e.g. if both are vectors, then the output of this method will also be a vector)
"""
return min((self.pred_dist < value).sum(), (self.pred_dist > value).sum()) / self.pred_dist.shape[0]
return np.minimum((self.pred_dist <= value).sum(axis=0),
(self.pred_dist >= value).sum(axis=0)) / self.pred_dist.shape[0]

def _expand_outputs(self, n_rows):
assert shape(self.pred)[0] == shape(self.pred_dist)[1] == 1
Expand Down Expand Up @@ -1216,11 +1161,16 @@ def _mixture_ppf(self, alpha, mean, stderr, tol):
"""
Helper function to get the confidence interval of mixture gaussian distribution
"""
done = False
# if stderr is zero, ppf will return nans and the loop below would never terminate
# so bail out early; note that it might be possible to correct the algorithm for
# this scenario, but since scipy's cdf returns nan whenever scale is zero it won't
# be clean
if (np.any(stderr == 0)):
return np.full(shape(mean)[1:], np.nan)
mix_ppf = scipy.stats.norm.ppf(alpha, loc=mean, scale=stderr)
lower = np.min(mix_ppf, axis=0)
upper = np.max(mix_ppf, axis=0)
while not done:
while True:
cur = (lower + upper) / 2
cur_mean = np.mean(scipy.stats.norm.cdf(cur, loc=mean, scale=stderr), axis=0)
if np.isscalar(cur):
Expand Down
34 changes: 34 additions & 0 deletions econml/tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,37 @@ def test_stratify_orthoiv(self):
inference = BootstrapInference(n_bootstrap_samples=20)
est.fit(Y, T, Z, X=X, inference=inference)
est.const_marginal_effect_interval(X)

def test_all_kinds(self):
T = [1, 0, 1, 2, 0, 2] * 5
Y = [1, 2, 3, 4, 5, 6] * 5
X = np.array([1, 1, 2, 2, 1, 2] * 5).reshape(-1, 1)
est = LinearDMLCateEstimator(n_splits=2)
for kind in ['percentile', 'pivot', 'normal']:
with self.subTest(kind=kind):
inference = BootstrapInference(n_bootstrap_samples=5, bootstrap_type=kind)
est.fit(Y, T, inference=inference)
i = est.const_marginal_effect_interval()
inf = est.const_marginal_effect_inference()
assert i[0].shape == i[1].shape == inf.point_estimate.shape
assert np.allclose(i[0], inf.conf_int()[0])
assert np.allclose(i[1], inf.conf_int()[1])

est.fit(Y, T, X=X, inference=inference)
i = est.const_marginal_effect_interval(X)
inf = est.const_marginal_effect_inference(X)
assert i[0].shape == i[1].shape == inf.point_estimate.shape
assert np.allclose(i[0], inf.conf_int()[0])
assert np.allclose(i[1], inf.conf_int()[1])

i = est.coef__interval()
inf = est.coef__inference()
assert i[0].shape == i[1].shape == inf.point_estimate.shape
assert np.allclose(i[0], inf.conf_int()[0])
assert np.allclose(i[1], inf.conf_int()[1])

i = est.effect_interval(X)
inf = est.effect_inference(X)
assert i[0].shape == i[1].shape == inf.point_estimate.shape
assert np.allclose(i[0], inf.conf_int()[0])
assert np.allclose(i[1], inf.conf_int()[1])
11 changes: 3 additions & 8 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def make_random(n, is_discrete, d):

model_t = LogisticRegression() if is_discrete else Lasso()

all_infs = [None, 'statsmodels', BootstrapInference(1)]
all_infs = [None, 'statsmodels', BootstrapInference(2)]

for est, multi, infs in\
[(DMLCateEstimator(model_y=Lasso(),
Expand Down Expand Up @@ -203,7 +203,7 @@ def make_random(n, is_discrete, d):
with pytest.raises(AttributeError):
self.assertEqual(shape(est.intercept__interval()),
(2,) + intercept_shape)
if inf in ['statsmodels', 'debiasedlasso', 'blb']:

const_marg_effect_inf = est.const_marginal_effect_inference(X)
T1 = np.full_like(T, 'b') if is_discrete else T
effect_inf = est.effect_inference(X, T0=T0, T1=T1)
Expand Down Expand Up @@ -269,12 +269,7 @@ def make_random(n, is_discrete, d):

# test coef__inference and intercept__inference
if not isinstance(est, KernelDMLCateEstimator):
if X is None:
cm = pytest.raises(AttributeError)
else:
cm = ExitStack()
# ExitStack can be used as a "do nothing" ContextManager
with cm:
if X is not None:
self.assertEqual(
shape(est.coef__inference().summary_frame()),
coef_summaryframe_shape)
Expand Down
Loading

0 comments on commit 9a153bc

Please sign in to comment.