diff --git a/tests/keras/losses_test.py b/tests/keras/losses_test.py index 2f44dc4f654..c2ca3d3d84f 100644 --- a/tests/keras/losses_test.py +++ b/tests/keras/losses_test.py @@ -57,5 +57,14 @@ def test_categorical_hinge(): assert np.isclose(expected_loss, np.mean(loss)) +def test_sparse_categorical_crossentropy(): + y_pred = K.variable(np.array([[0.3, 0.6, 0.1], + [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([1, 2])) + expected_loss = - (np.log(0.6) + np.log(0.7)) / 2 + loss = K.eval(losses.sparse_categorical_crossentropy(y_true, y_pred)) + assert np.isclose(expected_loss, np.mean(loss)) + + if __name__ == '__main__': pytest.main([__file__])