Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix multiclass onehot inverse #189

Merged
merged 12 commits into from
Feb 15, 2021
14 changes: 5 additions & 9 deletions scikeras/utils/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,7 @@ def inverse_transform(
class_predictions = np.argmax(y, axis=1).reshape(-1, 1)
class_predictions = self._final_encoder.inverse_transform(class_predictions)
elif self._target_type == "multiclass":
# array([0.8, 0.1, 0.1], [.1, .8, .1]) ->
# array(['apple', 'orange'])
# array([0.8, 0.1, 0.1], [.1, .8, .1]) -> array(['apple', 'orange'])
idx = np.argmax(y, axis=-1)
if not is_categorical_crossentropy(self.loss):
class_predictions = idx.reshape(-1, 1)
Expand All @@ -255,13 +254,12 @@ def inverse_transform(
class_predictions[:, idx] = 1
class_predictions = self._final_encoder.inverse_transform(class_predictions)
elif self._target_type == "multiclass-onehot":
# array([.8, .1, .1], [.1, .8, .1]) ->
# array([[1, 0, 0], [0, 1, 0]])
# array([.8, .1, .1], [.1, .8, .1]) -> array([[1, 0, 0], [0, 1, 0]])
idx = np.argmax(y, axis=-1)
class_predictions = np.zeros(y.shape, dtype=int)
class_predictions[:, idx] = 1
class_predictions[np.arange(idx.size), idx] = 1
adriangb marked this conversation as resolved.
Show resolved Hide resolved
elif self._target_type == "multilabel-indicator":
class_predictions = np.around(y)
class_predictions = np.around(y).astype(int, copy=False)
else:
if not return_proba:
raise NotImplementedError(
Expand All @@ -279,9 +277,7 @@ def inverse_transform(

if return_proba:
return np.squeeze(y)
res = np.column_stack(class_predictions).astype(self._y_dtype, copy=False)
res = res.reshape(-1, *self._y_shape[1:])
return res
return class_predictions.reshape(-1, *self._y_shape[1:])

def get_metadata(self) -> Dict[str, Any]:
"""Returns a dictionary of meta-parameters generated when this transfromer
Expand Down
9 changes: 6 additions & 3 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ def _check_model_compatibility(self, y: np.ndarray) -> None:
if self.n_outputs_expected_ != len(self.model_.outputs):
raise ValueError(
"Detected a Keras model input of size"
f" {y[0].shape[0]}, but {self.model_} has"
f" {self.model_.outputs} outputs"
f" {self.n_outputs_expected_ }, but {self.model_} has"
f" {len(self.model_.outputs)} outputs"
)
# check that if the user gave us a loss function it ended up in
# the actual model
Expand Down Expand Up @@ -1626,6 +1626,9 @@ def r_squared(y_true, y_pred):
"""A simple Keras implementation of R^2 that can be used as a Keras
loss function.

Note that this returns 1-R^2 so that it can be minimized as a loss
function.
adriangb marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
Expand All @@ -1641,6 +1644,6 @@ def r_squared(y_true, y_pred):
tf.math.squared_difference(y_true, tf.math.reduce_mean(y_true, axis=0)),
axis=0,
)
return tf.math.reduce_mean(
return 1 - tf.math.reduce_mean(
adriangb marked this conversation as resolved.
Show resolved Hide resolved
1 - ss_res / (ss_tot + tf.keras.backend.epsilon()), axis=-1
)
21 changes: 17 additions & 4 deletions tests/multi_output_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,26 @@ def inverse_transform(
) -> np.ndarray:
if self._target_type not in ("multilabel-indicator", "multiclass-multioutput"):
return super().inverse_transform(y, return_proba=return_proba)
if not return_proba and self.split:
y = [np.argmax(y_, axis=1).astype(self._y_dtype, copy=False) for y_ in y]
y = np.squeeze(np.column_stack(y))
if return_proba:
return y
if self._target_type == "multilabel-indicator":
if self.split:
y = np.column_stack(y)
# RandomForestClassifier and sklearn's MultiOutputClassifier always return int64
# for multilabel-indicator
y = y.astype(int)
y = np.around(y).astype(int, copy=False)
else: # mutlitclass-multioutput
if self.split:
y = np.column_stack(
[
np.argmax(y_, axis=1).astype(self._y_dtype, copy=False)
for y_ in y
]
)
else:
raise NotImplementedError(
"multiclass-multioutput must be handled by a multi-output Model"
)
return y


Expand Down
Loading