Skip to content

Commit

Permalink
PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jun 15, 2022
1 parent fd89f8b commit 98996ed
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 45 deletions.
32 changes: 17 additions & 15 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,10 @@ def call(
"""Input is expected to be of size [bsz x seqlen]."""
if position_ids is None:
seq_len = input_shape[1]
positions = tf.range(seq_len, delta=1, name="range")
positions += past_key_values_length
else:
positions = position_ids
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

return super().call(positions + self.offset)
return super().call(position_ids + self.offset)


class TFBartAttention(tf.keras.layers.Layer):
Expand Down Expand Up @@ -612,7 +610,8 @@ def serving(self, inputs):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
range `[0, config.max_position_embeddings - 1]`. If `past_key_values` is passed, `position_ids` has to be
provided.
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
Expand Down Expand Up @@ -883,7 +882,8 @@ def call(
[What are attention masks?](../glossary#attention-mask)
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range `[0, config.max_position_embeddings - 1]`.
range `[0, config.max_position_embeddings - 1]`. If `past_key_values` is passed, `position_ids` has to
be provided.
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
Expand Down Expand Up @@ -942,13 +942,11 @@ def call(
# embed positions
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
raise ValueError("Make sure to provide the position ids when passing `past_key_values`.")
positions = self.embed_positions(input_shape, past_key_values_length)
else:
positions = self.embed_positions(input_shape, position_ids=position_ids)

# breakpoint()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

Expand Down Expand Up @@ -1419,10 +1417,10 @@ def prepare_inputs_for_generation(

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # non xla + past
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand All @@ -1447,8 +1445,12 @@ def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current
batch_size = past[0][0].shape[0]

if not is_past_initialized:
# past[0].shape[2] is seq_length of prompt
# past[0][0].shape[2] is seq_length of prompt
# The padded version of `past` requires only `max_length - 1` steps along the time dimension.
num_padding_values = max_length - past[0][0].shape[2] - 1
# prepare the padding tensor for `tf.pad`.
# `shape=(4, 2)` because each tensor element in `past` has `rank=4`.
# `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward).
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))

new_past = ()
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/blenderbot/modeling_tf_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,10 +1399,10 @@ def prepare_inputs_for_generation(

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # non xla + past
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1371,10 +1371,10 @@ def prepare_inputs_for_generation(

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # non xla + past
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/marian/modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,10 +1414,10 @@ def prepare_inputs_for_generation(

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # non xla + past
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
16 changes: 7 additions & 9 deletions src/transformers/models/mbart/modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,10 @@ def call(
"""Input is expected to be of size [bsz x seqlen]."""
if position_ids is None:
seq_len = input_shape[1]
positions = tf.range(seq_len, delta=1, name="range")
positions += past_key_values_length
else:
positions = position_ids
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

return super().call(positions + self.offset)
return super().call(position_ids + self.offset)


# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
Expand Down Expand Up @@ -1413,10 +1411,10 @@ def prepare_inputs_for_generation(

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # non xla + past
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/pegasus/modeling_tf_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,10 +1423,10 @@ def prepare_inputs_for_generation(

if decoder_attention_mask is not None: # xla
decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
elif past is not None: # non xla + past
decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1))
else: # non xla + non past
decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape)
elif past is not None: # no xla + past
decoder_position_ids = past[0][0].shape[2]
else: # no xla + no past
decoder_position_ids = tf.range(decoder_input_ids.shape[1])

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
Expand Down
8 changes: 5 additions & 3 deletions tests/models/bart/test_modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32)
output_from_no_past = model(
next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids
)[0]
)
output_from_no_past = output_from_no_past[0]

decoder_position_ids = (
tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2]
Expand All @@ -138,7 +139,8 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
attention_mask=next_attention_mask,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)[0]
)
output_from_past = output_from_past[0]

self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])

Expand All @@ -151,7 +153,7 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)

def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1
Expand Down
2 changes: 1 addition & 1 deletion tests/models/gpt2/test_modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/t5/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def create_and_check_t5_decoder_model_past_large_inputs(
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)

def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1
Expand Down

0 comments on commit 98996ed

Please sign in to comment.