From ada798bf0c9f548135c7adbf296f464ea14ebbfb Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Sat, 15 Jun 2024 12:24:58 +0200 Subject: [PATCH] narwhals to the rescue (#678) --- tests/test_meta/test_hierarchical_predictor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_meta/test_hierarchical_predictor.py b/tests/test_meta/test_hierarchical_predictor.py index 1cf29317..ec8d70f3 100644 --- a/tests/test_meta/test_hierarchical_predictor.py +++ b/tests/test_meta/test_hierarchical_predictor.py @@ -188,10 +188,11 @@ def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage): """ X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)]) + X_ = nw.from_native(X).drop(groups).pipe(nw.to_native) meta_model = meta_cls(estimator=clone(base_estimator), groups=groups, **shrinkage).fit(X, y) - base_model = clone(base_estimator).fit(X.drop(columns=groups), y) + base_model = clone(base_estimator).fit(X_, y) - assert metric(y, base_model.predict(X.drop(columns=groups))) <= metric(y, meta_model.predict(X)) + assert metric(y, base_model.predict(X_)) <= metric(y, meta_model.predict(X)) @pytest.mark.parametrize(