Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jul 28, 2024
1 parent 999bb16 commit 9afe9ac
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tests/test_meta/test_decay_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_decay_weight(mod, is_clf, decay_func, decay_kwargs):
if is_clf:
y = (y < 0).astype(int)

mod = DecayEstimator(mod, decay_func=decay_func, **decay_kwargs).fit(X, y)
mod = DecayEstimator(mod, decay_func=decay_func, decay_kwargs=decay_kwargs).fit(X, y)

assert np.logical_and(mod.weights_ >= 0, mod.weights_ <= 1).all()
assert np.all(mod.weights_[:-1] <= mod.weights_[1:])
Expand Down
22 changes: 12 additions & 10 deletions tests/test_meta/test_grouped_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_constant_shrinkage(shrinkage_data):
["Planet", "Country", "City"],
shrinkage="constant",
use_global_model=False,
alpha=0.1,
shrinkage_kwargs={"alpha": 0.1},
)

shrinkage_factors = np.array([0.01, 0.09, 0.9])
Expand Down Expand Up @@ -304,7 +304,7 @@ def test_min_n_obs_shrinkage(shrinkage_data):
["Planet", "Country", "City"],
shrinkage="min_n_obs",
use_global_model=False,
min_n_obs=2,
shrinkage_kwargs={"min_n_obs": 2},
)

shrink_est.fit(X, y)
Expand All @@ -327,7 +327,7 @@ def test_min_n_obs_shrinkage_too_little_obs(shrinkage_data):
["Planet", "Country", "City"],
shrinkage="min_n_obs",
use_global_model=False,
min_n_obs=too_big_n_obs,
shrinkage_kwargs={"min_n_obs": too_big_n_obs},
)

with pytest.raises(ValueError) as e:
Expand Down Expand Up @@ -459,15 +459,15 @@ def test_global_model_shrinkage(shrinkage_data):
["Planet", "Country", "City"],
shrinkage="min_n_obs",
use_global_model=False,
min_n_obs=2,
shrinkage_kwargs={"min_n_obs": 2},
)

shrink_est_with_global = GroupedPredictor(
DummyRegressor(),
["Country", "City"],
shrinkage="min_n_obs",
use_global_model=True,
min_n_obs=2,
shrinkage_kwargs={"min_n_obs": 2},
)

shrink_est_without_global.fit(X, y)
Expand All @@ -490,7 +490,7 @@ def test_shrinkage_single_group(shrinkage_data):
"Country",
shrinkage="constant",
use_global_model=True,
alpha=0.1,
shrinkage_kwargs={"alpha": 0.1},
)

shrinkage_factors = np.array([0.1, 0.9])
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_shrinkage_single_group_no_global(shrinkage_data):
"Country",
shrinkage="constant",
use_global_model=False,
alpha=0.1,
shrinkage_kwargs={"alpha": 0.1},
)
shrink_est.fit(X, y)

Expand Down Expand Up @@ -548,7 +548,9 @@ def test_unseen_groups_shrinkage(shrinkage_data):

X, y = df.drop(columns="Target"), df["Target"]

shrink_est = GroupedPredictor(DummyRegressor(), ["Planet", "Country", "City"], shrinkage="constant", alpha=0.1)
shrink_est = GroupedPredictor(
DummyRegressor(), ["Planet", "Country", "City"], shrinkage="constant", shrinkage_kwargs={"alpha": 0.1}
)

shrink_est.fit(X, y)

Expand All @@ -569,7 +571,7 @@ def test_predict_missing_group_column(shrinkage_data):
["Planet", "Country", "City"],
shrinkage="constant",
use_global_model=False,
alpha=0.1,
shrinkage_kwargs={"alpha": 0.1},
)

shrink_est.fit(X, y)
Expand All @@ -592,7 +594,7 @@ def test_predict_missing_value_column(shrinkage_data):
["Planet", "Country", "City"],
shrinkage="constant",
use_global_model=False,
alpha=0.1,
shrinkage_kwargs={"alpha": 0.1},
)

shrink_est.fit(X, y)
Expand Down
42 changes: 24 additions & 18 deletions tests/test_meta/test_hierarchical_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,28 @@ def make_hierarchical_dummy(frame_func):
)
@pytest.mark.parametrize("fallback_method", ["raise", "parent"])
@pytest.mark.parametrize(
"shrinkage",
("shrinkage", "kwargs"),
[
{"shrinkage": None},
{"shrinkage": "equal"},
{"shrinkage": "relative"},
{"shrinkage": "min_n_obs", "min_n_obs": 10},
{"shrinkage": "constant", "alpha": 0.5},
(None, None),
("equal", None),
("relative", None),
("min_n_obs", {"min_n_obs": 10}),
("constant", {"alpha": 0.5}),
],
)
def test_fit_predict(meta_cls, base_estimator, task, fallback_method, shrinkage):
def test_fit_predict(meta_cls, base_estimator, task, fallback_method, shrinkage, kwargs):
"""Tests that the model can be fit and predict with different configurations of fallback and shrinkage methods if
X to predict contains same groups as X used to fit.
"""
X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)])

meta_model = meta_cls(estimator=base_estimator, groups=groups, fallback_method=fallback_method, **shrinkage).fit(
X, y
)
meta_model = meta_cls(
estimator=base_estimator,
groups=groups,
fallback_method=fallback_method,
shrinkage=shrinkage,
shrinkage_kwargs=kwargs,
).fit(X, y)

assert meta_model.estimators_ is not None
assert meta_model.predict(X) is not None
Expand Down Expand Up @@ -173,23 +177,25 @@ def test_fallback(meta_cls, base_estimator, task, fallback_method, context):
],
)
@pytest.mark.parametrize(
"shrinkage",
("shrinkage", "kwargs"),
[
{"shrinkage": None},
{"shrinkage": "equal"},
{"shrinkage": "relative"},
{"shrinkage": "min_n_obs", "min_n_obs": 10},
{"shrinkage": "constant", "alpha": 0.5},
(None, None),
("equal", None),
("relative", None),
("min_n_obs", {"min_n_obs": 10}),
("constant", {"alpha": 0.5}),
],
)
def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage):
def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage, kwargs):
"""Tests that the model performance is better than the base estimator when predicting with different shrinkage
methods.
"""
X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)])

X_ = nw.from_native(X).drop(groups).pipe(nw.to_native)
meta_model = meta_cls(estimator=clone(base_estimator), groups=groups, **shrinkage).fit(X, y)
meta_model = meta_cls(
estimator=clone(base_estimator), groups=groups, shrinkage=shrinkage, shrinkage_kwargs=kwargs
).fit(X, y)
base_model = clone(base_estimator).fit(X_, y)

assert metric(y, base_model.predict(X_)) <= metric(y, meta_model.predict(X))
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pandas_utils/test_pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def test_add_lags_correct_df(data, frame_func):
ans = add_lags(test_df, "X1", -1)
if isinstance(ans, pl.LazyFrame):
ans = ans.collect()
print("HERE\n", ans)
if isinstance(expected, pl.LazyFrame):
expected = expected.collect()
assert [x for x in ans.columns] == [x for x in expected.columns]

assert (ans.to_numpy() == expected.to_numpy()).all()


Expand Down

0 comments on commit 9afe9ac

Please sign in to comment.