Skip to content

Commit

Permalink
MAINT compatibility sklearn 1.4 (#1058)
Browse files Browse the repository at this point in the history
* MAINT compatibility sklearn 1.4

* iter

* fix

* doc

* compat numpydoc

* update changelog

* fix
  • Loading branch information
glemaitre authored Jan 19, 2024
1 parent 0a659af commit c7a1838
Show file tree
Hide file tree
Showing 17 changed files with 166 additions and 352 deletions.
10 changes: 5 additions & 5 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ jobs:
ne(variables['Build.Reason'], 'Schedule')
)
matrix:
py38_conda_forge_openblas_ubuntu_1804:
py39_conda_forge_openblas_ubuntu_1804:
DISTRIB: 'conda'
CONDA_CHANNEL: 'conda-forge'
PYTHON_VERSION: '3.8'
PYTHON_VERSION: '3.9'
BLAS: 'openblas'
COVERAGE: 'false'

Expand Down Expand Up @@ -188,7 +188,7 @@ jobs:
pylatest_conda_tensorflow:
DISTRIB: 'conda-latest-tensorflow'
CONDA_CHANNEL: 'conda-forge'
PYTHON_VERSION: '3.8'
PYTHON_VERSION: '3.9'
TEST_DOCS: 'true'
TEST_DOCSTRINGS: 'true'
CHECK_WARNINGS: 'true'
Expand All @@ -214,7 +214,7 @@ jobs:
pylatest_conda_keras:
DISTRIB: 'conda-latest-keras'
CONDA_CHANNEL: 'conda-forge'
PYTHON_VERSION: '3.8'
PYTHON_VERSION: '3.9'
TEST_DOCS: 'true'
TEST_DOCSTRINGS: 'true'
CHECK_WARNINGS: 'true'
Expand Down Expand Up @@ -301,7 +301,7 @@ jobs:
py38_conda_forge_mkl:
DISTRIB: 'conda'
CONDA_CHANNEL: 'conda-forge'
PYTHON_VERSION: '3.8'
PYTHON_VERSION: '3.10'
CHECK_WARNINGS: 'true'
PYTHON_ARCH: '64'
PYTEST_VERSION: '*'
Expand Down
5 changes: 2 additions & 3 deletions doc/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ data set, this classifier will favor the majority classes::
>>> from sklearn.ensemble import BaggingClassifier
>>> from sklearn.tree import DecisionTreeClassifier
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
>>> bc = BaggingClassifier(base_estimator=DecisionTreeClassifier(),
... random_state=0)
>>> bc = BaggingClassifier(DecisionTreeClassifier(), random_state=0)
>>> bc.fit(X_train, y_train) #doctest:
BaggingClassifier(...)
>>> y_pred = bc.predict(X_test)
Expand All @@ -50,7 +49,7 @@ sampling is controlled by the parameter `sampler` or the two parameters
:class:`~imblearn.under_sampling.RandomUnderSampler`::

>>> from imblearn.ensemble import BalancedBaggingClassifier
>>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
>>> bbc = BalancedBaggingClassifier(DecisionTreeClassifier(),
... sampling_strategy='auto',
... replacement=False,
... random_state=0)
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ Compatibility

- :class:`~imblearn.ensemble.BalancedRandomForestClassifier` now support missing values
and monotonic constraints if scikit-learn >= 1.4 is installed.

- :class:`~imblearn.pipeline.Pipeline` support metadata routing if scikit-learn >= 1.4
is installed.

- Compatibility with scikit-learn 1.4.
:pr:`1058` by :user:`Guillaume Lemaitre <glemaitre>`.

Deprecations
............

Expand Down
84 changes: 26 additions & 58 deletions imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# License: MIT

import copy
import inspect
import numbers
import warnings

