Skip to content

Commit

Permalink
training.py _weighted_masked_objective fix crash when weights is None (
Browse files Browse the repository at this point in the history
…#7068)

* training.py _weighted_masked_objective fix crash when weights is None

* unit test _weighted_masked_objective function
  • Loading branch information
ahundt authored and fchollet committed Jun 21, 2017
1 parent f430de1 commit 219d6ee
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
9 changes: 4 additions & 5 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,12 @@ def weighted(y_true, y_pred, weights, mask=None):
# to the number of unmasked samples.
score_array /= K.mean(mask)

# reduce score_array to same ndim as weight array
ndim = K.ndim(score_array)
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))

# apply sample weighting
if weights is not None:
# reduce score_array to same ndim as weight array
ndim = K.ndim(score_array)
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
score_array *= weights
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
return K.mean(score_array)
Expand Down
16 changes: 15 additions & 1 deletion tests/keras/engine/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from keras.layers import Dense, Dropout
from keras.engine.topology import Input
from keras.engine.training import Model, _check_loss_and_target_compatibility
from keras.engine.training import Model
from keras.engine.training import _check_loss_and_target_compatibility
from keras.engine.training import _weighted_masked_objective
from keras.models import Sequential
from keras import backend as K
from keras.utils import Sequence
Expand All @@ -27,6 +29,18 @@ def __getitem__(self, idx):
np.random.random((self.batch_size, 3))]


@keras_test
def test_weighted_masked_objective():
a = Input(shape=(3,), name='input_a')

# weighted_masked_objective
def mask_dummy(y_true=None, y_pred=None, weight=None):
return K.placeholder(y_true.shape)

weighted_function = _weighted_masked_objective(K.categorical_crossentropy)
weighted_function(a, a, None)


@keras_test
def test_model_methods():
a = Input(shape=(3,), name='input_a')
Expand Down

0 comments on commit 219d6ee

Please sign in to comment.