Skip to content

Commit

Permalink
Merge branch 'j-external-predictions' of https://github.com/DoubleML/…
Browse files Browse the repository at this point in the history
…doubleml-for-py into j-external-predictions
  • Loading branch information
JanTeichertKluge committed Dec 8, 2023
2 parents d3cdedf + 824aabc commit a6be4d0
Show file tree
Hide file tree
Showing 20 changed files with 2,494 additions and 1,966 deletions.
1,071 changes: 608 additions & 463 deletions doubleml/double_ml.py

Large diffs are not rendered by default.

330 changes: 174 additions & 156 deletions doubleml/double_ml_did.py

Large diffs are not rendered by default.

572 changes: 326 additions & 246 deletions doubleml/double_ml_did_cs.py

Large diffs are not rendered by default.

476 changes: 268 additions & 208 deletions doubleml/double_ml_iivm.py

Large diffs are not rendered by default.

380 changes: 197 additions & 183 deletions doubleml/double_ml_irm.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions doubleml/double_ml_lpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(
stratify=strata,
)
self._smpls = obj_dml_resampling.split_samples()

self._external_predictions_implemented = True

@property
Expand Down Expand Up @@ -385,9 +385,9 @@ def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models=Fa

# preliminary propensity for z
ml_m_z_prelim = clone(fitted_models["ml_m_z"][i_fold])
m_z_hat_prelim = _dml_cv_predict(ml_m_z_prelim, x_train_1, z_train_1, method="predict_proba", smpls=smpls_prelim)[
"preds"
]
m_z_hat_prelim = _dml_cv_predict(
ml_m_z_prelim, x_train_1, z_train_1, method="predict_proba", smpls=smpls_prelim
)["preds"]

m_z_hat_prelim = _trimm(m_z_hat_prelim, self.trimming_rule, self.trimming_threshold)
if self._normalize_ipw:
Expand Down
890 changes: 514 additions & 376 deletions doubleml/double_ml_pliv.py

Large diffs are not rendered by default.

305 changes: 166 additions & 139 deletions doubleml/double_ml_plr.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions doubleml/double_ml_pq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import copy
from sklearn.base import clone
from sklearn.utils import check_X_y
from sklearn.model_selection import StratifiedKFold, train_test_split
Expand Down Expand Up @@ -182,7 +181,10 @@ def __init__(
stratify=self._dml_data.d,
)
self._smpls = obj_dml_resampling.split_samples()

<<<<<<< HEAD
=======

>>>>>>> 5d59ac29a02b6034a9e550709069398ea177a30d
self._external_predictions_implemented = True

@property
Expand Down
341 changes: 179 additions & 162 deletions doubleml/double_ml_qte.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions doubleml/tests/test_did_external_predictions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np
import pytest
import math
from sklearn.linear_model import LinearRegression, LassoCV, LogisticRegression
from doubleml import DoubleMLData, DoubleMLDID
from sklearn.linear_model import LinearRegression, LogisticRegression
from doubleml import DoubleMLDID
from doubleml.datasets import make_did_SZ2020
from doubleml.utils import dummy_regressor, dummy_classifier
from ._utils import draw_smpls


@pytest.fixture(scope="module", params=["observational", "experimental"])
def did_score(request):
return request.param
Expand All @@ -32,7 +33,7 @@ def doubleml_did_fixture(did_score, dml_procedure, n_rep):
"score": did_score,
"n_rep": n_rep,
"dml_procedure": dml_procedure,
"draw_sample_splitting": False
"draw_sample_splitting": False,
}
DMLDID = DoubleMLDID(ml_g=LinearRegression(), ml_m=LogisticRegression(), **kwargs)
DMLDID.set_sample_splitting(all_smpls)
Expand Down
6 changes: 3 additions & 3 deletions doubleml/tests/test_didcs_external_predictions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest
import math
from sklearn.linear_model import LinearRegression, LassoCV, LogisticRegression
from doubleml import DoubleMLData, DoubleMLDIDCS
from sklearn.linear_model import LinearRegression, LogisticRegression
from doubleml import DoubleMLDIDCS
from doubleml.datasets import make_did_SZ2020
from doubleml.utils import dummy_regressor, dummy_classifier
from ._utils import draw_smpls
Expand Down Expand Up @@ -34,7 +34,7 @@ def doubleml_didcs_fixture(did_score, dml_procedure, n_rep):
"n_rep": n_rep,
"n_folds": 5,
"dml_procedure": dml_procedure,
"draw_sample_splitting": False
"draw_sample_splitting": False,
}
DMLDIDCS = DoubleMLDIDCS(ml_g=LinearRegression(), ml_m=LogisticRegression(), **kwargs)
DMLDIDCS.set_sample_splitting(all_smpls)
Expand Down
2 changes: 1 addition & 1 deletion doubleml/tests/test_doubleml_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import numpy as np

