Skip to content

Commit

Permalink
TF: BART compatible with XLA generation (huggingface#17479)
Browse files Browse the repository at this point in the history
* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
  • Loading branch information
gante authored and younesbelkada committed Jun 25, 2022
1 parent 72bf6a4 commit c98ac16
Show file tree
Hide file tree
Showing 18 changed files with 421 additions and 86 deletions.
113 changes: 106 additions & 7 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
Expand Down Expand Up @@ -87,7 +88,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
bsz = input_ids_shape[0]
tgt_len = input_ids_shape[1]
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])

Expand All @@ -99,7 +101,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))


def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand All @@ -123,12 +125,19 @@ def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)

def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
def call(
self,
input_shape: Optional[tf.TensorShape] = None,
past_key_values_length: int = 0,
position_ids: Optional[tf.Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
if position_ids is None:
seq_len = input_shape[1]
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions + self.offset)
return super().call(position_ids + self.offset)


class TFBartAttention(tf.keras.layers.Layer):
Expand Down Expand Up @@ -599,6 +608,9 @@ def serving(self, inputs):
for denoising pre-training following the paper.
decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
Expand Down Expand Up @@ -838,6 +850,7 @@ def call(
input_ids: Optional[TFModelInputType] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
Expand Down Expand Up @@ -866,6 +879,9 @@ def call(
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
Expand Down Expand Up @@ -922,7 +938,10 @@ def call(
past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0

# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if position_ids is None:
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
Expand Down Expand Up @@ -1058,6 +1077,7 @@ def call(
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
Expand Down Expand Up @@ -1112,6 +1132,7 @@ def call(
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
Expand Down Expand Up @@ -1173,6 +1194,7 @@ def call(
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
Expand All @@ -1193,6 +1215,7 @@ def call(
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
Expand Down Expand Up @@ -1278,6 +1301,7 @@ def call(
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
Expand Down Expand Up @@ -1320,6 +1344,7 @@ def call(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
Expand Down Expand Up @@ -1375,29 +1400,103 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs
):

# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]

if not is_past_initialized:
# past[0][0].shape[2] is seq_length of prompt
# The padded version of `past` requires only `max_length - 1` steps along the time dimension.
num_padding_values = max_length - past[0][0].shape[2] - 1
# prepare the padding tensor for `tf.pad`.
# `shape=(4, 2)` because each tensor element in `past` has `rank=4`.
# `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward).
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)

# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor,
# ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1

new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)

update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)

# set `decoder_attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past

return model_kwargs

def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

Expand Down
Loading

0 comments on commit c98ac16

Please sign in to comment.