Expand All @@ -15,6 +14,7 @@
from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble._bagging import _parallel_decision_function
from sklearn.ensemble._base import _partition_estimators
from sklearn.exceptions import NotFittedError
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import parse_version
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -121,30 +121,13 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
.. versionadded:: 0.8
base_estimator : estimator object, default=None
The base estimator to fit on random subsets of the dataset.
If None, then the base estimator is a decision tree.
.. deprecated:: 0.10
`base_estimator` was renamed to `estimator` in version 0.10 and
will be removed in 0.12.
Attributes
----------
estimator_ : estimator
The base estimator from which the ensemble is grown.
.. versionadded:: 0.10
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
.. deprecated:: 1.2
`base_estimator_` is deprecated in `scikit-learn` 1.2 and will be
removed in 1.4. Use `estimator_` instead. When the minimum version
of `scikit-learn` supported by `imbalanced-learn` will reach 1.4,
this attribute will be removed.
n_features_ : int
The number of features when `fit` is performed.
Expand Down Expand Up @@ -266,7 +249,7 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
"""

# make a deepcopy to not modify the original dictionary
if sklearn_version >= parse_version("1.3"):
if sklearn_version >= parse_version("1.4"):
_parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
else:
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)
Expand All @@ -283,6 +266,9 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
"sampler": [HasMethods(["fit_resample"]), None],
}
)
# TODO: remove when minimum supported version of scikit-learn is 1.4
if "base_estimator" in _parameter_constraints:
del _parameter_constraints["base_estimator"]

def __init__(
self,
Expand All @@ -301,18 +287,8 @@ def __init__(
random_state=None,
verbose=0,
sampler=None,
base_estimator="deprecated",
):
# TODO: remove when supporting scikit-learn>=1.2
bagging_classifier_signature = inspect.signature(super().__init__)
estimator_params = {"base_estimator": base_estimator}
if "estimator" in bagging_classifier_signature.parameters:
estimator_params["estimator"] = estimator
else:
self.estimator = estimator

super().__init__(
**estimator_params,
n_estimators=n_estimators,
max_samples=max_samples,
max_features=max_features,
Expand All @@ -324,6 +300,7 @@ def __init__(
random_state=random_state,
verbose=verbose,
)
self.estimator = estimator
self.sampling_strategy = sampling_strategy
self.replacement = replacement
self.sampler = sampler
Expand All @@ -349,42 +326,17 @@ def _validate_y(self, y):
def _validate_estimator(self, default=DecisionTreeClassifier()):
"""Check the estimator and the n_estimator attribute, set the
`estimator_` attribute."""
if self.estimator is not None and (
self.base_estimator not in [None, "deprecated"]
):
raise ValueError(
"Both `estimator` and `base_estimator` were set. Only set `estimator`."
)

if self.estimator is not None:
base_estimator = clone(self.estimator)
elif self.base_estimator not in [None, "deprecated"]:
warnings.warn(
"`base_estimator` was renamed to `estimator` in version 0.10 and "
"will be removed in 0.12.",
FutureWarning,
)
base_estimator = clone(self.base_estimator)
estimator = clone(self.estimator)
else:
base_estimator = clone(default)
estimator = clone(default)

if self.sampler_._sampling_type != "bypass":
self.sampler_.set_params(sampling_strategy=self._sampling_strategy)

self._estimator = Pipeline(
[("sampler", self.sampler_), ("classifier", base_estimator)]
self.estimator_ = Pipeline(
[("sampler", self.sampler_), ("classifier", estimator)]
)
try:
# scikit-learn < 1.2
self.base_estimator_ = self._estimator
except AttributeError:
pass

# TODO: remove when supporting scikit-learn>=1.4
@property
def estimator_(self):
"""Estimator used to grow the ensemble."""
return self._estimator

# TODO: remove when supporting scikit-learn>=1.2
@property
Expand Down Expand Up @@ -483,6 +435,22 @@ def decision_function(self, X):

return decisions

@property
def base_estimator_(self):
"""Attribute for older sklearn version compatibility."""
error = AttributeError(
f"{self.__class__.__name__} object has no attribute 'base_estimator_'."
)
if sklearn_version < parse_version("1.2"):
# The base class require to have the attribute defined. For scikit-learn
# > 1.2, we are going to raise an error.
try:
check_is_fitted(self)
return self.estimator_
except NotFittedError:
raise error
raise error

def _more_tags(self):
tags = super()._more_tags()
tags_key = "_xfail_checks"
Expand Down
Loading

0 comments on commit c7a1838

Please sign in to comment.