Skip to content

Commit

Permalink
rebase gate on cate method
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenKlaassen committed Nov 29, 2023
1 parent d5788e7 commit 062e879
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 70 deletions.
24 changes: 7 additions & 17 deletions doubleml/double_ml_irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_

return res

def cate(self, basis):
def cate(self, basis, is_gate=False):
"""
Calculate conditional average treatment effects (CATE) for a given basis.
Expand All @@ -406,6 +406,9 @@ def cate(self, basis):
basis : :class:`pandas.DataFrame`
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
is_gate : bool
Indicates whether the basis is constructed for GATEs (dummy-basis).
Default is ``False``.
Returns
-------
Expand All @@ -424,8 +427,8 @@ def cate(self, basis):
# define the orthogonal signal
orth_signal = self.psi_elements['psi_b'].reshape(-1)
# fit the best linear predictor
model = DoubleMLBLP(orth_signal, basis=basis).fit()

model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
model.fit()
return model

def gate(self, groups):
Expand All @@ -444,15 +447,6 @@ def gate(self, groups):
model : :class:`doubleML.DoubleMLBLP`
Best linear Predictor model for Group Effects.
"""
valid_score = ['ATE']
if self.score not in valid_score:
raise ValueError('Invalid score ' + self.score + '. ' +
'Valid score ' + ' or '.join(valid_score) + '.')

if self.n_rep != 1:
raise NotImplementedError('Only implemented for one repetition. ' +
f'Number of repetitions is {str(self.n_rep)}.')

if not isinstance(groups, pd.DataFrame):
raise TypeError('Groups must be of DataFrame type. '
f'Groups of type {str(type(groups))} was passed.')
Expand All @@ -467,11 +461,7 @@ def gate(self, groups):
if any(groups.sum(0) <= 5):
warnings.warn('At least one group effect is estimated with less than 6 observations.')

# define the orthogonal signal
orth_signal = self.psi_elements['psi_b'].reshape(-1)
# fit the best linear predictor for GATE (different confint() method)
model = DoubleMLBLP(orth_signal, basis=groups, is_gate=True).fit()

model = self.cate(groups, is_gate=True)
return model

def policy_tree(self, features, depth=2, **tree_params):
Expand Down
25 changes: 7 additions & 18 deletions doubleml/double_ml_plr.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_

return res

def cate(self, basis):
def cate(self, basis, is_gate=False):
"""
Calculate conditional average treatment effects (CATE) for a given basis.
Expand All @@ -339,6 +339,9 @@ def cate(self, basis):
basis : :class:`pandas.DataFrame`
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
is_gate : bool
Indicates whether the basis is constructed for GATEs (dummy-basis).
Default is ``False``.
Returns
-------
Expand All @@ -358,7 +361,7 @@ def cate(self, basis):
model = DoubleMLBLP(
orth_signal=Y_tilde.reshape(-1),
basis=D_basis,
is_gate=False,
is_gate=is_gate,
)
model.fit()
return model
Expand All @@ -376,15 +379,9 @@ def gate(self, groups):
Returns
-------
model : :class:`doubleML.DoubleMLBLPGATE`
model : :class:`doubleML.DoubleMLBLP`
Best linear Predictor model for Group Effects.
"""
if self._dml_data.n_treat > 1:
raise NotImplementedError('Only implemented for single treatment. ' +
f'Number of treatments is {str(self._dml_data.n_treat)}.')
if self.n_rep != 1:
raise NotImplementedError('Only implemented for one repetition. ' +
f'Number of repetitions is {str(self.n_rep)}.')

