From 9afe9ace409693ec605e71ac61705959b9f16ea4 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 28 Jul 2024 17:24:49 +0200 Subject: [PATCH] refactor tests --- tests/test_meta/test_decay_estimator.py | 2 +- tests/test_meta/test_grouped_predictor.py | 22 +++++----- .../test_meta/test_hierarchical_predictor.py | 42 +++++++++++-------- tests/test_pandas_utils/test_pandas_utils.py | 2 + 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/tests/test_meta/test_decay_estimator.py b/tests/test_meta/test_decay_estimator.py index fabf8061..a0d6fc5e 100644 --- a/tests/test_meta/test_decay_estimator.py +++ b/tests/test_meta/test_decay_estimator.py @@ -53,7 +53,7 @@ def test_decay_weight(mod, is_clf, decay_func, decay_kwargs): if is_clf: y = (y < 0).astype(int) - mod = DecayEstimator(mod, decay_func=decay_func, **decay_kwargs).fit(X, y) + mod = DecayEstimator(mod, decay_func=decay_func, decay_kwargs=decay_kwargs).fit(X, y) assert np.logical_and(mod.weights_ >= 0, mod.weights_ <= 1).all() assert np.all(mod.weights_[:-1] <= mod.weights_[1:]) diff --git a/tests/test_meta/test_grouped_predictor.py b/tests/test_meta/test_grouped_predictor.py index 5d20a23a..21606a46 100644 --- a/tests/test_meta/test_grouped_predictor.py +++ b/tests/test_meta/test_grouped_predictor.py @@ -249,7 +249,7 @@ def test_constant_shrinkage(shrinkage_data): ["Planet", "Country", "City"], shrinkage="constant", use_global_model=False, - alpha=0.1, + shrinkage_kwargs={"alpha": 0.1}, ) shrinkage_factors = np.array([0.01, 0.09, 0.9]) @@ -304,7 +304,7 @@ def test_min_n_obs_shrinkage(shrinkage_data): ["Planet", "Country", "City"], shrinkage="min_n_obs", use_global_model=False, - min_n_obs=2, + shrinkage_kwargs={"min_n_obs": 2}, ) shrink_est.fit(X, y) @@ -327,7 +327,7 @@ def test_min_n_obs_shrinkage_too_little_obs(shrinkage_data): ["Planet", "Country", "City"], shrinkage="min_n_obs", use_global_model=False, - min_n_obs=too_big_n_obs, + shrinkage_kwargs={"min_n_obs": too_big_n_obs}, ) with pytest.raises(ValueError) as e: @@ -459,7 +459,7 @@ def test_global_model_shrinkage(shrinkage_data): ["Planet", "Country", "City"], shrinkage="min_n_obs", use_global_model=False, - min_n_obs=2, + shrinkage_kwargs={"min_n_obs": 2}, ) shrink_est_with_global = GroupedPredictor( @@ -467,7 +467,7 @@ def test_global_model_shrinkage(shrinkage_data): ["Country", "City"], shrinkage="min_n_obs", use_global_model=True, - min_n_obs=2, + shrinkage_kwargs={"min_n_obs": 2}, ) shrink_est_without_global.fit(X, y) @@ -490,7 +490,7 @@ def test_shrinkage_single_group(shrinkage_data): "Country", shrinkage="constant", use_global_model=True, - alpha=0.1, + shrinkage_kwargs={"alpha": 0.1}, ) shrinkage_factors = np.array([0.1, 0.9]) @@ -519,7 +519,7 @@ def test_shrinkage_single_group_no_global(shrinkage_data): "Country", shrinkage="constant", use_global_model=False, - alpha=0.1, + shrinkage_kwargs={"alpha": 0.1}, ) shrink_est.fit(X, y) @@ -548,7 +548,9 @@ def test_unseen_groups_shrinkage(shrinkage_data): X, y = df.drop(columns="Target"), df["Target"] - shrink_est = GroupedPredictor(DummyRegressor(), ["Planet", "Country", "City"], shrinkage="constant", alpha=0.1) + shrink_est = GroupedPredictor( + DummyRegressor(), ["Planet", "Country", "City"], shrinkage="constant", shrinkage_kwargs={"alpha": 0.1} + ) shrink_est.fit(X, y) @@ -569,7 +571,7 @@ def test_predict_missing_group_column(shrinkage_data): ["Planet", "Country", "City"], shrinkage="constant", use_global_model=False, - alpha=0.1, + shrinkage_kwargs={"alpha": 0.1}, ) shrink_est.fit(X, y) @@ -592,7 +594,7 @@ def test_predict_missing_value_column(shrinkage_data): ["Planet", "Country", "City"], shrinkage="constant", use_global_model=False, - alpha=0.1, + shrinkage_kwargs={"alpha": 0.1}, ) shrink_est.fit(X, y) diff --git a/tests/test_meta/test_hierarchical_predictor.py b/tests/test_meta/test_hierarchical_predictor.py index ff5c6cd2..94bd0161 100644 --- a/tests/test_meta/test_hierarchical_predictor.py +++ b/tests/test_meta/test_hierarchical_predictor.py @@ -116,24 +116,28 @@ def make_hierarchical_dummy(frame_func): ) @pytest.mark.parametrize("fallback_method", ["raise", "parent"]) @pytest.mark.parametrize( - "shrinkage", + ("shrinkage", "kwargs"), [ - {"shrinkage": None}, - {"shrinkage": "equal"}, - {"shrinkage": "relative"}, - {"shrinkage": "min_n_obs", "min_n_obs": 10}, - {"shrinkage": "constant", "alpha": 0.5}, + (None, None), + ("equal", None), + ("relative", None), + ("min_n_obs", {"min_n_obs": 10}), + ("constant", {"alpha": 0.5}), ], ) -def test_fit_predict(meta_cls, base_estimator, task, fallback_method, shrinkage): +def test_fit_predict(meta_cls, base_estimator, task, fallback_method, shrinkage, kwargs): """Tests that the model can be fit and predict with different configurations of fallback and shrinkage methods if X to predict contains same groups as X used to fit. """ X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)]) - meta_model = meta_cls(estimator=base_estimator, groups=groups, fallback_method=fallback_method, **shrinkage).fit( - X, y - ) + meta_model = meta_cls( + estimator=base_estimator, + groups=groups, + fallback_method=fallback_method, + shrinkage=shrinkage, + shrinkage_kwargs=kwargs, + ).fit(X, y) assert meta_model.estimators_ is not None assert meta_model.predict(X) is not None @@ -173,23 +177,25 @@ def test_fallback(meta_cls, base_estimator, task, fallback_method, context): ], ) @pytest.mark.parametrize( - "shrinkage", + ("shrinkage", "kwargs"), [ - {"shrinkage": None}, - {"shrinkage": "equal"}, - {"shrinkage": "relative"}, - {"shrinkage": "min_n_obs", "min_n_obs": 10}, - {"shrinkage": "constant", "alpha": 0.5}, + (None, None), + ("equal", None), + ("relative", None), + ("min_n_obs", {"min_n_obs": 10}), + ("constant", {"alpha": 0.5}), ], ) -def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage): +def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage, kwargs): """Tests that the model performance is better than the base estimator when predicting with different shrinkage methods. """ X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)]) X_ = nw.from_native(X).drop(groups).pipe(nw.to_native) - meta_model = meta_cls(estimator=clone(base_estimator), groups=groups, **shrinkage).fit(X, y) + meta_model = meta_cls( + estimator=clone(base_estimator), groups=groups, shrinkage=shrinkage, shrinkage_kwargs=kwargs + ).fit(X, y) base_model = clone(base_estimator).fit(X_, y) assert metric(y, base_model.predict(X_)) <= metric(y, meta_model.predict(X)) diff --git a/tests/test_pandas_utils/test_pandas_utils.py b/tests/test_pandas_utils/test_pandas_utils.py index bff49ff0..19c8ce98 100644 --- a/tests/test_pandas_utils/test_pandas_utils.py +++ b/tests/test_pandas_utils/test_pandas_utils.py @@ -48,9 +48,11 @@ def test_add_lags_correct_df(data, frame_func): ans = add_lags(test_df, "X1", -1) if isinstance(ans, pl.LazyFrame): ans = ans.collect() + print("HERE\n", ans) if isinstance(expected, pl.LazyFrame): expected = expected.collect() assert [x for x in ans.columns] == [x for x in expected.columns] + assert (ans.to_numpy() == expected.to_numpy()).all()