Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC Use dataframe api to also support Polars and other dataframe libraries #597

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"scikit-learn>=1.0",
"pandas>=1.1.5",
"Deprecated>=1.2.6",
"dataframe-api-compat>=0.1.26",
]
cvxpy_packages = ["cvxpy>=1.1.8"]
umap_packages = ["umap-learn>=0.4.6"]
Expand All @@ -31,6 +32,7 @@
"pytest-cov>=2.6.1",
"pytest-mock>=1.6.3",
"pre-commit>=1.18.3",
"polars>=0.19.13",
]
util_packages = [
"matplotlib>=3.0.2",
Expand Down
15 changes: 10 additions & 5 deletions sklego/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import collections
import hashlib
from warnings import warn

import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

from sklego.dataframe_agnostic_utils import try_convert_to_standard_compliant_dataframe


class TrainOnlyTransformerMixin(TransformerMixin):
"""Mixin class for transformers that can handle training and test data differently.
Expand Down Expand Up @@ -58,9 +58,9 @@ def transform_train(self, X, y=None):
"""

_HASHERS = {
pd.DataFrame: lambda X: hashlib.sha256(
pd.util.hash_pandas_object(X, index=True).values
).hexdigest(),
'__dataframe_namespace__': lambda X: hash(
X.to_array().data.tobytes()
),
np.ndarray: lambda X: hash(X.data.tobytes()),
np.memmap: lambda X: hash(X.data.tobytes()),
}
Expand Down Expand Up @@ -113,6 +113,11 @@ def _hash(X):
ValueError
If the type of `X` is not supported.
"""
X = try_convert_to_standard_compliant_dataframe(X)
if hasattr(X, '__dataframe_namespace__'):
hasher = TrainOnlyTransformerMixin._HASHERS['__dataframe_namespace__']
X = X.persist()
return hasher(X)
try:
hasher = TrainOnlyTransformerMixin._HASHERS[type(X)]
except KeyError:
Expand Down
57 changes: 57 additions & 0 deletions sklego/dataframe_agnostic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
We use the DataFrame API Standard to write dataframe-agnostic code.

The full spec can be found here: https://data-apis.org/dataframe-api/draft/API_specification/index.html.

pandas and Polars expose the entrypoint `__dataframe_consortium_standard__` on their DataFrame and Series objects,
but only as of versions 2.1 and 0.18.13 respectively.

In order to support earlier versions, we import `convert_to_standard_compliant_dataframe` from
the `dataframe-api-compat` package.
"""

def try_convert_to_standard_compliant_dataframe(df, *, strict = False):
if hasattr(df, '__dataframe_consortium_standard__'):
return df.__dataframe_consortium_standard__(api_version='2023.11-beta')
try:
import pandas as pd
except ModuleNotFoundError:
pass
else:
if isinstance(df, pd.DataFrame):
from dataframe_api_compat.pandas_standard import convert_to_standard_compliant_dataframe
return convert_to_standard_compliant_dataframe(df)
try:
import polars as pl
except ModuleNotFoundError:
pass
else:
if isinstance(df, (pl.DataFrame, pl.LazyFrame)):
from dataframe_api_compat.polars_standard import convert_to_standard_compliant_dataframe
return convert_to_standard_compliant_dataframe(df)
if strict:
raise TypeError(f"Could not convert {type(df)} to a standard compliant dataframe")
return df

def try_convert_to_standard_compliant_column(df, *, strict = False):
if hasattr(df, '__column_consortium_standard__'):
return df.__column_consortium_standard__(api_version='2023.11-beta')
try:
import pandas as pd
except ModuleNotFoundError:
pass
else:
if isinstance(df, pd.Series):
from dataframe_api_compat.pandas_standard import convert_to_standard_compliant_column
return convert_to_standard_compliant_column(df)
try:
import polars as pl
except ModuleNotFoundError:
pass
else:
if isinstance(df, (pl.DataFrame, pl.LazyFrame)):
from dataframe_api_compat.polars_standard import convert_to_standard_compliant_column
return convert_to_standard_compliant_column(df)
if strict:
raise TypeError(f"Could not convert {type(df)} to a standard compliant dataframe")
return df
15 changes: 14 additions & 1 deletion sklego/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import numpy as np
import pandas as pd
from pkg_resources import resource_filename
from sklearn.datasets import fetch_openml

Expand Down Expand Up @@ -92,6 +91,8 @@ def load_penguins(return_X_y=False, as_frame=False):

(Accessed 2020-06-08).
"""
import pandas as pd

filepath = resource_filename("sklego", os.path.join("data", "penguins.zip"))
df = pd.read_csv(filepath)
if as_frame:
Expand Down Expand Up @@ -151,6 +152,8 @@ def load_arrests(return_X_y=False, as_frame=False):

- Personal communication from Michael Friendly, York University.
"""
import pandas as pd

filepath = resource_filename("sklego", os.path.join("data", "arrests.zip"))
df = pd.read_csv(filepath)
if as_frame:
Expand Down Expand Up @@ -198,6 +201,8 @@ def load_chicken(return_X_y=False, as_frame=False):
- Crowder, M. and Hand, D. (1990), Analysis of Repeated Measures, Chapman and Hall (example 5.3)
- Hand, D. and Crowder, M. (1996), Practical Longitudinal Data Analysis, Chapman and Hall (table A.2)
"""
import pandas as pd

filepath = resource_filename("sklego", os.path.join("data", "chickweight.zip"))
df = pd.read_csv(filepath)
if as_frame:
Expand Down Expand Up @@ -244,6 +249,8 @@ def load_abalone(return_X_y=False, as_frame=False):

Sea Fisheries Division, Technical Report No. 48 (ISSN 1034-3288)
"""
import pandas as pd

filepath = resource_filename("sklego", os.path.join("data", "abalone.zip"))
df = pd.read_csv(filepath)
if as_frame:
Expand Down Expand Up @@ -294,6 +301,8 @@ def load_heroes(return_X_y=False, as_frame=False):
# Index(['name', 'attack_type', 'role', 'health', 'attack', 'attack_spd'], dtype='object')
```
"""
import pandas as pd

filepath = resource_filename("sklego", os.path.join("data", "heroes.zip"))
df = pd.read_csv(filepath)
if as_frame:
Expand Down Expand Up @@ -351,6 +360,8 @@ def load_hearts(return_X_y=False, as_frame=False):
The documentation of the dataset can be viewed at:
https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/heart-disease.names
"""
import pandas as pd

filepath = resource_filename("sklego", os.path.join("data", "hearts.zip"))
df = pd.read_csv(filepath)
if as_frame:
Expand Down Expand Up @@ -434,6 +445,8 @@ def make_simpleseries(
'''
```
"""
import pandas as pd

if seed:
np.random.seed(seed)
time = np.arange(0, n_samples)
Expand Down
10 changes: 7 additions & 3 deletions sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from warnings import warn

import numpy as np
import pandas as pd
from deprecated.sphinx import deprecated
from scipy.optimize import minimize
from scipy.special._ufuncs import expit
Expand All @@ -26,6 +25,8 @@
column_or_1d,
)

from sklego.pandas_utils import try_convert_to_standard_compliant_dataframe


class LowessRegression(BaseEstimator, RegressorMixin):
"""`LowessRegression` estimator: LOWESS (Locally Weighted Scatterplot Smoothing) is a type of
Expand Down Expand Up @@ -496,10 +497,13 @@ def fit(self, X, y):
)

