diff --git a/econml/inference.py b/econml/inference.py index 2302d88a9..63f856877 100644 --- a/econml/inference.py +++ b/econml/inference.py @@ -12,6 +12,7 @@ from .utilities import (cross_product, broadcast_unit_treatments, reshape_treatmentwise_effects, ndim, shape, inverse_onehot, parse_final_model_params, _safe_norm_ppf, Summary, StatsModelsLinearRegression) +from warnings import warn """Options for performing inference in estimators.""" @@ -87,7 +88,7 @@ class GenericModelFinalInference(Inference): Inference based on predict_interval of the model_final model. Assumes that estimator class has a model_final method, whose predict(cross_product(X, [0, ..., 1, ..., 0])) gives the const_marginal_effect of the treamtnent at the column with value 1 and which also supports - predict_interval(X). + prediction_stderr(X). """ def prefit(self, estimator, *args, **kwargs): @@ -104,14 +105,7 @@ def fit(self, estimator, *args, **kwargs): self.d_y = self._d_y[0] if self._d_y else 1 def const_marginal_effect_interval(self, X, *, alpha=0.1): - if X is None: - X = np.ones((1, 1)) - elif self.featurizer is not None: - X = self.featurizer.transform(X) - X, T = broadcast_unit_treatments(X, self.d_t) - preds = self._predict_interval(cross_product(X, T), alpha=alpha) - return tuple(reshape_treatmentwise_effects(pred, self._d_t, self._d_y) - for pred in preds) + return self.const_marginal_effect_inference(X).conf_int(alpha=alpha) def const_marginal_effect_inference(self, X): if X is None: @@ -127,9 +121,6 @@ def const_marginal_effect_inference(self, X): return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred, pred_stderr=pred_stderr, inf_type='effect') - def _predict_interval(self, X, alpha): - return self.model_final.predict_interval(X, alpha=alpha) - def _predict(self, X): return self.model_final.predict(X) @@ -168,8 +159,6 @@ def effect_interval(self, X, *, T0, T1, alpha=0.1): def effect_inference(self, X, *, T0, T1): # We can write effect inference as a function of const_marginal_effect_inference for a single treatment X, T0, T1 = self._est._expand_treatments(X, T0, T1) - if (T0 == T1).all(): - raise AttributeError("T0 is the same as T1, please input different treatment!") cme_pred = self.const_marginal_effect_inference(X).point_estimate cme_stderr = self.const_marginal_effect_inference(X).stderr dT = T1 - T0 @@ -201,21 +190,27 @@ def fit(self, estimator, *args, **kwargs): self.bias_part_of_coef = estimator.bias_part_of_coef self.fit_cate_intercept = estimator.fit_cate_intercept + # replacing _predict of super to fend against misuse, when the user has used a final linear model with + # an intercept even when bias is part of coef. + def _predict(self, X): + intercept = 0 + if self.bias_part_of_coef: + intercept = self.model_final.predict(np.zeros((1, X.shape[1]))) + if np.any(np.abs(intercept) > 0): + warn("The final model has a nonzero intercept for at least one outcome; " + "it will be subtracted, but consider fitting a model without an intercept if possible. " + "Standard errors will also be slightly incorrect if the final model used fits an intercept " + "as they will be including the variance of the intercept parameter estimate.", + UserWarning) + return self.model_final.predict(X) - intercept + def effect_interval(self, X, *, T0, T1, alpha=0.1): - # We can write effect interval as a function of predict_interval of the final method for linear models - X, T0, T1 = self._est._expand_treatments(X, T0, T1) - if X is None: - X = np.ones((T0.shape[0], 1)) - elif self.featurizer is not None: - X = self.featurizer.transform(X) - return self._predict_interval(cross_product(X, T1 - T0), alpha=alpha) + return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha) def effect_inference(self, X, *, T0, T1): # We can write effect inference as a function of prediction and prediction standard error of # the final method for linear models X, T0, T1 = self._est._expand_treatments(X, T0, T1) - if (T0 == T1).all(): - raise AttributeError("T0 is the same as T1, please input different treatment!") if X is None: X = np.ones((T0.shape[0], 1)) elif self.featurizer is not None: @@ -373,10 +368,9 @@ def effect_interval(self, X, *, T0, T1, alpha=0.1): def effect_inference(self, X, *, T0, T1): X, T0, T1 = self._est._expand_treatments(X, T0, T1) - if (T0 == T1).all(): - raise AttributeError("T0 is the same with T1, please input different treatment!") - if np.any(np.any(T0 > 0, axis=1)): - raise AttributeError("Can only calculate inference of effects with respect to baseline treatment!") + if np.any(np.any(T0 > 0, axis=1)) or np.any(np.all(T1 == 0, axis=1)): + raise AttributeError("Can only calculate inference of effects between a non-baseline treatment " + "and the baseline treatment!") ind = inverse_onehot(T1) pred = self.const_marginal_effect_inference(X).point_estimate pred = np.concatenate([np.zeros(pred.shape[0:-1] + (1,)), pred], -1) diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 1e32a76b6..381b81902 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -790,21 +790,35 @@ def test_can_use_statsmodel_inference(self): def test_ignores_final_intercept(self): """Test that final model intercepts are ignored (with a warning)""" class InterceptModel: - def fit(Y, X): + def fit(self, Y, X): pass - def predict(X): + def predict(self, X): return X + 1 + def prediction_stderr(self, X): + return np.zeros(X.shape[0]) + # (incorrectly) use a final model with an intercept dml = DML(LinearRegression(), LinearRegression(), - model_final=InterceptModel) + model_final=InterceptModel()) # Because final model is fixed, actual values of T and Y don't matter t = np.random.normal(size=100) y = np.random.normal(size=100) with self.assertWarns(Warning): # we should warn whenever there's an intercept dml.fit(y, t) assert dml.const_marginal_effect() == 1 # coefficient on X in InterceptModel is 1 + assert dml.const_marginal_effect_inference().point_estimate == 1 + assert dml.const_marginal_effect_inference().conf_int() == (1, 1) + assert dml.const_marginal_effect_interval() == (1, 1) + assert dml.effect() == 1 + assert dml.effect_inference().point_estimate == 1 + assert dml.effect_inference().conf_int() == (1, 1) + assert dml.effect_interval() == (1, 1) + assert dml.marginal_effect(1) == 1 # coefficient on X in InterceptModel is 1 + assert dml.marginal_effect_inference(1).point_estimate == 1 + assert dml.marginal_effect_inference(1).conf_int() == (1, 1) + assert dml.marginal_effect_interval(1) == (1, 1) def test_sparse(self): for _ in range(5):