Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PandasTypeSelector selector dataframe-agnostic #670

Merged
merged 9 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@
options:
show_root_full_path: true
show_root_heading: true

:::sklego.preprocessing.pandastransformers.TypeSelector
options:
show_root_full_path: true
show_root_heading: true
2 changes: 1 addition & 1 deletion docs/contribution.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ When a new feature is introduced, it should be documented, and typically there a
- [x] A user guide in the `docs/user-guide/` folder.
- [x] A python script in the `docs/_scripts/` folder to generate plots and code snippets (see [next section](#working-with-pymdown-snippets-extension))
- [x] Relevant static files, such as images, plots, tables and html's, should be saved in the `docs/_static/` folder.
- [x] Edit the `mkdocs.yaml` file to include the new pages in the navigation.
- [x] Edit the `mkdocs.yaml` file to include the new pages in the navigation.

### Working with pymdown snippets extension

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ maintainers = [
]

dependencies = [
"narwhals>=0.8.12",
"narwhals>=0.8.13",
"pandas>=1.1.5",
"scikit-learn>=1.0",
"importlib-metadata >= 1.0; python_version < '3.8'",
Expand Down
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Here's a list of features that this library currently offers:
- `sklego.preprocessing.InformationFilter` transformer that can de-correlate features
- `sklego.preprocessing.IdentityTransformer` returns the same data, allows for concatenating pipelines
- `sklego.preprocessing.OrthogonalTransformer` makes all features linearly independent
- `sklego.preprocessing.PandasTypeSelector` selects columns based on pandas type
- `sklego.preprocessing.TypeSelector` selects columns based on type
- `sklego.preprocessing.RandomAdder` adds randomness in training
- `sklego.preprocessing.RepeatingBasisFunction` repeating feature engineering, useful for timeseries
- `sklego.preprocessing.DictMapper` assign numeric values on categorical columns
Expand Down
3 changes: 2 additions & 1 deletion sklego/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"OrthogonalTransformer",
"OutlierRemover",
"PandasTypeSelector",
"TypeSelector",
"RandomAdder",
"RepeatingBasisFunction",
]
Expand All @@ -20,7 +21,7 @@
from sklego.preprocessing.identitytransformer import IdentityTransformer
from sklego.preprocessing.intervalencoder import IntervalEncoder
from sklego.preprocessing.outlier_remover import OutlierRemover
from sklego.preprocessing.pandastransformers import ColumnDropper, ColumnSelector, PandasTypeSelector
from sklego.preprocessing.pandastransformers import ColumnDropper, ColumnSelector, PandasTypeSelector, TypeSelector
from sklego.preprocessing.projections import InformationFilter, OrthogonalTransformer
from sklego.preprocessing.randomadder import RandomAdder
from sklego.preprocessing.repeatingbasis import RepeatingBasisFunction
147 changes: 113 additions & 34 deletions sklego/preprocessing/pandastransformers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
from __future__ import annotations

import warnings

import narwhals as nw
import pandas as pd
from narwhals.dependencies import get_pandas
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted

from sklego.common import as_list


def _nw_match_dtype(dtype, selection):
if selection == "number":
return dtype in (
nw.Int64,
nw.Int32,
nw.Int16,
nw.Int8,
nw.UInt64,
nw.UInt32,
nw.UInt16,
nw.UInt8,
nw.Float64,
nw.Float32,
)
if selection == "bool":
return dtype == nw.Boolean
if selection == "string":
return dtype == nw.String
if selection == "category":
return dtype == nw.Categorical
msg = f"Expected {{'number', 'bool', 'string', 'category'}}, got: {selection}, which is not (yet!) supported."
raise ValueError(msg)


def _nw_select_dtypes(df, include: str | list[str], exclude: str | list[str]):
if not include and not exclude:
raise ValueError("Must provide at least one of `include` or `exclude`")

if isinstance(include, str):
include = [include]
if isinstance(exclude, str):
exclude = [exclude]

include = include or ["string", "number", "bool", "category"]
exclude = exclude or []

feature_names = [
name
for name, dtype in df.schema.items()
if any(_nw_match_dtype(dtype, _include) for _include in include)
and not any(_nw_match_dtype(dtype, _exclude) for _exclude in exclude)
]
return df.select(feature_names)


class ColumnDropper(BaseEstimator, TransformerMixin):
"""The `ColumnDropper` transformer allows dropping specific columns from a DataFrame by name.
Can be useful in a sklearn Pipeline.
Expand Down Expand Up @@ -172,13 +221,21 @@ def _check_column_names(self, X):
raise KeyError(f"{list(non_existent_columns)} column(s) not in DataFrame")


class PandasTypeSelector(BaseEstimator, TransformerMixin):
"""The `PandasTypeSelector` transformer allows to select columns in a pandas DataFrame based on their type.
class TypeSelector(BaseEstimator, TransformerMixin):
"""The `TypeSelector` transformer allows to select columns in a DataFrame based on their type.
Can be useful in a sklearn Pipeline.

It uses
[pandas.DataFrame.select_dtypes](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.select_dtypes.html)
method.
- For pandas, it uses
[pandas.DataFrame.select_dtypes](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.select_dtypes.html)
method.
- For non-pandas dataframes (e.g. Polars), the following inputs are allowed:

- 'number'
- 'string'
- 'bool'
- 'category'

!!! info "New in version 0.9.0"

Parameters
----------
Expand All @@ -191,7 +248,7 @@ class PandasTypeSelector(BaseEstimator, TransformerMixin):
----------
feature_names_ : list[str]
The names of the features to keep during transform.
X_dtypes_ : pd.Series
X_dtypes_ : Series | dict[str, DType]
The dtypes of the columns in the input DataFrame.

!!! warning
Expand All @@ -202,7 +259,7 @@ class PandasTypeSelector(BaseEstimator, TransformerMixin):
--------
```py
import pandas as pd
from sklego.preprocessing import PandasTypeSelector
from sklego.preprocessing import TypeSelector

df = pd.DataFrame({
"name": ["Swen", "Victor", "Alex"],
Expand All @@ -211,14 +268,14 @@ class PandasTypeSelector(BaseEstimator, TransformerMixin):
})

#Excluding single column
PandasTypeSelector(exclude="int64").fit_transform(df)
TypeSelector(exclude="int64").fit_transform(df)
# name length
#0 Swen 1.82
#1 Victor 1.85
#2 Alex 1.80

#Including multiple columns
PandasTypeSelector(include=["int64", "object"]).fit_transform(df)
TypeSelector(include=["int64", "object"]).fit_transform(df)
# name shoesize
#0 Swen 42
#1 Victor 44
Expand All @@ -235,26 +292,30 @@ def fit(self, X, y=None):

Parameters
----------
X : pd.DataFrame
X : DataFrame
The data on which we apply the column selection.
y : pd.Series, default=None
y : Series, default=None
Ignored, present for compatibility.

Returns
-------
self : PandasTypeSelector
self : TypeSelector
The fitted transformer.

Raises
------
TypeError
If `X` is not a `pd.DataFrame` object.
If `X` is not a supported DataFrame.
ValueError
If provided type(s) results in empty dataframe.
"""
self._check_X_for_type(X)
self.X_dtypes_ = X.dtypes
self.feature_names_ = list(X.select_dtypes(include=self.include, exclude=self.exclude).columns)
if (pd := get_pandas()) is not None and isinstance(X, pd.DataFrame):
self.X_dtypes_ = X.dtypes
self.feature_names_ = list(X.select_dtypes(include=self.include, exclude=self.exclude).columns)
else:
X = nw.from_native(X)
self.X_dtypes_ = X.schema
self.feature_names_ = _nw_select_dtypes(X, include=self.include, exclude=self.exclude).columns

if len(self.feature_names_) == 0:
raise ValueError("Provided type(s) results in empty dataframe")
Expand All @@ -266,48 +327,66 @@ def get_feature_names(self, *args, **kwargs):
return self.feature_names_

def transform(self, X):
"""Returns a pandas DataFrame with columns (de)selected based on their dtype.
"""Returns a DataFrame with columns (de)selected based on their dtype.

Parameters
----------
X : pd.DataFrame
X : DataFrame
The data to select dtype for.

Returns
-------
pd.DataFrame
DataFrame
The data with the specified columns selected.

Raises
------
TypeError
If `X` is not a `pd.DataFrame` object.
If `X` is not a supported DataFrame.
ValueError
If column dtypes were not equal during fit and transform.
"""
check_is_fitted(self, ["X_dtypes_", "feature_names_"])

try:
if (self.X_dtypes_ != X.dtypes).any():
if (pd := get_pandas()) is not None and isinstance(X, pd.DataFrame):
try:
if (self.X_dtypes_ != X.dtypes).any():
raise ValueError(
f"Column dtypes were not equal during fit and transform. Fit types: \n"
f"{self.X_dtypes_}\n"
f"transform: \n"
f"{X.dtypes}"
)
except ValueError as e:
raise ValueError("Column dtypes were not equal during fit and transform") from e
transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude)
else:
X = nw.from_native(X)
if self.X_dtypes_ != X.schema:
raise ValueError(
f"Column dtypes were not equal during fit and transform. Fit types: \n"
f"{self.X_dtypes_}\n"
f"transform: \n"
f"{X.dtypes}"
f"{X.schema}"
)
except ValueError as e:
raise ValueError("Columns were not equal during fit and transform") from e

self._check_X_for_type(X)
transformed_df = X.select_dtypes(include=self.include, exclude=self.exclude)
transformed_df = _nw_select_dtypes(X, include=self.include, exclude=self.exclude)

return transformed_df

@staticmethod
def _check_X_for_type(X):
"""Checks if input of the Selector is of the required dtype"""
if not isinstance(X, pd.DataFrame):
raise TypeError("Provided variable X is not of type pandas.DataFrame")

class PandasTypeSelector(TypeSelector):
"""
!!! warning "Deprecated since version 0.9.0, please use TypeSelector instead"
"""

def __init__(self, include=None, exclude=None):
warnings.warn(
"PandasTypeSelector is deprecated and will be removed in a future version. "
"Please use `from sklego.preprocessing import TypeSelector` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(include=include, exclude=exclude)


class ColumnSelector(BaseEstimator, TransformerMixin):
Expand Down
Loading
Loading