From 7e94390cdcd18e6ade70f01e186ebffb3525a8e5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 21 Oct 2021 10:42:39 +0200 Subject: [PATCH] MNT update import for keras --- doc/miscellaneous.rst | 2 +- imblearn/keras/_generator.py | 8 ++++---- imblearn/keras/tests/test_generator.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/miscellaneous.rst b/doc/miscellaneous.rst index 43503abfa..4aeb4a2cf 100644 --- a/doc/miscellaneous.rst +++ b/doc/miscellaneous.rst @@ -136,7 +136,7 @@ calling ``fit_generator`` method to train the model. To illustrate, we will define a logistic regression model:: >>> import keras - >>> y = keras.utils.to_categorical(y, 3) + >>> y = keras.utils.np_utils.to_categorical(y, 3) >>> model = keras.Sequential() >>> model.add(keras.layers.Dense(y.shape[1], input_dim=X.shape[1], ... activation='softmax')) diff --git a/imblearn/keras/_generator.py b/imblearn/keras/_generator.py index 4c0707498..8ba1316b8 100644 --- a/imblearn/keras/_generator.py +++ b/imblearn/keras/_generator.py @@ -17,7 +17,7 @@ def import_from_keras(): try: import keras # noqa - return (keras.utils.data_utils.Sequence,), True + return (keras.utils.Sequence,), True except ImportError: return tuple(), False @@ -33,10 +33,10 @@ def import_from_tensforflow(): ParentClassTensorflow, has_keras_tf = import_from_tensforflow() has_keras = has_keras_k or has_keras_tf if has_keras: - if has_keras_tf: - ParentClass = ParentClassTensorflow - else: + if has_keras_k: ParentClass = ParentClassKeras + else: + ParentClass = ParentClassTensorflow else: ParentClass = (object,) return ParentClass, has_keras diff --git a/imblearn/keras/tests/test_generator.py b/imblearn/keras/tests/test_generator.py index 40c10b6a3..071c74364 100644 --- a/imblearn/keras/tests/test_generator.py +++ b/imblearn/keras/tests/test_generator.py @@ -8,7 +8,7 @@ keras = pytest.importorskip("keras") from keras.models import Sequential # noqa: E402 from keras.layers import Dense # noqa: E402 -from keras.utils import to_categorical # noqa: E402 +from keras.utils.np_utils import to_categorical # noqa: E402 from imblearn.datasets import make_imbalance # noqa: E402 from imblearn.under_sampling import ClusterCentroids # noqa: E402