Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearModelFinalInference was not subtracting intercept when model_final has one #318

Merged
merged 2 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .utilities import (cross_product, broadcast_unit_treatments, reshape_treatmentwise_effects,
ndim, shape, inverse_onehot, parse_final_model_params, _safe_norm_ppf, Summary,
StatsModelsLinearRegression)
from warnings import warn


"""Options for performing inference in estimators."""
Expand Down Expand Up @@ -87,7 +88,7 @@ class GenericModelFinalInference(Inference):
Inference based on predict_interval of the model_final model. Assumes that estimator
class has a model_final method, whose predict(cross_product(X, [0, ..., 1, ..., 0])) gives
the const_marginal_effect of the treamtnent at the column with value 1 and which also supports
predict_interval(X).
prediction_stderr(X).
"""

def prefit(self, estimator, *args, **kwargs):
Expand All @@ -104,14 +105,7 @@ def fit(self, estimator, *args, **kwargs):
self.d_y = self._d_y[0] if self._d_y else 1

def const_marginal_effect_interval(self, X, *, alpha=0.1):
if X is None:
X = np.ones((1, 1))
elif self.featurizer is not None:
X = self.featurizer.transform(X)
X, T = broadcast_unit_treatments(X, self.d_t)
preds = self._predict_interval(cross_product(X, T), alpha=alpha)
return tuple(reshape_treatmentwise_effects(pred, self._d_t, self._d_y)
for pred in preds)
return self.const_marginal_effect_inference(X).conf_int(alpha=alpha)

def const_marginal_effect_inference(self, X):
if X is None:
Expand All @@ -127,9 +121,6 @@ def const_marginal_effect_inference(self, X):
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
pred_stderr=pred_stderr, inf_type='effect')

def _predict_interval(self, X, alpha):
return self.model_final.predict_interval(X, alpha=alpha)

def _predict(self, X):
return self.model_final.predict(X)

Expand Down Expand Up @@ -201,14 +192,22 @@ def fit(self, estimator, *args, **kwargs):
self.bias_part_of_coef = estimator.bias_part_of_coef
self.fit_cate_intercept = estimator.fit_cate_intercept

# replacing _predict of super to fend against misuse, when the user has used a final linear model with
# an intercept even when bias is part of coef.
def _predict(self, X):
intercept = 0
if self.bias_part_of_coef:
intercept = self.model_final.predict(np.zeros((1, X.shape[1])))
if np.any(np.abs(intercept) > 0):
warn("The final model has a nonzero intercept for at least one outcome; "
"it will be subtracted, but consider fitting a model without an intercept if possible. "
"Standard errors will also be slightly incorrect if the final model used fits an intercept "
"as they will be including the variance of the intercept parameter estimate.",
UserWarning)
return self.model_final.predict(X) - intercept

def effect_interval(self, X, *, T0, T1, alpha=0.1):
# We can write effect interval as a function of predict_interval of the final method for linear models
X, T0, T1 = self._est._expand_treatments(X, T0, T1)
if X is None:
X = np.ones((T0.shape[0], 1))
elif self.featurizer is not None:
X = self.featurizer.transform(X)
return self._predict_interval(cross_product(X, T1 - T0), alpha=alpha)
return self.effect_inference(X, T0=T0, T1=T1).conf_int(alpha=alpha)

def effect_inference(self, X, *, T0, T1):
# We can write effect inference as a function of prediction and prediction standard error of
Expand Down
20 changes: 17 additions & 3 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,21 +790,35 @@ def test_can_use_statsmodel_inference(self):
def test_ignores_final_intercept(self):
"""Test that final model intercepts are ignored (with a warning)"""
class InterceptModel:
def fit(Y, X):
def fit(self, Y, X):
pass

def predict(X):
def predict(self, X):
return X + 1

def prediction_stderr(self, X):
return np.zeros(X.shape[0])

# (incorrectly) use a final model with an intercept
dml = DML(LinearRegression(), LinearRegression(),
model_final=InterceptModel)
model_final=InterceptModel())
# Because final model is fixed, actual values of T and Y don't matter
t = np.random.normal(size=100)
y = np.random.normal(size=100)
with self.assertWarns(Warning): # we should warn whenever there's an intercept
dml.fit(y, t)
assert dml.const_marginal_effect() == 1 # coefficient on X in InterceptModel is 1
assert dml.const_marginal_effect_inference().point_estimate == 1
assert dml.const_marginal_effect_inference().conf_int() == (1, 1)
assert dml.const_marginal_effect_interval() == (1, 1)
assert dml.effect() == 1
assert dml.effect_inference().point_estimate == 1
assert dml.effect_inference().conf_int() == (1, 1)
assert dml.effect_interval() == (1, 1)
assert dml.marginal_effect(1) == 1 # coefficient on X in InterceptModel is 1
assert dml.marginal_effect_inference(1).point_estimate == 1
assert dml.marginal_effect_inference(1).conf_int() == (1, 1)
assert dml.marginal_effect_interval(1) == (1, 1)

def test_sparse(self):
for _ in range(5):
Expand Down