Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch: scikit-learn 1.6 compatibility #726

Merged
merged 11 commits into from
Dec 17, 2024
520 changes: 520 additions & 0 deletions sklego/_sklearn_compat.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions sklego/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import validate_data


class TrainOnlyTransformerMixin(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -79,11 +81,11 @@ def fit(self, X, y=None):
The fitted transformer.
"""
if y is None:
check_array(X, estimator=self)
validate_data(self, X=X, reset=True)
else:
check_X_y(X, y, estimator=self, multi_output=True)
validate_data(self, X=X, y=y, multi_output=True, reset=True)

self.X_hash_ = self._hash(X)
self.n_features_in_ = X.shape[1]
return self

@staticmethod
Expand Down Expand Up @@ -145,10 +147,7 @@ def transform(self, X, y=None):
If the input dimension does not match the training dimension.
"""
check_is_fitted(self, ["X_hash_", "n_features_in_"])
check_array(X, estimator=self)

if X.shape[1] != self.n_features_in_:
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self.n_features_in_}")
validate_data(self, X=X, reset=False)

if self._hash(X) == self.X_hash_:
return self.transform_train(X)
Expand Down
12 changes: 7 additions & 5 deletions sklego/decomposition/pca_reconstruction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from sklearn.base import BaseEstimator, OutlierMixin
from sklearn.decomposition import PCA
from sklearn.utils.validation import FLOAT_DTYPES, check_array, check_is_fitted
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import validate_data


class PCAOutlierDetection(OutlierMixin, BaseEstimator):
Expand Down Expand Up @@ -94,7 +96,7 @@ def fit(self, X, y=None):
ValueError
If `threshold` is `None`.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=True)
if not self.threshold:
raise ValueError("The `threshold` value cannot be `None`.")

Expand All @@ -108,8 +110,6 @@ def fit(self, X, y=None):
)
self.pca_.fit(X, y)
self.offset_ = -self.threshold

self.n_features_in_ = X.shape[1]
return self