from doubleml import DoubleMLPLR, DoubleMLIRM, DoubleMLIIVM, DoubleMLPLIV, DoubleMLData,\
from doubleml import DoubleMLPLR, DoubleMLIRM, DoubleMLIIVM, DoubleMLPLIV, DoubleMLData, \
DoubleMLClusterData, DoubleMLPQ, DoubleMLLPQ, DoubleMLCVAR, DoubleMLQTE, DoubleMLDID, DoubleMLDIDCS
from doubleml.datasets import make_plr_CCDDHNR2018, make_irm_data, make_pliv_CHS2015, make_iivm_data, \
make_pliv_multiway_cluster_CKMS2021, make_did_SZ2020
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import pytest
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLData
from doubleml.datasets import make_irm_data
from doubleml.utils import dummy_regressor, dummy_classifier

df_irm = make_irm_data(n_obs=500, dim_x=20, theta=0.5, return_type="DataFrame")
df_irm = make_irm_data(n_obs=10, dim_x=10, theta=0.5, return_type="DataFrame")
ext_predictions = {"d": {}}


Expand Down
2 changes: 1 addition & 1 deletion doubleml/tests/test_dummy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def test_clone(dl_fixture):
try:
_ = clone(dl_fixture["dummy_regressor"])
_ = clone(dl_fixture["dummy_classifier"])
except Error as e:
except Exception as e:
pytest.fail(f"clone() raised an exception:\n{str(e)}\n")
9 changes: 3 additions & 6 deletions doubleml/tests/test_iivm_external_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import math
from sklearn.linear_model import LinearRegression, LassoCV, LogisticRegression
from sklearn.linear_model import LinearRegression, LogisticRegression
from doubleml import DoubleMLIIVM, DoubleMLData
from doubleml.datasets import make_iivm_data
from doubleml.utils import dummy_regressor, dummy_classifier
Expand All @@ -21,9 +21,7 @@ def n_rep(request):
def adapted_doubleml_fixture(dml_procedure, n_rep):
ext_predictions = {"d": {}}

data = make_iivm_data(
n_obs=500, dim_x=20, theta=0.5, alpha_x=1.0, return_type="DataFrame"
)
data = make_iivm_data(n_obs=500, dim_x=20, theta=0.5, alpha_x=1.0, return_type="DataFrame")

np.random.seed(3141)

Expand All @@ -45,14 +43,13 @@ def adapted_doubleml_fixture(dml_procedure, n_rep):
np.random.seed(3141)

DMLIIVM.fit(store_predictions=True)

ext_predictions["d"]["ml_g0"] = DMLIIVM.predictions["ml_g0"][:, :, 0]
ext_predictions["d"]["ml_g1"] = DMLIIVM.predictions["ml_g1"][:, :, 0]
ext_predictions["d"]["ml_m"] = DMLIIVM.predictions["ml_m"][:, :, 0]
ext_predictions["d"]["ml_r0"] = DMLIIVM.predictions["ml_r0"][:, :, 0]
ext_predictions["d"]["ml_r1"] = DMLIIVM.predictions["ml_r1"][:, :, 0]


