Skip to content

Commit

Permalink
remove extra terms from softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
dakshvar22 committed Dec 21, 2020
1 parent 5c3870f commit 8cff4ec
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions rasa/utils/tensorflow/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,7 @@ def _loss_softmax(
) -> tf.Tensor:
"""Define softmax loss."""

softmax_logits = tf.concat(
[sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li], axis=-1
)
softmax_logits = tf.concat([sim_pos, sim_neg_il, sim_neg_li], axis=-1)

sigmoid_logits = tf.concat(
[sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li], axis=-1
Expand All @@ -864,7 +862,7 @@ def _loss_softmax(
# create label_ids for softmax
softmax_label_ids = tf.zeros_like(softmax_logits[..., 0], tf.int32)

sigmoid_label_ids = tf.concat(
sigmoid_labels = tf.concat(
[
tf.expand_dims(tf.ones_like(sigmoid_logits[..., 0], tf.float32), -1),
tf.zeros_like(sigmoid_logits[..., 1:], tf.float32),
Expand All @@ -876,7 +874,7 @@ def _loss_softmax(
labels=softmax_label_ids, logits=softmax_logits
)
sigmoid_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=sigmoid_label_ids, logits=sigmoid_logits
labels=sigmoid_labels, logits=sigmoid_logits
)

loss = softmax_loss + tf.reduce_mean(sigmoid_loss, axis=-1)
Expand Down

0 comments on commit 8cff4ec

Please sign in to comment.