def difference(self, X):
Expand All @@ -126,6 +126,8 @@ def difference(self, X):
The calculated difference.
"""
check_is_fitted(self, ["pca_", "offset_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)

reduced = self.pca_.transform(X)
diff = np.sum(np.abs(self.pca_.inverse_transform(reduced) - X), axis=1)
if self.variant == "relative":
Expand Down Expand Up @@ -157,8 +159,8 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data. 1 for inliers, -1 for outliers.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["pca_", "offset_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)
result = np.ones(X.shape[0])
result[self.difference(X) > self.threshold] = -1
return result.astype(int)
19 changes: 14 additions & 5 deletions sklego/decomposition/umap_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpy as np
from sklearn.base import BaseEstimator, OutlierMixin
from sklearn.utils.validation import FLOAT_DTYPES, check_array, check_is_fitted
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import validate_data


class UMAPOutlierDetection(OutlierMixin, BaseEstimator):
Expand Down Expand Up @@ -100,9 +102,10 @@ def fit(self, X, y=None):
- If `n_components` is less than 2.
- If `threshold` is `None`.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
if y is not None:
y = check_array(y, estimator=self, ensure_2d=False)
X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True)
else:
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=True)

if not self.threshold:
raise ValueError("The `threshold` value cannot be `None`.")
Expand All @@ -116,7 +119,6 @@ def fit(self, X, y=None):
)
self.umap_.fit(X, y)
self.offset_ = -self.threshold
self.n_features_in_ = X.shape[1]
return self

def difference(self, X):
Expand All @@ -133,6 +135,8 @@ def difference(self, X):
The calculated difference.
"""
check_is_fitted(self, ["umap_", "offset_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)

reduced = self.umap_.transform(X)
diff = np.sum(np.abs(self.umap_.inverse_transform(reduced) - X), axis=1)
if self.variant == "relative":
Expand All @@ -155,8 +159,8 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data. 1 for inliers, -1 for outliers.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["umap_", "offset_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)
result = np.ones(X.shape[0])
result[self.difference(X) > self.threshold] = -1
return result.astype(int)
Expand All @@ -172,3 +176,8 @@ def score_samples(self, X):

def _more_tags(self):
return {"non_deterministic": True}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.non_deterministic = True
return tags
23 changes: 11 additions & 12 deletions sklego/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils import check_X_y
from sklearn.utils.validation import (
FLOAT_DTYPES,
check_array,
check_is_fitted,
check_random_state,
)
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_random_state

from sklego._sklearn_compat import validate_data


class RandomRegressor(RegressorMixin, BaseEstimator):
Expand Down Expand Up @@ -72,8 +68,7 @@ def fit(self, X: np.array, y: np.array) -> "RandomRegressor":
"""
if self.strategy not in self._ALLOWED_STRATEGIES:
raise ValueError(f"strategy {self.strategy} is not in {self._ALLOWED_STRATEGIES}")
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
self.n_features_in_ = X.shape[1]
X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True)

self.min_ = np.min(y)
self.max_ = np.max(y)
Expand All @@ -99,9 +94,7 @@ def predict(self, X):
rs = check_random_state(self.random_state)
check_is_fitted(self, ["n_features_in_", "min_", "max_", "mu_", "sigma_"])

X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
if X.shape[1] != self.n_features_in_:
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self.dim_}")
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)

if self.strategy == "normal":
return rs.normal(self.mu_, self.sigma_, X.shape[0])
Expand All @@ -127,3 +120,9 @@ def allowed_strategies(self):

def _more_tags(self):
return {"poor_score": True, "non_deterministic": True}

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.non_deterministic = True
tags.regressor_tags.poor_score = True
return tags
7 changes: 4 additions & 3 deletions sklego/feature_selection/mrmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from sklearn.base import BaseEstimator
from sklearn.feature_selection import f_classif, f_regression
from sklearn.feature_selection._base import SelectorMixin
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import validate_data


def _redundancy_pearson(X, selected, left):
Expand Down Expand Up @@ -201,13 +203,12 @@ def fit(self, X, y):

k parameter is not integer type or is < n_features_in (X.shape[1]) or < 1
"""
X, y = check_X_y(X, y, dtype="numeric", y_numeric=True)
X, y = validate_data(self, X=X, y=y, dtype="numeric", y_numeric=True, reset=True)
self._y_dtype = y.dtype

relevance = self._get_relevance
redundancy = self._get_redundancy

self.n_features_in_ = X.shape[1]
left_features = list(range(self.n_features_in_))
selected_features = []
selected_scores = []
Expand Down
30 changes: 17 additions & 13 deletions sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from sklearn.utils.validation import (
FLOAT_DTYPES,
_check_sample_weight,
check_array,
check_is_fitted,
column_or_1d,
)

from sklego._sklearn_compat import check_array, validate_data


class LowessRegression(RegressorMixin, BaseEstimator):
"""`LowessRegression` estimator: LOWESS (Locally Weighted Scatterplot Smoothing) is a type of
Expand Down Expand Up @@ -96,7 +97,7 @@ def fit(self, X, y):
- If `span` is not between 0 and 1.
- If `sigma` is negative.
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True)
if self.span is not None:
if not 0 <= self.span <= 1:
raise ValueError(f"Param `span` must be 0 <= span <= 1, got: {self.span}")
Expand Down Expand Up @@ -138,8 +139,8 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted values.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["X_", "y_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)

try:
results = np.stack([np.average(self.y_, weights=self._calc_wts(x_i=x_i)) for x_i in X])
Expand Down Expand Up @@ -233,7 +234,7 @@ def fit(self, X, y):
self : ProbWeightRegression
The fitted estimator.
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True)

# Construct the problem.
betas = cp.Variable(X.shape[1])
Expand Down Expand Up @@ -263,8 +264,8 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["coef_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)
return np.dot(X, self.coef_)

@property
Expand Down Expand Up @@ -345,8 +346,6 @@ class DeadZoneRegressor(RegressorMixin, BaseEstimator):

print(y_pred)
```


"""

_ALLOWED_EFFECTS = ("linear", "quadratic", "constant")
Expand Down Expand Up @@ -381,7 +380,8 @@ def fit(self, X, y):
ValueError
If `effect` is not one of "linear", "quadratic" or "constant".
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True)

if self.effect not in self._ALLOWED_EFFECTS:
raise ValueError(f"effect {self.effect} must be in {self._ALLOWED_EFFECTS}")

Expand Down Expand Up @@ -458,8 +458,9 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["coef_"])
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)

