Skip to content

Commit

Permalink
Fix bootstrap inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Nov 6, 2020
1 parent 6fbc585 commit feb5bfc
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 204 deletions.
96 changes: 55 additions & 41 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,18 @@ class BootstrapEstimator:
Whether to pass calls through to the underlying collection and return the mean. Setting this
to ``False`` can avoid ambiguities if the wrapped object itself has method names with an `_interval` suffix.
prefer_wrapped: bool, default: False
In case a method ending in '_interval' exists on the wrapped object, whether
that should be preferred (meaning this wrapper will compute the mean of it).
This option only affects behavior if `compute_means` is set to ``True``.
bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot'
Bootstrap method used to compute results. 'percentile' will result in using the empiracal CDF of
the replicated computations of the statistics. 'pivot' will also use the replicates but create a pivot
interval that also relies on the estimate over the entire dataset. 'normal' will instead compute an interval
assuming the replicates are normally distributed.
"""

def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False,
bootstrap_type='pivot'):
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, bootstrap_type='pivot'):
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
self._n_bootstrap_samples = n_bootstrap_samples
self._n_jobs = n_jobs
self._compute_means = compute_means
self._prefer_wrapped = prefer_wrapped
self._bootstrap_type = bootstrap_type
self._wrapped = wrapped

Expand Down Expand Up @@ -187,46 +180,74 @@ def call(lower=5, upper=95):
return call

def get_inference():
# can't import from econml.inference at top level without creating mutual dependencies
from .inference import EmpiricalInferenceResults
# can't import from econml.inference at top level without creating cyclical dependencies
from .inference import EmpiricalInferenceResults, NormalInferenceResults
from .cate_estimator import LinearModelFinalCateEstimatorDiscreteMixin

prefix = name[: - len("_inference")]

def fname_transformer(x):
return x

if prefix in ['const_marginal_effect', 'marginal_effect', 'effect']:
inf_type = 'effect'
elif prefix == 'coef_':
inf_type = 'coefficient'
if (hasattr(self._instances[0], 'cate_feature_names') and
callable(self._instances[0].cate_feature_names)):
def fname_transformer(x):
return self._instances[0].cate_feature_names(x)
elif prefix == 'intercept_':
inf_type = 'intercept'
else:
raise AttributeError("Unsupported inference: " + name)

d_t = self._wrapped._d_t[0] if self._wrapped._d_t else 1
d_t = 1 if prefix == 'effect' else d_t
if prefix == 'effect' or (isinstance(self._wrapped, LinearModelFinalCateEstimatorDiscreteMixin) and
(inf_type == 'coefficient' or inf_type == 'intercept')):
d_t = 1
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1

def get_inference_nonparametric(kind):
can_call = callable(getattr(self._instances[0], prefix))

kind = self._bootstrap_type
if kind == 'percentile' or kind == 'pivot':
def get_dist(est, arr):
if kind == 'percentile':
return arr
elif kind == 'pivot':
return 2 * est - arr
else:
raise ValueError("Invalid kind, must be either 'percentile' or 'pivot'")
return proxy(callable(getattr(self._instances[0], prefix)), prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type, fname_transformer=None))

def get_inference_parametric():
pred = getattr(self._wrapped, prefix)
stderr = getattr(self, prefix + '_std')
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
pred_dist=None, fname_transformer=None)

return {'normal': get_inference_parametric,
'percentile': lambda: get_inference_nonparametric('percentile'),
'pivot': lambda: get_inference_nonparametric('pivot')}[self._bootstrap_type]

def get_result():
return proxy(can_call, prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type,
fname_transformer=fname_transformer))

# Note that inference results are always methods even if the inference is for a property
# (e.g. coef__inference() is a method but coef_ is a property)
# Therefore we must insert a lambda if getting inference for a non-callable
return get_result() if can_call else get_result

else:
assert kind == 'normal'

def normal_inference(*args, **kwargs):
pred = getattr(self._wrapped, prefix)
if can_call:
pred = pred(*args, **kwargs)
stderr = getattr(self, prefix + '_std')
if can_call:
stderr = stderr(*args, **kwargs)
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
fname_transformer=fname_transformer)

# If inference is for a property, create a fresh lambda to avoid passing args through
return normal_inference if can_call else lambda: normal_inference()

caught = None
m = None
Expand All @@ -236,22 +257,15 @@ def get_inference_parametric():
m = get_std
elif name.endswith("_inference"):
m = get_inference
if self._compute_means and self._prefer_wrapped:

# try to get interval/std first if appropriate,
# since we don't prefer a wrapped method with this name
if m is not None:
try:
return get_mean()
return m()
except AttributeError as err:
caught = err
if m is not None:
m()
else:
# try to get interval/std first if appropriate,
# since we don't prefer a wrapped method with this name
if m is not None:
try:
return m()
except AttributeError as err:
caught = err
if self._compute_means:
return get_mean()
if self._compute_means:
return get_mean()

raise (caught if caught else AttributeError(name))
60 changes: 54 additions & 6 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import deepcopy
from warnings import warn
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot, Summary
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults

Expand Down Expand Up @@ -585,7 +585,6 @@ def intercept__inference(self):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect.
Expand All @@ -608,7 +607,34 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass
smry = Summary()
d_t = self._d_t[0] if self._d_t else 1
d_y = self._d_y[0] if self._d_y else 1
try:
coef_table = self.coef__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = [i + '\n' +
j for (i, j) in coef_table.columns] if d_t > 1 else coef_table.columns.tolist()
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if d_y > 1 else coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", str(e))
try:
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = [i + '\n' + j for (i, j)
in intercept_table.columns] if d_t > 1 else intercept_table.columns.tolist()
intercept_stubs = [i + ' | ' + j for (i, j)
in intercept_table.index] if d_y > 1 else intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", str(e))
if len(smry.tables) > 0:
return smry


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
Expand Down Expand Up @@ -684,7 +710,7 @@ def intercept_(self, T):
_, T = self._expand_treatments(None, T)
ind = inverse_onehot(T).item() - 1
assert ind >= 0, "No model was fitted for the control"
return self.fitted_models_final[ind].intercept_
return self.fitted_models_final[ind].intercept_.reshape(self._d_y)

@BaseCateEstimator._defer_to_inference
def coef__interval(self, T, *, alpha=0.1):
Expand Down Expand Up @@ -761,7 +787,6 @@ def intercept__inference(self, T):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect associated with treatment T.
Expand All @@ -784,7 +809,30 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass
smry = Summary()
try:
coef_table = self.coef__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = coef_table.columns.tolist()
coef_stubs = coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", e)
try:
intercept_table = self.intercept__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = intercept_table.columns.tolist()
intercept_stubs = intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", e)

if len(smry.tables) > 0:
return smry


class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscreteMixin):
Expand Down
Loading

0 comments on commit feb5bfc

Please sign in to comment.