diff --git a/sklego/preprocessing/pandastransformers.py b/sklego/preprocessing/pandastransformers.py index 6df4d9a7..1d516c3e 100644 --- a/sklego/preprocessing/pandastransformers.py +++ b/sklego/preprocessing/pandastransformers.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from typing import Any import narwhals as nw from narwhals.dependencies import get_pandas @@ -34,7 +35,7 @@ def _nw_match_dtype(dtype, selection): raise ValueError(msg) -def _nw_select_dtypes(df, include: str | list[str], exclude: str | list[str]): +def _nw_select_dtypes(include: str | list[str], exclude: str | list[str], schema: dict[str, Any]): if not include and not exclude: raise ValueError("Must provide at least one of `include` or `exclude`") @@ -43,16 +44,20 @@ def _nw_select_dtypes(df, include: str | list[str], exclude: str | list[str]): if isinstance(exclude, str): exclude = [exclude] - include = include or ["string", "number", "bool", "category"] exclude = exclude or [] - feature_names = [ - name - for name, dtype in df.schema.items() - if any(_nw_match_dtype(dtype, _include) for _include in include) - and not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude) - ] - return df.select(feature_names) + if include: + feature_names = [ + name + for name, dtype in schema.items() + if any(_nw_match_dtype(dtype, _include) for _include in include) + and not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude) + ] + else: + feature_names = [ + name for name, dtype in schema.items() if not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude) + ] + return feature_names class ColumnDropper(BaseEstimator, TransformerMixin): @@ -330,7 +335,7 @@ def fit(self, X, y=None): else: X = nw.from_native(X) self.X_dtypes_ = X.schema - self.feature_names_ = _nw_select_dtypes(X, include=self.include, exclude=self.exclude).columns + self.feature_names_ = _nw_select_dtypes(include=self.include, exclude=self.exclude, schema=self.X_dtypes_) if len(self.feature_names_) == 0: raise ValueError("Provided type(s) results in empty dataframe") @@ -377,14 +382,17 @@ def transform(self, X): transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude) else: X = nw.from_native(X) - if self.X_dtypes_ != X.schema: + X_schema = X.schema + if self.X_dtypes_ != X_schema: raise ValueError( f"Column dtypes were not equal during fit and transform. Fit types: \n" f"{self.X_dtypes_}\n" f"transform: \n" f"{X.schema}" ) - transformed_df = _nw_select_dtypes(X, include=self.include, exclude=self.exclude) + transformed_df = X.select( + _nw_select_dtypes(include=self.include, exclude=self.exclude, schema=X_schema) + ).pipe(nw.to_native) return transformed_df diff --git a/tests/test_meta/test_hierarchical_predictor.py b/tests/test_meta/test_hierarchical_predictor.py index 1cf29317..ec8d70f3 100644 --- a/tests/test_meta/test_hierarchical_predictor.py +++ b/tests/test_meta/test_hierarchical_predictor.py @@ -188,10 +188,11 @@ def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage): """ X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)]) + X_ = nw.from_native(X).drop(groups).pipe(nw.to_native) meta_model = meta_cls(estimator=clone(base_estimator), groups=groups, **shrinkage).fit(X, y) - base_model = clone(base_estimator).fit(X.drop(columns=groups), y) + base_model = clone(base_estimator).fit(X_, y) - assert metric(y, base_model.predict(X.drop(columns=groups))) <= metric(y, meta_model.predict(X)) + assert metric(y, base_model.predict(X_)) <= metric(y, meta_model.predict(X)) @pytest.mark.parametrize( diff --git a/tests/test_preprocessing/test_pandastypeselector.py b/tests/test_preprocessing/test_pandastypeselector.py index 71644dc2..8d4ea5cd 100644 --- a/tests/test_preprocessing/test_pandastypeselector.py +++ b/tests/test_preprocessing/test_pandastypeselector.py @@ -68,6 +68,24 @@ def test_get_feature_names(frame_func): assert transformer_number.get_feature_names() == ["b"] +@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame]) +@pytest.mark.parametrize( + ("include", "exclude", "expected"), + [ + ("number", None, ["a", "c"]), + ("bool", None, ["b"]), + (None, "number", ["b", "d"]), + (None, ["number", "bool"], ["d"]), + ], +) +def test_include_vs_exclude(frame_func, include, exclude, expected): + df = frame_func({"a": [4, 5, 6], "b": [True, False, True], "c": [4.0, 5.0, 6.0], "d": ["a", "b", "c"]}) + type_selector = TypeSelector(include=include, exclude=exclude).fit(df) + assert type_selector.get_feature_names() == expected + result = type_selector.transform(df) + assert isinstance(result, frame_func) + + @pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame]) def test_get_feature_names_deprecated(frame_func): df = frame_func({"a": [4, 5, 6], "b": ["4", "5", "6"]})