Skip to content

Commit

Permalink
Rearrange attention output to have batch dimension on the 0th axis (#…
Browse files Browse the repository at this point in the history
…8591)

* transpose first and second axis

* add changelog

* typo

* fix tests
  • Loading branch information
dakshvar22 authored May 12, 2021
1 parent 68ab1d0 commit 03b3236
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
3 changes: 3 additions & 0 deletions changelog/8591.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Tensorflow models now return batch dimension on the first axis and number of layers on the second axis for output array associated with `attention_weights` key.

Previously, the expected shape of the output array was - `(num_layers, batch_size, num_heads, length, length)`. Now, the expected shape of the output array is `(batch_size, num_layers, num_heads, length, length)`.
9 changes: 7 additions & 2 deletions rasa/utils/tensorflow/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ def call(
# a whole stack of unnormalized layer outputs.
x = self._layer_norm(x) # (batch_size, length, units)

# Keep the batch dimension on the first axis
attention_weights_as_output = tf.transpose(
tf.stack(layer_attention_weights), (1, 0, 2, 3, 4)
)

# (batch_size, length, units),
# (num_layers, batch_size, num_heads, length, length)
return x, tf.stack(layer_attention_weights)
# (batch_size, num_layers, num_heads, length, length)
return x, attention_weights_as_output
8 changes: 4 additions & 4 deletions tests/utils/tensorflow/test_rasa_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ def test_layer_gives_correct_output_units(
[batch_size, max_seq_length + 1, units_1],
[batch_size, max_seq_length + 1, 1],
[
num_transformer_layers,
batch_size,
num_transformer_layers,
num_transformer_heads,
max_seq_length + 1,
max_seq_length + 1,
Expand All @@ -381,8 +381,8 @@ def test_layer_gives_correct_output_units(
[0,],
[0,],
[
num_transformer_layers,
batch_size,
num_transformer_layers,
num_transformer_heads,
max_seq_length + 1,
max_seq_length + 1,
Expand Down Expand Up @@ -440,8 +440,8 @@ def test_layer_gives_correct_output_units(
[batch_size, max_seq_length, 2],
[batch_size, max_seq_length, 1],
[
num_transformer_layers,
batch_size,
num_transformer_layers,
num_transformer_heads,
max_seq_length,
max_seq_length,
Expand All @@ -454,8 +454,8 @@ def test_layer_gives_correct_output_units(
[0,],
[0,],
[
num_transformer_layers,
batch_size,
num_transformer_layers,
num_transformer_heads,
max_seq_length,
max_seq_length,
Expand Down

0 comments on commit 03b3236

Please sign in to comment.