Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bootstrap inference #299

Merged
merged 2 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
dependsOn: 'EvalChanges'
condition: eq(dependencies.EvalChanges.outputs['output.buildNbs'], 'True')
variables:
python.version: '3.6'
python.version: '3.8'
pool:
vmImage: 'ubuntu-16.04'
steps:
Expand Down Expand Up @@ -142,7 +142,7 @@ jobs:
dependsOn: 'EvalChanges'
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
variables:
python.version: '3.6'
python.version: '3.8'
pool:
vmImage: 'macOS-10.15'
steps:
Expand All @@ -158,15 +158,6 @@ jobs:
condition: eq(dependencies.EvalChanges.outputs['output.testCode'], 'True')
strategy:
matrix:
Linux, Python 3.5:
imageName: 'ubuntu-16.04'
python.version: '3.5'
macOS, Python 3.5:
imageName: 'macOS-10.15'
python.version: '3.5'
Windows, Python 3.5:
imageName: 'vs2017-win2016'
python.version: '3.5'
Linux, Python 3.6:
imageName: 'ubuntu-16.04'
python.version: '3.6'
Expand Down
96 changes: 55 additions & 41 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,18 @@ class BootstrapEstimator:
Whether to pass calls through to the underlying collection and return the mean. Setting this
to ``False`` can avoid ambiguities if the wrapped object itself has method names with an `_interval` suffix.

prefer_wrapped: bool, default: False
In case a method ending in '_interval' exists on the wrapped object, whether
that should be preferred (meaning this wrapper will compute the mean of it).
This option only affects behavior if `compute_means` is set to ``True``.

bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot'
Bootstrap method used to compute results. 'percentile' will result in using the empiracal CDF of
the replicated computations of the statistics. 'pivot' will also use the replicates but create a pivot
interval that also relies on the estimate over the entire dataset. 'normal' will instead compute an interval
assuming the replicates are normally distributed.
"""

def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False,
bootstrap_type='pivot'):
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, bootstrap_type='pivot'):
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
self._n_bootstrap_samples = n_bootstrap_samples
self._n_jobs = n_jobs
self._compute_means = compute_means
self._prefer_wrapped = prefer_wrapped
self._bootstrap_type = bootstrap_type
self._wrapped = wrapped

Expand Down Expand Up @@ -187,46 +180,74 @@ def call(lower=5, upper=95):
return call

def get_inference():
# can't import from econml.inference at top level without creating mutual dependencies
from .inference import EmpiricalInferenceResults
# can't import from econml.inference at top level without creating cyclical dependencies
from .inference import EmpiricalInferenceResults, NormalInferenceResults
from .cate_estimator import LinearModelFinalCateEstimatorDiscreteMixin

prefix = name[: - len("_inference")]

def fname_transformer(x):
return x

if prefix in ['const_marginal_effect', 'marginal_effect', 'effect']:
inf_type = 'effect'
elif prefix == 'coef_':
inf_type = 'coefficient'
if (hasattr(self._instances[0], 'cate_feature_names') and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have no test that calls coef__inference or summary() but such that the estimator has no cate_feature_names method!

We either need to add such a test and make sure the behavior is as we want:

  1. I see this here being problematic because we would ideally want when I call summary(feat_name=..), even if the method has no cate_feature_names method, that the feat_name list that I give, will appear in the summary table.

Or maybe even better: we make cate_feature_names a mandatory method for any estimator that inherits from LinearModelFinalInference class and the LinearModelFinalInferenceDiscrete class. In this case, summary should perform ok as we will always have cate_feature_names implemented, whenever summary is available.

callable(self._instances[0].cate_feature_names)):
def fname_transformer(x):
return self._instances[0].cate_feature_names(x)
elif prefix == 'intercept_':
inf_type = 'intercept'
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a test that checks that the attribute error is raised in this loop and nothing else happens.

raise AttributeError("Unsupported inference: " + name)

d_t = self._wrapped._d_t[0] if self._wrapped._d_t else 1
d_t = 1 if prefix == 'effect' else d_t
if prefix == 'effect' or (isinstance(self._wrapped, LinearModelFinalCateEstimatorDiscreteMixin) and
(inf_type == 'coefficient' or inf_type == 'intercept')):
d_t = 1
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1

def get_inference_nonparametric(kind):
can_call = callable(getattr(self._instances[0], prefix))

kind = self._bootstrap_type
if kind == 'percentile' or kind == 'pivot':
def get_dist(est, arr):
if kind == 'percentile':
return arr
elif kind == 'pivot':
return 2 * est - arr
else:
raise ValueError("Invalid kind, must be either 'percentile' or 'pivot'")
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
return proxy(callable(getattr(self._instances[0], prefix)), prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type, fname_transformer=None))

def get_inference_parametric():
pred = getattr(self._wrapped, prefix)
stderr = getattr(self, prefix + '_std')
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
pred_dist=None, fname_transformer=None)

return {'normal': get_inference_parametric,
'percentile': lambda: get_inference_nonparametric('percentile'),
'pivot': lambda: get_inference_nonparametric('pivot')}[self._bootstrap_type]

def get_result():
return proxy(can_call, prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type,
fname_transformer=fname_transformer))

# Note that inference results are always methods even if the inference is for a property
# (e.g. coef__inference() is a method but coef_ is a property)
# Therefore we must insert a lambda if getting inference for a non-callable
return get_result() if can_call else get_result

else:
assert kind == 'normal'

def normal_inference(*args, **kwargs):
pred = getattr(self._wrapped, prefix)
if can_call:
pred = pred(*args, **kwargs)
stderr = getattr(self, prefix + '_std')
if can_call:
stderr = stderr(*args, **kwargs)
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
fname_transformer=fname_transformer)

# If inference is for a property, create a fresh lambda to avoid passing args through
return normal_inference if can_call else lambda: normal_inference()

caught = None
m = None
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -236,22 +257,15 @@ def get_inference_parametric():
m = get_std
elif name.endswith("_inference"):
m = get_inference
if self._compute_means and self._prefer_wrapped:

# try to get interval/std first if appropriate,
# since we don't prefer a wrapped method with this name
if m is not None:
try:
return get_mean()
return m()
except AttributeError as err:
caught = err
if m is not None:
m()
else:
# try to get interval/std first if appropriate,
# since we don't prefer a wrapped method with this name
if m is not None:
try:
return m()
except AttributeError as err:
caught = err
if self._compute_means:
return get_mean()
if self._compute_means:
return get_mean()

raise (caught if caught else AttributeError(name))
60 changes: 54 additions & 6 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import deepcopy
from warnings import warn
from .inference import BootstrapInference
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot
from .utilities import tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot, Summary
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults

Expand Down Expand Up @@ -585,7 +585,6 @@ def intercept__inference(self):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect.
Expand All @@ -608,7 +607,34 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass
smry = Summary()
d_t = self._d_t[0] if self._d_t else 1
d_y = self._d_y[0] if self._d_y else 1
try:
coef_table = self.coef__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = [i + '\n' +
j for (i, j) in coef_table.columns] if d_t > 1 else coef_table.columns.tolist()
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if d_y > 1 else coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", str(e))
try:
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = [i + '\n' + j for (i, j)
in intercept_table.columns] if d_t > 1 else intercept_table.columns.tolist()
intercept_stubs = [i + ' | ' + j for (i, j)
in intercept_table.index] if d_y > 1 else intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
print("Intercept Results: ", str(e))
if len(smry.tables) > 0:
return smry


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
Expand Down Expand Up @@ -684,7 +710,7 @@ def intercept_(self, T):
_, T = self._expand_treatments(None, T)
ind = inverse_onehot(T).item() - 1
assert ind >= 0, "No model was fitted for the control"
return self.fitted_models_final[ind].intercept_
return self.fitted_models_final[ind].intercept_.reshape(self._d_y)

@BaseCateEstimator._defer_to_inference
def coef__interval(self, T, *, alpha=0.1):
Expand Down Expand Up @@ -761,7 +787,6 @@ def intercept__inference(self, T):
"""
pass

@BaseCateEstimator._defer_to_inference
def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect associated with treatment T.
Expand All @@ -784,7 +809,30 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None):
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
pass
smry = Summary()
try:
coef_table = self.coef__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=feat_name)
coef_array = coef_table.values
coef_headers = coef_table.columns.tolist()
coef_stubs = coef_table.index.tolist()
coef_title = 'Coefficient Results'
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
except Exception as e:
print("Coefficient Results: ", e)
try:
intercept_table = self.intercept__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=None)
intercept_array = intercept_table.values
intercept_headers = intercept_table.columns.tolist()
intercept_stubs = intercept_table.index.tolist()
intercept_title = 'Intercept Results'
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
except Exception as e:
kbattocchi marked this conversation as resolved.
Show resolved Hide resolved
print("Intercept Results: ", e)

if len(smry.tables) > 0:
return smry


class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscreteMixin):
Expand Down
Loading