diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 1714d857952cce..5d90fea8e0dd4c 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -134,12 +134,10 @@ def call( """Input is expected to be of size [bsz x seqlen].""" if position_ids is None: seq_len = input_shape[1] - positions = tf.range(seq_len, delta=1, name="range") - positions += past_key_values_length - else: - positions = position_ids + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length - return super().call(positions + self.offset) + return super().call(position_ids + self.offset) class TFBartAttention(tf.keras.layers.Layer): @@ -612,7 +610,8 @@ def serving(self, inputs): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. + range `[0, config.max_position_embeddings - 1]`. If `past_key_values` is passed, `position_ids` has to be + provided. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: @@ -883,7 +882,8 @@ def call( [What are attention masks?](../glossary#attention-mask) position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the - range `[0, config.max_position_embeddings - 1]`. + range `[0, config.max_position_embeddings - 1]`. If `past_key_values` is passed, `position_ids` has to + be provided. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. @@ -942,13 +942,11 @@ def call( # embed positions if position_ids is None: if past_key_values is not None: - raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + raise ValueError("Make sure to provide the position ids when passing `past_key_values`.") positions = self.embed_positions(input_shape, past_key_values_length) else: positions = self.embed_positions(input_shape, position_ids=position_ids) - # breakpoint() - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale @@ -1419,10 +1417,10 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past is not None: # non xla + past - decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) - else: # non xla + non past - decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1447,8 +1445,12 @@ def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current batch_size = past[0][0].shape[0] if not is_past_initialized: - # past[0].shape[2] is seq_length of prompt + # past[0][0].shape[2] is seq_length of prompt + # The padded version of `past` requires only `max_length - 1` steps along the time dimension. num_padding_values = max_length - past[0][0].shape[2] - 1 + # prepare the padding tensor for `tf.pad`. + # `shape=(4, 2)` because each tensor element in `past` has `rank=4`. + # `indices=[[2, 1]]` means the time dimension (dim 2) needs **right**-padding (`1` means padding afterward). padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2)) new_past = () diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 0314414002b41e..16fdd2995266d9 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1399,10 +1399,10 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past is not None: # non xla + past - decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) - else: # non xla + non past - decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 0135377491cf0f..a183c94aa7906c 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1371,10 +1371,10 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past is not None: # non xla + past - decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) - else: # non xla + non past - decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index c368020d45148a..3a99dbab58b7b6 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1414,10 +1414,10 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past is not None: # non xla + past - decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) - else: # non xla + non past - decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index dd8ca53ef32851..2fe3435b5fc48b 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -134,12 +134,10 @@ def call( """Input is expected to be of size [bsz x seqlen].""" if position_ids is None: seq_len = input_shape[1] - positions = tf.range(seq_len, delta=1, name="range") - positions += past_key_values_length - else: - positions = position_ids + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length - return super().call(positions + self.offset) + return super().call(position_ids + self.offset) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart @@ -1413,10 +1411,10 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past is not None: # non xla + past - decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) - else: # non xla + non past - decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 4db1ff6a28ec87..6d91a05f52861a 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1423,10 +1423,10 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: # xla decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] - elif past is not None: # non xla + past - decoder_position_ids = tf.broadcast_to(past[0][0].shape[2], (decoder_input_ids.shape[0], 1)) - else: # non xla + non past - decoder_position_ids = tf.broadcast_to(tf.range(decoder_input_ids.shape[1]), decoder_input_ids.shape) + elif past is not None: # no xla + past + decoder_position_ids = past[0][0].shape[2] + else: # no xla + no past + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index 6ff3bdadff0698..0df55500db37e3 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -128,7 +128,8 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32) output_from_no_past = model( next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids - )[0] + ) + output_from_no_past = output_from_no_past[0] decoder_position_ids = ( tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2] @@ -138,7 +139,8 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): attention_mask=next_attention_mask, past_key_values=past_key_values, position_ids=decoder_position_ids, - )[0] + ) + output_from_past = output_from_past[0] self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) @@ -151,7 +153,7 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args): - config.eos_token_id = None + config.eos_token_id = None # Generate until max length config.max_length = 10 config.do_sample = False config.num_beams = 1 diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index 93b48ce8f29948..efa3f0ac1c0568 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -295,7 +295,7 @@ def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args): - config.eos_token_id = None + config.eos_token_id = None # Generate until max length config.max_length = 10 model = TFGPT2LMHeadModel(config=config) diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index 5ad746e34fc877..e815fd7ad07a36 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -228,7 +228,7 @@ def create_and_check_t5_decoder_model_past_large_inputs( tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args): - config.eos_token_id = None + config.eos_token_id = None # Generate until max length config.max_length = 10 config.do_sample = False config.num_beams = 1