From 3e9d0f7f599439a9fcfebda5d3edff185a4c9829 Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Sat, 12 Mar 2022 13:06:55 +0100 Subject: [PATCH] Change unpacking of TF Bart inputs (#16094) --- .../models/bart/modeling_tf_bart.py | 346 +++++++----------- 1 file changed, 124 insertions(+), 222 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 2b1df1a73586cb..3b7e3a03a56b8b 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -42,8 +42,8 @@ TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -660,6 +660,7 @@ def get_embed_tokens(self): def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens + @unpack_inputs def call( self, input_ids=None, @@ -708,80 +709,67 @@ def call( return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) - hidden_states = inputs["inputs_embeds"] + embed_pos + hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # check attention mask and invert - if inputs["attention_mask"] is not None: + if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(inputs["attention_mask"]) + attention_mask = _expand_mask(attention_mask) else: attention_mask = None - encoder_states = () if inputs["output_hidden_states"] else None - all_attentions = () if inputs["output_attentions"] else None + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): + if head_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], + shape_list(head_mask)[0], len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.", ) # encoder layers for idx, encoder_layer in enumerate(self.layers): - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer + if training and (dropout_probability < self.layerdrop): # skip the layer continue hidden_states, attn = encoder_layer( hidden_states, attention_mask, - inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + head_mask[idx] if head_mask is not None else None, ) - if inputs["output_attentions"]: + if output_attentions: all_attentions += (attn,) - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not inputs["return_dict"]: + if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return TFBaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions @@ -822,6 +810,7 @@ def get_embed_tokens(self): def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens + @unpack_inputs def call( self, input_ids=None, @@ -899,45 +888,25 @@ def call( return_dict (`bool`, *optional*): Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - past_key_values_length = ( - shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0 - ) + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions positions = self.embed_positions(input_shape, past_key_values_length) - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - hidden_states = inputs["inputs_embeds"] + hidden_states = inputs_embeds # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if input_shape[-1] > 1: @@ -947,72 +916,68 @@ def call( tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None: - combined_attention_mask = combined_attention_mask + _expand_mask( - inputs["attention_mask"], tgt_len=input_shape[-1] - ) + if attention_mask is not None: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: + if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) hidden_states = self.layernorm_embedding(hidden_states + positions) - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # decoder layers - all_hidden_states = () if inputs["output_hidden_states"] else None - all_self_attns = () if inputs["output_attentions"] else None - all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None - present_key_values = () if inputs["use_cache"] else None + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + present_key_values = () if use_cache else None # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - for attn_mask in ["head_mask", "cross_attn_head_mask"]: - if inputs[attn_mask] is not None and tf.executing_eagerly(): + for attn_mask in [head_mask, cross_attn_head_mask]: + if attn_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs[attn_mask])[0], + shape_list(attn_mask)[0], len(self.layers), - message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): + if training and (dropout_probability < self.layerdrop): continue - past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] - if inputs["cross_attn_head_mask"] is not None - else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, past_key_value=past_key_value, ) - if inputs["use_cache"]: + if use_cache: present_key_values += (present_key_value,) - if inputs["output_attentions"]: + if output_attentions: all_self_attns += (layer_self_attn,) - if inputs["encoder_hidden_states"] is not None: + if encoder_hidden_states is not None: all_cross_attns += (layer_cross_attn,) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) - if not inputs["return_dict"]: + if not return_dict: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: return TFBaseModelOutputWithPastAndCrossAttentions( @@ -1062,6 +1027,7 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens) + @unpack_inputs def call( self, input_ids=None, @@ -1082,82 +1048,59 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: - inputs["use_cache"] = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + use_cache = False - inputs["output_hidden_states"] = ( - inputs["output_hidden_states"] - if inputs["output_hidden_states"] is not None - else self.config.output_hidden_states + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None: - inputs["decoder_input_ids"] = shift_tokens_right( - inputs["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id + if decoder_input_ids is None and input_ids is not None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) - if inputs["encoder_outputs"] is None: - inputs["encoder_outputs"] = self.encoder( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): - inputs["encoder_outputs"] = TFBaseModelOutput( - last_hidden_state=inputs["encoder_outputs"][0], - hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, - attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): - inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() decoder_outputs = self.decoder( - inputs["decoder_input_ids"], - attention_mask=inputs["decoder_attention_mask"], - encoder_hidden_states=inputs["encoder_outputs"][0], - encoder_attention_mask=inputs["attention_mask"], - head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) - if not inputs["return_dict"]: - return decoder_outputs + inputs["encoder_outputs"] + if not return_dict: + return decoder_outputs + encoder_outputs return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, @@ -1165,9 +1108,9 @@ def call( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, - encoder_hidden_states=inputs["encoder_outputs"].hidden_states, - encoder_attentions=inputs["encoder_outputs"].attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, ) @@ -1197,6 +1140,7 @@ def get_decoder(self): output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, ) + @unpack_inputs def call( self, input_ids=None, @@ -1217,9 +1161,8 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, + + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -1236,26 +1179,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - - outputs = self.model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - decoder_attention_mask=inputs["decoder_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - encoder_outputs=inputs["encoder_outputs"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["inputs_embeds"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -1322,6 +1245,7 @@ def set_bias(self, value): @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @add_end_docstrings(BART_GENERATION_EXAMPLE) + @unpack_inputs def call( self, input_ids=None, @@ -1352,17 +1276,28 @@ def call( Returns: """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + + if labels is not None: + labels = tf.where( + labels == self.config.pad_token_id, + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), + labels, + ) + use_cache = False + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1370,46 +1305,13 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - - if inputs["labels"] is not None: - inputs["labels"] = tf.where( - inputs["labels"] == self.config.pad_token_id, - tf.cast(tf.fill(shape_list(inputs["labels"]), -100), inputs["labels"].dtype), - inputs["labels"], - ) - inputs["use_cache"] = False - if inputs["decoder_input_ids"] is None: - inputs["decoder_input_ids"] = shift_tokens_right( - inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - encoder_outputs=inputs["encoder_outputs"], - decoder_attention_mask=inputs["decoder_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["inputs_embeds"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) lm_logits = self.model.shared(outputs[0], mode="linear") lm_logits = lm_logits + self.final_logits_bias - masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - if not inputs["return_dict"]: + if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return TFSeq2SeqLMOutput(