if not isinstance(groups, pd.DataFrame):
raise TypeError('Groups must be of DataFrame type. '
Expand All @@ -398,16 +395,8 @@ def gate(self, groups):

if any(groups.sum(0) <= 5):
warnings.warn('At least one group effect is estimated with less than 6 observations.')
Y_tilde, D_tilde = self._partial_out()

D_basis = groups * D_tilde
# fit the best linear predictor for GATE (different confint() method)
model = DoubleMLBLP(
orth_signal=Y_tilde.reshape(-1),
basis=D_basis,
is_gate=True,
)
model.fit()
model = self.cate(groups, is_gate=True)
return model

def _partial_out(self):
Expand Down
40 changes: 5 additions & 35 deletions doubleml/tests/test_doubleml_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,10 +1268,11 @@ def test_doubleml_exception_gate():
msg = "Groups must be of DataFrame type. Groups of type <class 'int'> was passed."
with pytest.raises(TypeError, match=msg):
dml_irm_obj.gate(groups=2)
groups = pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3)))
msg = (r'Columns of groups must be of bool type or int type \(dummy coded\). '
'Alternatively, groups should only contain one column.')
with pytest.raises(TypeError, match=msg):
dml_irm_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3))))
dml_irm_obj.gate(groups=groups)

dml_irm_obj = DoubleMLIRM(dml_data_irm,
ml_g=Lasso(),
Expand All @@ -1280,10 +1281,10 @@ def test_doubleml_exception_gate():
n_folds=5,
score='ATTE')
dml_irm_obj.fit()

groups = pd.DataFrame(np.random.choice([True, False], size=dml_data_irm.n_obs))
msg = 'Invalid score ATTE. Valid score ATE.'
with pytest.raises(ValueError, match=msg):
dml_irm_obj.gate(groups=2)
dml_irm_obj.gate(groups=groups)

dml_irm_obj = DoubleMLIRM(dml_data_irm,
ml_g=Lasso(),
Expand All @@ -1296,7 +1297,7 @@ def test_doubleml_exception_gate():

msg = 'Only implemented for one repetition. Number of repetitions is 2.'
with pytest.raises(NotImplementedError, match=msg):
dml_irm_obj.gate(groups=2)
dml_irm_obj.gate(groups=groups)


@pytest.mark.ci
Expand Down Expand Up @@ -1360,17 +1361,6 @@ def test_doubleml_exception_plr_cate():

@pytest.mark.ci
def test_doubleml_exception_plr_gate():
dml_plr_obj = DoubleMLPLR(dml_data,
ml_l=Lasso(),
ml_m=Lasso(),
n_folds=2,
n_rep=2)
dml_plr_obj.fit()

msg = 'Only implemented for one repetition. Number of repetitions is 2.'
with pytest.raises(NotImplementedError, match=msg):
dml_plr_obj.gate(groups=2)

dml_plr_obj = DoubleMLPLR(dml_data,
ml_l=Lasso(),
ml_m=Lasso(),
Expand All @@ -1384,26 +1374,6 @@ def test_doubleml_exception_plr_gate():
'Alternatively, groups should only contain one column.')
with pytest.raises(TypeError, match=msg):
dml_plr_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(dml_data.n_obs, 3))))
dml_plr_obj = DoubleMLPLR(dml_data,
ml_l=Lasso(),
ml_m=Lasso(),
n_folds=2,
n_rep=1)
dml_plr_obj.fit(store_predictions=False)
msg = r'predictions are None. Call .fit\(store_predictions=True\) to store the predictions.'
with pytest.raises(ValueError, match=msg):
dml_plr_obj.gate(groups=pd.DataFrame(np.random.choice([True, False], (dml_data.n_obs, 2))))

dml_data_multiple_treat = DoubleMLData(dml_data.data, y_col="y", d_cols=['d', 'X1'])
dml_plr_obj_multiple = DoubleMLPLR(dml_data_multiple_treat,
ml_l=Lasso(),
ml_m=Lasso(),
n_folds=2)
dml_plr_obj_multiple.fit()
msg = 'Only implemented for single treatment. Number of treatments is 2.'
with pytest.raises(NotImplementedError, match=msg):
dml_plr_obj_multiple.gate(groups=pd.DataFrame(np.random.choice([True, False], (dml_data.n_obs, 2))))


@pytest.mark.ci
def test_double_ml_exception_evaluate_learner():
Expand Down

0 comments on commit 062e879

Please sign in to comment.