Skip to content

Commit

Permalink
Fix keras.ops.softmax for the tensorflow backend (#19300)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel authored Mar 13, 2024
1 parent df705d4 commit 6a266b8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
7 changes: 5 additions & 2 deletions keras/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,13 @@ def softmax(x, axis=-1):
array([0.09003057, 0.24472847, 0.66524096], shape=(3,), dtype=float64)
"""
if isinstance(axis, int) and backend.shape(x)[axis] == 1:
# Don't use `backend.shape` since TensorFlow returns
# symbolic tensors for unknown shape which can trigger
# an error in TensorFlow graph execution.
if isinstance(axis, int) and x.shape[axis] == 1:
warnings.warn(
f"You are using a softmax over axis {axis} "
f"of a tensor of shape {backend.shape(x)}. This axis "
f"of a tensor of shape {x.shape}. This axis "
"has size 1. The softmax operation will always return "
"the value 1, which is likely not what you intended. "
"Did you mean to use a sigmoid instead?"
Expand Down
18 changes: 18 additions & 0 deletions keras/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import pytest
from absl.testing import parameterized

import keras
from keras import backend
from keras import layers
from keras import losses
from keras import models
from keras import ops
from keras import testing
from keras.backend.common import standardize_dtype
from keras.backend.common.keras_tensor import KerasTensor
Expand Down Expand Up @@ -84,6 +86,22 @@ def test_softmax(self):
self.assertEqual(knn.softmax(x, axis=1).shape, (None, 2, 3))
self.assertEqual(knn.softmax(x, axis=-1).shape, (None, 2, 3))

def test_softmax_in_graph(self):
class SoftmaxLayer(keras.Layer):
def call(self, x):
return ops.softmax(x, axis=-1)

class Model(keras.Model):
def __init__(self):
x = keras.Input(shape=(None,))
y = SoftmaxLayer()(x)
super().__init__(inputs=x, outputs=y)

# Make sure Keras is able to compile the model graph
model = Model()
x = ops.array([[1.0, 2.0, 3.0, 4.0]])
model.predict(x)

def test_log_softmax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.log_softmax(x).shape, (None, 2, 3))
Expand Down

0 comments on commit 6a266b8

Please sign in to comment.