self.sensitive_col_idx_ = self.sensitive_cols
if isinstance(X, pd.DataFrame):

X = try_convert_to_standard_compliant_dataframe(X)
if hasattr(X, '__dataframe_namespace__'):
self.sensitive_col_idx_ = [
i for i, name in enumerate(X.columns) if name in self.sensitive_cols
i for i, name in enumerate(X.column_names) if name in self.sensitive_cols
]
X = X.dataframe
X, y = check_X_y(X, y, accept_large_sparse=False)

sensitive = X[:, self.sensitive_col_idx_]
Expand Down
38 changes: 27 additions & 11 deletions sklego/meta/_grouped_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.utils.validation import _ensure_no_complex_data

from sklego.common import as_list
from sklego.dataframe_agnostic_utils import try_convert_to_standard_compliant_dataframe


def constant_shrinkage(group_sizes: list, alpha: float) -> np.ndarray:
Expand Down Expand Up @@ -45,21 +46,25 @@ def _split_groups_and_values(
_shape_check(X, min_value_cols)

try:
if isinstance(X, pd.DataFrame):
X_group = X.loc[:, as_list(groups)]
X_value = X.drop(columns=groups).values
else:
try:
X = try_convert_to_standard_compliant_dataframe(X, strict=True).persist()
except TypeError:
X_group = pd.DataFrame(X[:, as_list(groups)])
pos_indexes = range(X.shape[1])
X_value = np.delete(X, [pos_indexes[g] for g in as_list(groups)], axis=1)
except (KeyError, IndexError):
raise ValueError(f"Could not drop groups {groups} from columns of X")
else:
X_group = X.select(*as_list(groups))
X_value = X.drop_columns(*as_list(groups)).to_array()
except (KeyError, IndexError, TypeError) as exc:
raise ValueError(f"Could not drop groups {groups} from columns of X") from exc

X_group = _check_grouping_columns(X_group, **kwargs)

if check_X:
X_value = check_array(X_value, **kwargs)

if hasattr(X_group, '__dataframe_namespace__'):
X_group = X_group.dataframe
return X_group, X_value


Expand All @@ -82,13 +87,24 @@ def _shape_check(X, min_value_cols):
def _check_grouping_columns(X_group, **kwargs) -> pd.DataFrame:
"""Do basic checks on grouping columns"""
# Do regular checks on numeric columns
X_group_num = X_group.select_dtypes(include="number")
if X_group_num.shape[1]:
check_array(X_group_num, **kwargs)
if not hasattr(X_group, '__dataframe_namespace__'):
X_group = try_convert_to_standard_compliant_dataframe(X_group).persist()
pdx = X_group.__dataframe_namespace__()
X_group_num = X_group.select(
*[col.name for col in X_group.iter_columns()
if pdx.is_dtype(col, 'numeric')]
)
if len(X_group_num.column_names):
check_array(X_group_num.to_array(), **kwargs)

# Only check missingness in object columns
if X_group.select_dtypes(exclude="number").isnull().any(axis=None):
if (
X_group.select(
*[col.name for col in X_group.iter_columns()
if not pdx.is_dtype(col, 'number')]
).is_null().to_array().any()
):
raise ValueError("X has NaN values")

# The grouping part we always want as a DataFrame with range index
return X_group.reset_index(drop=True)
return X_group.dataframe
Loading
Loading