return np.dot(X, self.coef_)

@property
Expand Down Expand Up @@ -970,8 +971,6 @@ def __init__(
self.fit_intercept = fit_intercept
self.copy_X = copy_X
self.positive = positive
if method not in ("SLSQP", "TNC", "L-BFGS-B"):
raise ValueError(f'method should be one of "SLSQP", "TNC", "L-BFGS-B", ' f"got {method} instead")
self.method = method

@abstractmethod
Expand Down Expand Up @@ -1021,6 +1020,10 @@ def fit(self, X, y, sample_weight=None):
self : BaseScipyMinimizeRegressor
Fitted linear model.
"""
if self.method not in {"SLSQP", "TNC", "L-BFGS-B"}:
msg = f"method should be one of 'SLSQP', 'TNC', 'L-BFGS-B', got {self.method} instead"
raise ValueError(msg)

X_, grad_loss, loss = self._prepare_inputs(X, sample_weight, y)

d = X_.shape[1] - self.n_features_in_ # This is either zero or one.
Expand Down Expand Up @@ -1051,7 +1054,8 @@ def _prepare_inputs(self, X, sample_weight, y):
This method is called by `fit` to prepare the inputs for the optimization problem. It adds an intercept column
to `X` if `fit_intercept=True`, and returns the loss function and its gradient.
"""
X, y = check_X_y(X, y, y_numeric=True)
X, y = validate_data(self, X=X, y=y, y_numeric=True, reset=True)

sample_weight = _check_sample_weight(sample_weight, X)
self.n_features_in_ = X.shape[1]

Expand Down Expand Up @@ -1081,7 +1085,7 @@ def predict(self, X):
The predicted data.
"""
check_is_fitted(self)
X = check_array(X)
X = validate_data(self, X=X, reset=False)

return X @ self.coef_ + self.intercept_

Expand Down
3 changes: 2 additions & 1 deletion sklego/meta/_grouped_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import narwhals.stable.v1 as nw
import pandas as pd
from scipy.sparse import issparse
from sklearn.utils import check_array
from sklearn.utils.validation import _ensure_no_complex_data

from sklego._sklearn_compat import check_array


def parse_X_y(X, y, groups, check_X=True, **kwargs) -> nw.DataFrame:
"""Converts X, y to narwhals dataframe.
Expand Down
11 changes: 6 additions & 5 deletions sklego/meta/confusion_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import FLOAT_DTYPES, check_array, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import validate_data
from sklego.base import ProbabilisticClassifier


Expand Down Expand Up @@ -63,7 +64,8 @@ def fit(self, X, y):
If the underlying estimator does not have a `predict_proba` method.
"""

X, y = check_X_y(X, y, estimator=self.estimator, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True)

if not isinstance(self.estimator, ProbabilisticClassifier):
raise ValueError(
"The ConfusionBalancer meta model only works on classification models with .predict_proba."
Expand All @@ -72,7 +74,6 @@ def fit(self, X, y):
self.classes_ = unique_labels(y)
cfm = confusion_matrix(y, self.estimator_.predict(X)).T + self.cfm_smooth
self.cfm_ = cfm / cfm.sum(axis=1).reshape(-1, 1)
self.n_features_in_ = X.shape[1]
return self

def predict_proba(self, X):
Expand All @@ -90,7 +91,7 @@ def predict_proba(self, X):
The predicted values.
"""
check_is_fitted(self, ["cfm_", "classes_", "estimator_"])
X = check_array(X, dtype=FLOAT_DTYPES)
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)
preds = self.estimator_.predict_proba(X)
return (1 - self.alpha) * preds + self.alpha * preds @ self.cfm_

Expand All @@ -108,5 +109,5 @@ def predict(self, X):
The predicted values.
"""
check_is_fitted(self, ["cfm_", "classes_", "estimator_"])
X = check_array(X, dtype=FLOAT_DTYPES)
X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False)
return self.classes_[self.predict_proba(X).argmax(axis=1)]
Loading
Loading