Skip to content

Commit

Permalink
reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
JanTeichertKluge committed Dec 8, 2023
1 parent bb9f94f commit 432ccc5
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import pytest
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLData
from doubleml.datasets import make_irm_data
from doubleml.utils import dummy_regressor, dummy_classifier

df_irm = make_irm_data(n_obs=500, dim_x=20, theta=0.5, return_type="DataFrame")
df_irm = make_irm_data(n_obs=10, dim_x=2, theta=0.5, return_type="DataFrame")
ext_predictions = {"d": {}}


Expand Down
5 changes: 3 additions & 2 deletions doubleml/tests/test_pliv_external_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import math
from sklearn.linear_model import LinearRegression, LassoCV
from sklearn.linear_model import LinearRegression
from doubleml import DoubleMLPLIV, DoubleMLData
from doubleml.datasets import make_pliv_CHS2015
from doubleml.utils import dummy_regressor
Expand Down Expand Up @@ -32,6 +32,7 @@ def adapted_doubleml_fixture(score, dml_procedure, n_rep, dim_z):
# IV-type score only allows dim_z = 1, so skip testcases with dim_z > 1 for IV-type score
if dim_z > 1 and score == "IV-type":
pytest.skip("IV-type score only allows dim_z = 1")
res_dict = None
else:
ext_predictions = {"d": {}}

Expand Down Expand Up @@ -86,7 +87,7 @@ def adapted_doubleml_fixture(score, dml_procedure, n_rep, dim_z):

res_dict = {"coef_normal": DMLPLIV.coef, "coef_ext": DMLPLIV_ext.coef}

return res_dict
return res_dict


@pytest.mark.ci
Expand Down
29 changes: 22 additions & 7 deletions doubleml/tests/test_pq_external_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.linear_model import LogisticRegression
from doubleml import DoubleMLPQ, DoubleMLData
from doubleml.datasets import make_irm_data
from doubleml.utils import dummy_regressor, dummy_classifier
from doubleml.utils import dummy_classifier
from ._utils import draw_smpls


Expand All @@ -27,6 +27,7 @@ def normalize_ipw(request):
def set_ml_m_ext(request):
return request.param


@pytest.fixture(scope="module", params=[True, False])
def set_ml_g_ext(request):
return request.param
Expand All @@ -36,7 +37,7 @@ def set_ml_g_ext(request):
def doubleml_pq_fixture(dml_procedure, n_rep, normalize_ipw, set_ml_m_ext, set_ml_g_ext):
ext_predictions = {"d": {}}
np.random.seed(3141)
data = make_irm_data(theta=0.5, n_obs=1000, dim_x=5, return_type="DataFrame")
data = make_irm_data(theta=1, n_obs=500, dim_x=5, return_type="DataFrame")

dml_data = DoubleMLData(data, "y", "d")
all_smpls = draw_smpls(len(dml_data.y), 5, n_rep=n_rep, groups=None)
Expand All @@ -47,7 +48,7 @@ def doubleml_pq_fixture(dml_procedure, n_rep, normalize_ipw, set_ml_m_ext, set_m
"n_rep": n_rep,
"dml_procedure": dml_procedure,
"normalize_ipw": normalize_ipw,
"draw_sample_splitting": False
"draw_sample_splitting": False,
}

ml_m = LogisticRegression(random_state=42)
Expand All @@ -63,24 +64,38 @@ def doubleml_pq_fixture(dml_procedure, n_rep, normalize_ipw, set_ml_m_ext, set_m
ml_m = dummy_classifier()
else:
ml_m = LogisticRegression(random_state=42)

if set_ml_g_ext:
ext_predictions["d"]["ml_g"] = DMLPQ.predictions["ml_g"][:, :, 0]
ml_g = dummy_classifier()
else:
ml_g = LogisticRegression(random_state=42)

DMLPLQ_ext = DoubleMLPQ(ml_g = ml_g, ml_m = ml_m, **kwargs)
DMLPLQ_ext = DoubleMLPQ(ml_g=ml_g, ml_m=ml_m, **kwargs)
DMLPLQ_ext.set_sample_splitting(all_smpls)

np.random.seed(3141)
DMLPLQ_ext.fit(external_predictions=ext_predictions)

res_dict = {"coef_normal": DMLPQ.coef, "coef_ext": DMLPLQ_ext.coef}
if set_ml_m_ext and not set_ml_g_ext:
# adjust tolerance for the case that ml_m is set to external predictions
# because no preliminary results are available for ml_m, the model use the (external) final predictions for ml_m
tol_rel = 0.1
tol_abs = 0.1
else:
tol_rel = 1e-9
tol_abs = 1e-4

res_dict = {"coef_normal": DMLPQ.coef, "coef_ext": DMLPLQ_ext.coef, "tol_rel": tol_rel, "tol_abs": tol_abs}

return res_dict


@pytest.mark.ci
def test_doubleml_pq_coef(doubleml_pq_fixture):
assert math.isclose(doubleml_pq_fixture["coef_normal"], doubleml_pq_fixture["coef_ext"], rel_tol=1e-9, abs_tol=1e-4)
assert math.isclose(
doubleml_pq_fixture["coef_normal"],
doubleml_pq_fixture["coef_ext"],
rel_tol=doubleml_pq_fixture["tol_rel"],
abs_tol=doubleml_pq_fixture["tol_abs"],
)
36 changes: 36 additions & 0 deletions doubleml/utils/dummy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@


class dummy_regressor(BaseEstimator):
"""
A dummy regressor that raises an AttributeError when attempting to access
its fit, predict, or set_params methods.
Attributes
----------
_estimator_type : str
Type of the estimator, set to "regressor".
Methods
-------
fit(*args)
Raises AttributeError: "Accessed fit method of DummyRegressor!"
predict(*args)
Raises AttributeError: "Accessed predict method of DummyRegressor!"
set_params(*args)
Raises AttributeError: "Accessed set_params method of DummyRegressor!"
"""

_estimator_type = "regressor"

def fit(*args):
Expand All @@ -15,6 +32,25 @@ def set_params(*args):


class dummy_classifier(BaseEstimator):
"""
A dummy classifier that raises an AttributeError when attempting to access
its fit, predict, set_params, or predict_proba methods.
Attributes
----------
_estimator_type : str
Type of the estimator, set to "classifier".
Methods
-------
fit(*args)
Raises AttributeError: "Accessed fit method of DummyClassifier!"
predict(*args)
Raises AttributeError: "Accessed predict method of DummyClassifier!"
set_params(*args)
Raises AttributeError: "Accessed set_params method of DummyClassifier!"
predict_proba(*args, **kwargs)
Raises AttributeError: "Accessed predict_proba method of DummyClassifier!"
"""

_estimator_type = "classifier"

def fit(*args):
Expand Down

0 comments on commit 432ccc5

Please sign in to comment.