DMLIIVM_ext = DoubleMLIIVM(
ml_g=dummy_regressor(), ml_m=dummy_classifier(), ml_r=dummy_classifier(), **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion doubleml/tests/test_lpq_external_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.linear_model import LogisticRegression
from doubleml import DoubleMLLPQ, DoubleMLData
from doubleml.datasets import make_iivm_data
from doubleml.utils import dummy_regressor, dummy_classifier
from doubleml.utils import dummy_classifier
from ._utils import draw_smpls


Expand Down
13 changes: 5 additions & 8 deletions doubleml/tests/test_pliv_external_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import math
from sklearn.linear_model import LinearRegression, LassoCV
from sklearn.linear_model import LinearRegression
from doubleml import DoubleMLPLIV, DoubleMLData
from doubleml.datasets import make_pliv_CHS2015
from doubleml.utils import dummy_regressor
Expand Down Expand Up @@ -32,12 +32,11 @@ def adapted_doubleml_fixture(score, dml_procedure, n_rep, dim_z):
# IV-type score only allows dim_z = 1, so skip testcases with dim_z > 1 for IV-type score
if dim_z > 1 and score == "IV-type":
pytest.skip("IV-type score only allows dim_z = 1")
res_dict = None
else:
ext_predictions = {"d": {}}

data = make_pliv_CHS2015(
n_obs=500, dim_x=20, alpha=0.5, dim_z=dim_z, return_type="DataFrame"
)
data = make_pliv_CHS2015(n_obs=500, dim_x=20, alpha=0.5, dim_z=dim_z, return_type="DataFrame")

np.random.seed(3141)

Expand Down Expand Up @@ -77,16 +76,14 @@ def adapted_doubleml_fixture(score, dml_procedure, n_rep, dim_z):
ml_m_key = "ml_m_" + "Z" + str(instr + 1)
ext_predictions["d"][ml_m_key] = DMLPLIV.predictions[ml_m_key][:, :, 0]

DMLPLIV_ext = DoubleMLPLIV(
ml_m=dummy_regressor(), ml_l=dummy_regressor(), ml_r=dummy_regressor(), **kwargs
)
DMLPLIV_ext = DoubleMLPLIV(ml_m=dummy_regressor(), ml_l=dummy_regressor(), ml_r=dummy_regressor(), **kwargs)

np.random.seed(3141)
DMLPLIV_ext.fit(external_predictions=ext_predictions)

res_dict = {"coef_normal": DMLPLIV.coef, "coef_ext": DMLPLIV_ext.coef}

return res_dict
return res_dict


@pytest.mark.ci
Expand Down
5 changes: 3 additions & 2 deletions doubleml/tests/test_pq_external_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.linear_model import LogisticRegression
from doubleml import DoubleMLPQ, DoubleMLData
from doubleml.datasets import make_irm_data
from doubleml.utils import dummy_regressor, dummy_classifier
from doubleml.utils import dummy_classifier
from ._utils import draw_smpls


Expand Down Expand Up @@ -79,7 +79,8 @@ def doubleml_pq_fixture(dml_procedure, n_rep, normalize_ipw, set_ml_m_ext, set_m

if set_ml_m_ext and not set_ml_g_ext:
# adjust tolerance for the case that ml_m is set to external predictions
# because no preliminary results are available for ml_m, the model use the (external) final predictions for ml_m
# because no preliminary results are available for ml_m,
# the model use the (external) final predictions for ml_m for calculating the ipw estimate
tol_rel = 0.1
tol_abs = 0.1
else:
Expand Down
32 changes: 32 additions & 0 deletions doubleml/utils/dummy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@


class dummy_regressor(BaseEstimator):
"""
A dummy regressor that raises an AttributeError when attempting to access
its fit, predict, or set_params methods.
Attributes
----------
_estimator_type : str
Type of the estimator, set to "regressor".
Methods
-------
fit(*args)
Raises AttributeError: "Accessed fit method of DummyRegressor!"
predict(*args)
Raises AttributeError: "Accessed predict method of DummyRegressor!"
set_params(*args)
Raises AttributeError: "Accessed set_params method of DummyRegressor!"
"""

_estimator_type = "regressor"

def fit(*args):
Expand All @@ -15,6 +32,21 @@ def set_params(*args):


class dummy_classifier(BaseEstimator):
"""
A dummy classifier that raises an AttributeError when attempting to access
its fit, predict, set_params, or predict_proba methods.
Attributes
----------
_estimator_type : str
Type of the estimator, set to "classifier".
predict(*args)
Raises AttributeError: "Accessed predict method of DummyClassifier!"
set_params(*args)
Raises AttributeError: "Accessed set_params method of DummyClassifier!"
predict_proba(*args, **kwargs)
Raises AttributeError: "Accessed predict_proba method of DummyClassifier!"
"""

_estimator_type = "classifier"

def fit(*args):
Expand Down

0 comments on commit a6be4d0

Please sign in to comment.