Skip to content

Commit

Permalink
narwhals to the rescue (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jun 15, 2024
1 parent 2b32533 commit ada798b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/test_meta/test_hierarchical_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage):
"""
X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)])

X_ = nw.from_native(X).drop(groups).pipe(nw.to_native)
meta_model = meta_cls(estimator=clone(base_estimator), groups=groups, **shrinkage).fit(X, y)
base_model = clone(base_estimator).fit(X.drop(columns=groups), y)
base_model = clone(base_estimator).fit(X_, y)

assert metric(y, base_model.predict(X.drop(columns=groups))) <= metric(y, meta_model.predict(X))
assert metric(y, base_model.predict(X_)) <= metric(y, meta_model.predict(X))


@pytest.mark.parametrize(
Expand Down

0 comments on commit ada798b

Please sign in to comment.