diff --git a/doubleml/double_ml_irm.py b/doubleml/double_ml_irm.py index 6ca09dd6..6a33c42e 100644 --- a/doubleml/double_ml_irm.py +++ b/doubleml/double_ml_irm.py @@ -397,7 +397,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_ return res - def cate(self, basis): + def cate(self, basis, is_gate=False): """ Calculate conditional average treatment effects (CATE) for a given basis. @@ -406,6 +406,9 @@ def cate(self, basis): basis : :class:`pandas.DataFrame` The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of predictors. + is_gate : bool + Indicates whether the basis is constructed for GATEs (dummy-basis). + Default is ``False``. Returns ------- @@ -424,8 +427,8 @@ def cate(self, basis): # define the orthogonal signal orth_signal = self.psi_elements['psi_b'].reshape(-1) # fit the best linear predictor - model = DoubleMLBLP(orth_signal, basis=basis).fit() - + model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate) + model.fit() return model def gate(self, groups): @@ -444,15 +447,6 @@ def gate(self, groups): model : :class:`doubleML.DoubleMLBLP` Best linear Predictor model for Group Effects. """ - valid_score = ['ATE'] - if self.score not in valid_score: - raise ValueError('Invalid score ' + self.score + '. ' + - 'Valid score ' + ' or '.join(valid_score) + '.') - - if self.n_rep != 1: - raise NotImplementedError('Only implemented for one repetition. ' + - f'Number of repetitions is {str(self.n_rep)}.') - if not isinstance(groups, pd.DataFrame): raise TypeError('Groups must be of DataFrame type. ' f'Groups of type {str(type(groups))} was passed.') @@ -467,11 +461,7 @@ def gate(self, groups): if any(groups.sum(0) <= 5): warnings.warn('At least one group effect is estimated with less than 6 observations.') - # define the orthogonal signal - orth_signal = self.psi_elements['psi_b'].reshape(-1) - # fit the best linear predictor for GATE (different confint() method) - model = DoubleMLBLP(orth_signal, basis=groups, is_gate=True).fit() - + model = self.cate(groups, is_gate=True) return model def policy_tree(self, features, depth=2, **tree_params): diff --git a/doubleml/double_ml_plr.py b/doubleml/double_ml_plr.py index e82fbaf2..acd98f26 100644 --- a/doubleml/double_ml_plr.py +++ b/doubleml/double_ml_plr.py @@ -330,7 +330,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_ return res - def cate(self, basis): + def cate(self, basis, is_gate=False): """ Calculate conditional average treatment effects (CATE) for a given basis. @@ -339,6 +339,9 @@ def cate(self, basis): basis : :class:`pandas.DataFrame` The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of predictors. + is_gate : bool + Indicates whether the basis is constructed for GATEs (dummy-basis). + Default is ``False``. Returns ------- @@ -358,7 +361,7 @@ def cate(self, basis): model = DoubleMLBLP( orth_signal=Y_tilde.reshape(-1), basis=D_basis, - is_gate=False, + is_gate=is_gate, ) model.fit() return model @@ -376,15 +379,9 @@ def gate(self, groups): Returns ------- - model : :class:`doubleML.DoubleMLBLPGATE` + model : :class:`doubleML.DoubleMLBLP` Best linear Predictor model for Group Effects. """ - if self._dml_data.n_treat > 1: - raise NotImplementedError('Only implemented for single treatment. ' + - f'Number of treatments is {str(self._dml_data.n_treat)}.') - if self.n_rep != 1: - raise NotImplementedError('Only implemented for one repetition. ' + - f'Number of repetitions is {str(self.n_rep)}.') if not isinstance(groups, pd.DataFrame): raise TypeError('Groups must be of DataFrame type. ' @@ -398,16 +395,8 @@ def gate(self, groups): if any(groups.sum(0) <= 5): warnings.warn('At least one group effect is estimated with less than 6 observations.') - Y_tilde, D_tilde = self._partial_out() - D_basis = groups * D_tilde - # fit the best linear predictor for GATE (different confint() method) - model = DoubleMLBLP( - orth_signal=Y_tilde.reshape(-1), - basis=D_basis, - is_gate=True, - ) - model.fit() + model = self.cate(groups, is_gate=True) return model def _partial_out(self): diff --git a/doubleml/tests/test_doubleml_exceptions.py b/doubleml/tests/test_doubleml_exceptions.py index c9dd9e21..b4b6125d 100644 --- a/doubleml/tests/test_doubleml_exceptions.py +++ b/doubleml/tests/test_doubleml_exceptions.py @@ -1268,10 +1268,11 @@ def test_doubleml_exception_gate(): msg = "Groups must be of DataFrame type. Groups of type was passed." with pytest.raises(TypeError, match=msg): dml_irm_obj.gate(groups=2) + groups = pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3))) msg = (r'Columns of groups must be of bool type or int type \(dummy coded\). ' 'Alternatively, groups should only contain one column.') with pytest.raises(TypeError, match=msg): - dml_irm_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3)))) + dml_irm_obj.gate(groups=groups) dml_irm_obj = DoubleMLIRM(dml_data_irm, ml_g=Lasso(), @@ -1280,10 +1281,10 @@ def test_doubleml_exception_gate(): n_folds=5, score='ATTE') dml_irm_obj.fit() - + groups = pd.DataFrame(np.random.choice([True, False], size=dml_data_irm.n_obs)) msg = 'Invalid score ATTE. Valid score ATE.' with pytest.raises(ValueError, match=msg): - dml_irm_obj.gate(groups=2) + dml_irm_obj.gate(groups=groups) dml_irm_obj = DoubleMLIRM(dml_data_irm, ml_g=Lasso(), @@ -1296,7 +1297,7 @@ def test_doubleml_exception_gate(): msg = 'Only implemented for one repetition. Number of repetitions is 2.' with pytest.raises(NotImplementedError, match=msg): - dml_irm_obj.gate(groups=2) + dml_irm_obj.gate(groups=groups) @pytest.mark.ci @@ -1360,17 +1361,6 @@ def test_doubleml_exception_plr_cate(): @pytest.mark.ci def test_doubleml_exception_plr_gate(): - dml_plr_obj = DoubleMLPLR(dml_data, - ml_l=Lasso(), - ml_m=Lasso(), - n_folds=2, - n_rep=2) - dml_plr_obj.fit() - - msg = 'Only implemented for one repetition. Number of repetitions is 2.' - with pytest.raises(NotImplementedError, match=msg): - dml_plr_obj.gate(groups=2) - dml_plr_obj = DoubleMLPLR(dml_data, ml_l=Lasso(), ml_m=Lasso(), @@ -1384,26 +1374,6 @@ def test_doubleml_exception_plr_gate(): 'Alternatively, groups should only contain one column.') with pytest.raises(TypeError, match=msg): dml_plr_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(dml_data.n_obs, 3)))) - dml_plr_obj = DoubleMLPLR(dml_data, - ml_l=Lasso(), - ml_m=Lasso(), - n_folds=2, - n_rep=1) - dml_plr_obj.fit(store_predictions=False) - msg = r'predictions are None. Call .fit\(store_predictions=True\) to store the predictions.' - with pytest.raises(ValueError, match=msg): - dml_plr_obj.gate(groups=pd.DataFrame(np.random.choice([True, False], (dml_data.n_obs, 2)))) - - dml_data_multiple_treat = DoubleMLData(dml_data.data, y_col="y", d_cols=['d', 'X1']) - dml_plr_obj_multiple = DoubleMLPLR(dml_data_multiple_treat, - ml_l=Lasso(), - ml_m=Lasso(), - n_folds=2) - dml_plr_obj_multiple.fit() - msg = 'Only implemented for single treatment. Number of treatments is 2.' - with pytest.raises(NotImplementedError, match=msg): - dml_plr_obj_multiple.gate(groups=pd.DataFrame(np.random.choice([True, False], (dml_data.n_obs, 2)))) - @pytest.mark.ci def test_double_ml_exception_evaluate_learner():