Skip to content

Commit

Permalink
FIX InstanceHardnessThreshold accepts classifier included in a Pipeli…
Browse files Browse the repository at this point in the history
…ne (#1049)

Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
gmogol and glemaitre committed Mar 31, 2024
1 parent 5570b40 commit 27b8d6a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
15 changes: 15 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
.. _changes_0_12:

Version 0.12.1
==============

**In progress**

Changelog
---------

Bug fixes
.........

- Fix a bug in :class:`~imblearn.under_sampling.InstanceHardnessThreshold` where
`estimator` could not be a :class:`~sklearn.pipeline.Pipeline` object.
:pr:`1049` by :user:`Gonenc Mogol <gmogol>`.

Version 0.12.0
==============

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import Counter

import numpy as np
from sklearn.base import ClassifierMixin, clone
from sklearn.base import clone, is_classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.model_selection import StratifiedKFold, cross_val_predict
Expand Down Expand Up @@ -140,7 +140,7 @@ def _validate_estimator(self, random_state):

if (
self.estimator is not None
and isinstance(self.estimator, ClassifierMixin)
and is_classifier(self.estimator)
and hasattr(self.estimator, "predict_proba")
):
self.estimator_ = clone(self.estimator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.naive_bayes import GaussianNB as NB
from sklearn.pipeline import make_pipeline
from sklearn.utils._testing import assert_array_equal

from imblearn.under_sampling import InstanceHardnessThreshold
Expand Down Expand Up @@ -93,3 +94,19 @@ def test_iht_fit_resample_default_estimator():
assert isinstance(iht.estimator_, RandomForestClassifier)
assert X_resampled.shape == (12, 2)
assert y_resampled.shape == (12,)


def test_iht_estimator_pipeline():
"""Check that we can pass a pipeline containing a classifier.
Checking if we have a classifier should not be based on inheriting from
`ClassifierMixin`.
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/pull/1049
"""
model = make_pipeline(GradientBoostingClassifier(random_state=RND_SEED))
iht = InstanceHardnessThreshold(estimator=model, random_state=RND_SEED)
X_resampled, y_resampled = iht.fit_resample(X, Y)
assert X_resampled.shape == (12, 2)
assert y_resampled.shape == (12,)

0 comments on commit 27b8d6a

Please sign in to comment.