-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make PandasTypeSelector selector dataframe-agnostic #670
Changes from 3 commits
2697b2d
d2e703c
d96e427
4f6b1ea
243f0a5
a5334cc
3ad9a10
070e2fe
d5f0413
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
import narwhals as nw | ||
import pandas as pd | ||
from narwhals.dependencies import get_pandas | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
from sklego.common import as_list | ||
|
||
|
||
def _nw_match_dtype(dtype, selection): | ||
if selection == "number": | ||
return dtype in ( | ||
nw.Int64, | ||
nw.Int32, | ||
nw.Int16, | ||
nw.Int8, | ||
nw.UInt64, | ||
nw.UInt32, | ||
nw.UInt16, | ||
nw.UInt8, | ||
nw.Float64, | ||
nw.Float32, | ||
) | ||
if selection == "bool": | ||
return dtype == nw.Boolean | ||
if selection == "string": | ||
return dtype == nw.String | ||
if selection == "category": | ||
return dtype == nw.Categorical | ||
msg = f"Expected {{'number', 'bool', 'string', 'category'}}, got: {selection}, which is not (yet!) supported." | ||
raise ValueError(msg) | ||
|
||
|
||
def _nw_select_dtypes(df, include: str | list[str], exclude: str | list[str]): | ||
feature_names = [] | ||
if isinstance(include, str): | ||
include = [include] | ||
if isinstance(exclude, str): | ||
exclude = [exclude] | ||
for name, dtype in df.schema.items(): | ||
if include and exclude: | ||
if any(_nw_match_dtype(dtype, _include) for _include in include) and not any( | ||
_nw_match_dtype(dtype, _exclude) for _exclude in exclude | ||
): | ||
feature_names.append(name) | ||
elif include: | ||
if any(_nw_match_dtype(dtype, _include) for _include in include): | ||
feature_names.append(name) | ||
elif exclude: | ||
if not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude): | ||
feature_names.append(name) | ||
else: | ||
raise ValueError("Must provide at least one of `include` or `exclude`") | ||
return df.select(feature_names) | ||
MarcoGorelli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class ColumnDropper(BaseEstimator, TransformerMixin): | ||
"""The `ColumnDropper` transformer allows dropping specific columns from a DataFrame by name. | ||
Can be useful in a sklearn Pipeline. | ||
|
@@ -173,12 +222,18 @@ def _check_column_names(self, X): | |
|
||
|
||
class PandasTypeSelector(BaseEstimator, TransformerMixin): | ||
"""The `PandasTypeSelector` transformer allows to select columns in a pandas DataFrame based on their type. | ||
"""The `PandasTypeSelector` transformer allows to select columns in a DataFrame based on their type. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering its name, we could do the following: class PandasTypeSelector(BaseEstimator, TransformerMixin):
def __init__(self, include=None, exclude=None):
warn(
"Please use `TypeSelector` instead of `PandasTypeSelector`, `PandasTypeSelector` will be deprecated in future versions",
DeprecationWarning,
)
return TypeSelector(include, exclude) and then class TypeSelector(BaseEstimator, TransformerMixin):
...
!!! info "New in version 0.9.0" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, and I think the whole OK to do it all in one go in a separate PR, so that all the ones in EDIT: I noticed that this is already exported from The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we can rename it to have a more intuitive naming path, but as you spotted, it shouldn't matter too much as they are exported into |
||
Can be useful in a sklearn Pipeline. | ||
|
||
It uses | ||
[pandas.DataFrame.select_dtypes](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.select_dtypes.html) | ||
method. | ||
- For pandas, it uses | ||
[pandas.DataFrame.select_dtypes](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.select_dtypes.html) | ||
method. | ||
- For non-pandas dataframes (e.g. Polars), the following inputs are allowed: | ||
|
||
- 'number' | ||
- 'string' | ||
- 'bool' | ||
- 'category' | ||
|
||
Parameters | ||
---------- | ||
|
@@ -191,7 +246,7 @@ class PandasTypeSelector(BaseEstimator, TransformerMixin): | |
---------- | ||
feature_names_ : list[str] | ||
The names of the features to keep during transform. | ||
X_dtypes_ : pd.Series | ||
X_dtypes_ : Series | dict[str, DType] | ||
The dtypes of the columns in the input DataFrame. | ||
|
||
!!! warning | ||
|
@@ -235,9 +290,9 @@ def fit(self, X, y=None): | |
|
||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
X : DataFrame | ||
The data on which we apply the column selection. | ||
y : pd.Series, default=None | ||
y : Series, default=None | ||
Ignored, present for compatibility. | ||
|
||
Returns | ||
|
@@ -248,13 +303,17 @@ def fit(self, X, y=None): | |
Raises | ||
------ | ||
TypeError | ||
If `X` is not a `pd.DataFrame` object. | ||
If `X` is not a supported DataFrame. | ||
ValueError | ||
If provided type(s) results in empty dataframe. | ||
""" | ||
self._check_X_for_type(X) | ||
self.X_dtypes_ = X.dtypes | ||
self.feature_names_ = list(X.select_dtypes(include=self.include, exclude=self.exclude).columns) | ||
if (pd := get_pandas()) is not None and isinstance(X, pd.DataFrame): | ||
self.X_dtypes_ = X.dtypes | ||
self.feature_names_ = list(X.select_dtypes(include=self.include, exclude=self.exclude).columns) | ||
else: | ||
X = nw.from_native(X) | ||
self.X_dtypes_ = X.schema | ||
self.feature_names_ = _nw_select_dtypes(X, include=self.include, exclude=self.exclude).columns | ||
|
||
if len(self.feature_names_) == 0: | ||
raise ValueError("Provided type(s) results in empty dataframe") | ||
|
@@ -266,49 +325,52 @@ def get_feature_names(self, *args, **kwargs): | |
return self.feature_names_ | ||
|
||
def transform(self, X): | ||
"""Returns a pandas DataFrame with columns (de)selected based on their dtype. | ||
"""Returns a DataFrame with columns (de)selected based on their dtype. | ||
|
||
Parameters | ||
---------- | ||
X : pd.DataFrame | ||
X : DataFrame | ||
The data to select dtype for. | ||
|
||
Returns | ||
------- | ||
pd.DataFrame | ||
DataFrame | ||
The data with the specified columns selected. | ||
|
||
Raises | ||
------ | ||
TypeError | ||
If `X` is not a `pd.DataFrame` object. | ||
If `X` is not a supported DataFrame. | ||
ValueError | ||
If column dtypes were not equal during fit and transform. | ||
""" | ||
check_is_fitted(self, ["X_dtypes_", "feature_names_"]) | ||
|
||
try: | ||
if (self.X_dtypes_ != X.dtypes).any(): | ||
if (pd := get_pandas()) is not None and isinstance(X, pd.DataFrame): | ||
try: | ||
if (self.X_dtypes_ != X.dtypes).any(): | ||
raise ValueError( | ||
f"Column dtypes were not equal during fit and transform. Fit types: \n" | ||
f"{self.X_dtypes_}\n" | ||
f"transform: \n" | ||
f"{X.dtypes}" | ||
) | ||
except ValueError as e: | ||
raise ValueError("Columns were not equal during fit and transform") from e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup, the last test in I've unified the messages and included the error message in the test |
||
transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude) | ||
else: | ||
X = nw.from_native(X) | ||
if self.X_dtypes_ != X.schema: | ||
raise ValueError( | ||
f"Column dtypes were not equal during fit and transform. Fit types: \n" | ||
f"{self.X_dtypes_}\n" | ||
f"transform: \n" | ||
f"{X.dtypes}" | ||
f"{X.schema}" | ||
) | ||
except ValueError as e: | ||
raise ValueError("Columns were not equal during fit and transform") from e | ||
|
||
self._check_X_for_type(X) | ||
transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude) | ||
transformed_df = _nw_select_dtypes(X, include=self.include, exclude=self.exclude) | ||
|
||
return transformed_df | ||
|
||
@staticmethod | ||
def _check_X_for_type(X): | ||
"""Checks if input of the Selector is of the required dtype""" | ||
if not isinstance(X, pd.DataFrame): | ||
raise TypeError("Provided variable X is not of type pandas.DataFrame") | ||
|
||
|
||
class ColumnSelector(BaseEstimator, TransformerMixin): | ||
"""The `ColumnSelector` transformer allows selecting specific columns from a DataFrame by name. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was line 23 the intended target?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i probably shouldn't make commits in a hurry whilst on a train sorry