diff --git a/econml/bootstrap.py b/econml/bootstrap.py index 459c75687..309a2c213 100644 --- a/econml/bootstrap.py +++ b/econml/bootstrap.py @@ -176,8 +176,31 @@ def call(lower=5, upper=95): return call def get_inference(): - raise NotImplementedError("The {0} method is not yet supported by bootstrap inference; " - "consider using a different inference method if available.".format(name)) + # can't import from econml.inference at top level without creating mutual dependencies + from .inference import InferenceResults + # TODO: consider treating percentile bootstrap differently since we can work directly with + # the empirical distribution + prefix = name[: - len("_inference")] + if prefix in ['const_marginal_effect', 'effect']: + inf_type = 'effect' + elif prefix == 'coef_': + inf_type = 'coefficient' + elif prefix == 'intercept_': + inf_type = 'intercept' + else: + raise AttributeError("Unsupported inference: " + name) + + def get_inference(): + pred = getattr(self._wrapped, prefix) + stderr = getattr(self, prefix + '_std') + d_t = self._wrapped._d_t[0] if self._wrapped._d_t else 1 + d_t = 1 if prefix == 'effect' else d_t + d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1 + return InferenceResults(d_t=d_t, d_y=d_y, pred=pred, + pred_stderr=stderr, inf_type=inf_type, + pred_dist=None, fname_transformer=None) + + return get_inference caught = None m = None @@ -202,11 +225,6 @@ def get_inference(): return m() except AttributeError as err: caught = err - if name.endswith("_inference"): - try: - return get_inference() - except AttributeError as err: - caught = err if self._compute_means: return get_mean() diff --git a/econml/inference.py b/econml/inference.py index 55c42ea35..5d64f9436 100644 --- a/econml/inference.py +++ b/econml/inference.py @@ -62,16 +62,50 @@ def fit(self, estimator, *args, **kwargs): bootstrap_type=self._bootstrap_type) est.fit(*args, **kwargs) self._est = est + self._d_t = estimator._d_t + self._d_y = estimator._d_y + self.d_t = self._d_t[0] if self._d_t else 1 + self.d_y = self._d_y[0] if self._d_y else 1 def __getattr__(self, name): if name.startswith('__'): raise AttributeError() m = getattr(self._est, name) + if name.endswith('_interval'): # convert alpha to lower/upper + def wrapped(*args, alpha=0.1, **kwargs): + return m(*args, lower=100 * alpha / 2, upper=100 * (1 - alpha / 2), **kwargs) + return wrapped + else: + return m - def wrapped(*args, alpha=0.1, **kwargs): - return m(*args, lower=100 * alpha / 2, upper=100 * (1 - alpha / 2), **kwargs) - return wrapped + def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None): + smry = Summary() + try: + coef_table = self.coef__inference().summary_frame(alpha=alpha, + value=value, decimals=decimals, feat_name=feat_name) + coef_array = coef_table.values + coef_headers = [i + '\n' + + j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist() + coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist() + coef_title = 'Coefficient Results' + smry.add_table(coef_array, coef_headers, coef_stubs, coef_title) + except Exception as e: + print("Coefficient Results: ", str(e)) + try: + intercept_table = self.intercept__inference().summary_frame(alpha=alpha, + value=value, decimals=decimals, feat_name=None) + intercept_array = intercept_table.values + intercept_headers = [i + '\n' + j for (i, j) + in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist() + intercept_stubs = [i + ' | ' + j for (i, j) + in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist() + intercept_title = 'Intercept Results' + smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title) + except Exception as e: + print("Intercept Results: ", str(e)) + if len(smry.tables) > 0: + return smry class GenericModelFinalInference(Inference): diff --git a/econml/tests/test_bootstrap.py b/econml/tests/test_bootstrap.py index 8a5cf9923..91e667b15 100644 --- a/econml/tests/test_bootstrap.py +++ b/econml/tests/test_bootstrap.py @@ -295,15 +295,3 @@ def test_stratify_orthoiv(self): inference = BootstrapInference(n_bootstrap_samples=20) est.fit(Y, T, Z, X=X, inference=inference) est.const_marginal_effect_interval(X) - - def test_inference_throws_helpful_error(self): - """Test that we see that inference methods are not yet implemented""" - T = np.random.normal(size=(1000, 1)) - Y = T + np.random.normal(size=(1000, 1)) - - opts = BootstrapInference(5, 2) - - est = LinearDMLCateEstimator().fit(Y, T, inference=opts) - - with self.assertRaises(NotImplementedError): - eff = est.const_marginal_effect_inference() diff --git a/econml/tests/test_inference.py b/econml/tests/test_inference.py index 4426d6f75..3f67816c6 100644 --- a/econml/tests/test_inference.py +++ b/econml/tests/test_inference.py @@ -6,6 +6,7 @@ from sklearn.base import clone from sklearn.preprocessing import PolynomialFeatures from econml.dml import LinearDMLCateEstimator +from econml.inference import BootstrapInference class TestInference(unittest.TestCase): @@ -26,21 +27,23 @@ def setUpClass(cls): def test_inference_results(self): """Tests the inference results summary.""" # Test inference results when `cate_feature_names` doesn not exist - cate_est = LinearDMLCateEstimator( - featurizer=PolynomialFeatures(degree=1, - include_bias=False) - ) - wrapped_est = self._NoFeatNamesEst(cate_est) - wrapped_est.fit( - TestInference.Y, - TestInference.T, - TestInference.X, - TestInference.W, - inference='statsmodels' - ) - summary_results = wrapped_est.summary() - coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0] - np.testing.assert_array_equal(coef_rows, ['X{}'.format(i) for i in range(TestInference.d_x)]) + + for inference in [BootstrapInference(n_bootstrap_samples=5), 'statsmodels']: + cate_est = LinearDMLCateEstimator( + featurizer=PolynomialFeatures(degree=1, + include_bias=False) + ) + wrapped_est = self._NoFeatNamesEst(cate_est) + wrapped_est.fit( + TestInference.Y, + TestInference.T, + TestInference.X, + TestInference.W, + inference=inference + ) + summary_results = wrapped_est.summary() + coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0] + np.testing.assert_array_equal(coef_rows, ['X{}'.format(i) for i in range(TestInference.d_x)]) class _NoFeatNamesEst: def __init__(self, cate_est):