Skip to content

Commit

Permalink
Merge branch 'main' into feat/zir-score-samples
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jun 20, 2024
2 parents d7f086c + ada798b commit 253d41b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
32 changes: 20 additions & 12 deletions sklego/preprocessing/pandastransformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from typing import Any

import narwhals as nw
from narwhals.dependencies import get_pandas
Expand Down Expand Up @@ -34,7 +35,7 @@ def _nw_match_dtype(dtype, selection):
raise ValueError(msg)


def _nw_select_dtypes(df, include: str | list[str], exclude: str | list[str]):
def _nw_select_dtypes(include: str | list[str], exclude: str | list[str], schema: dict[str, Any]):
if not include and not exclude:
raise ValueError("Must provide at least one of `include` or `exclude`")

Expand All @@ -43,16 +44,20 @@ def _nw_select_dtypes(df, include: str | list[str], exclude: str | list[str]):
if isinstance(exclude, str):
exclude = [exclude]

include = include or ["string", "number", "bool", "category"]
exclude = exclude or []

feature_names = [
name
for name, dtype in df.schema.items()
if any(_nw_match_dtype(dtype, _include) for _include in include)
and not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude)
]
return df.select(feature_names)
if include:
feature_names = [
name
for name, dtype in schema.items()
if any(_nw_match_dtype(dtype, _include) for _include in include)
and not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude)
]
else:
feature_names = [
name for name, dtype in schema.items() if not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude)
]
return feature_names


class ColumnDropper(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -330,7 +335,7 @@ def fit(self, X, y=None):
else:
X = nw.from_native(X)
self.X_dtypes_ = X.schema
self.feature_names_ = _nw_select_dtypes(X, include=self.include, exclude=self.exclude).columns
self.feature_names_ = _nw_select_dtypes(include=self.include, exclude=self.exclude, schema=self.X_dtypes_)

if len(self.feature_names_) == 0:
raise ValueError("Provided type(s) results in empty dataframe")
Expand Down Expand Up @@ -377,14 +382,17 @@ def transform(self, X):
transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude)
else:
X = nw.from_native(X)
if self.X_dtypes_ != X.schema:
X_schema = X.schema
if self.X_dtypes_ != X_schema:
raise ValueError(
f"Column dtypes were not equal during fit and transform. Fit types: \n"
f"{self.X_dtypes_}\n"
f"transform: \n"
f"{X.schema}"
)
transformed_df = _nw_select_dtypes(X, include=self.include, exclude=self.exclude)
transformed_df = X.select(
_nw_select_dtypes(include=self.include, exclude=self.exclude, schema=X_schema)
).pipe(nw.to_native)

return transformed_df

Expand Down
5 changes: 3 additions & 2 deletions tests/test_meta/test_hierarchical_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ def test_shrinkage(meta_cls, base_estimator, task, metric, shrinkage):
"""
X, y, groups = make_hierarchical_dataset(task, frame_func=frame_funcs[randint(0, 1)])

X_ = nw.from_native(X).drop(groups).pipe(nw.to_native)
meta_model = meta_cls(estimator=clone(base_estimator), groups=groups, **shrinkage).fit(X, y)
base_model = clone(base_estimator).fit(X.drop(columns=groups), y)
base_model = clone(base_estimator).fit(X_, y)

assert metric(y, base_model.predict(X.drop(columns=groups))) <= metric(y, meta_model.predict(X))
assert metric(y, base_model.predict(X_)) <= metric(y, meta_model.predict(X))


@pytest.mark.parametrize(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_preprocessing/test_pandastypeselector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ def test_get_feature_names(frame_func):
assert transformer_number.get_feature_names() == ["b"]


@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
@pytest.mark.parametrize(
("include", "exclude", "expected"),
[
("number", None, ["a", "c"]),
("bool", None, ["b"]),
(None, "number", ["b", "d"]),
(None, ["number", "bool"], ["d"]),
],
)
def test_include_vs_exclude(frame_func, include, exclude, expected):
df = frame_func({"a": [4, 5, 6], "b": [True, False, True], "c": [4.0, 5.0, 6.0], "d": ["a", "b", "c"]})
type_selector = TypeSelector(include=include, exclude=exclude).fit(df)
assert type_selector.get_feature_names() == expected
result = type_selector.transform(df)
assert isinstance(result, frame_func)


@pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame])
def test_get_feature_names_deprecated(frame_func):
df = frame_func({"a": [4, 5, 6], "b": ["4", "5", "6"]})
Expand Down

0 comments on commit 253d41b

Please sign in to comment.