From 841620684b75ce63918e8e9dfecdd3b46394bbc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Gustavo=20A=2E=20Amorim?= Date: Sat, 12 Mar 2022 12:05:13 -0300 Subject: [PATCH] apply unpack_input decorator to ViT model (#16102) --- .../models/vit/modeling_tf_vit.py | 91 +++++-------------- 1 file changed, 24 insertions(+), 67 deletions(-) diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index 9a7025c662d71e..9818cf29d137d2 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -30,8 +30,8 @@ TFPreTrainedModel, TFSequenceClassificationLoss, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -477,6 +477,7 @@ class PreTrainedModel """ raise NotImplementedError + @unpack_inputs def call( self, pixel_values: Optional[TFModelInputType] = None, @@ -488,29 +489,14 @@ def call( training: bool = False, **kwargs, ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=pixel_values, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - - if "input_ids" in inputs: - inputs["pixel_values"] = inputs.pop("input_ids") - if inputs["pixel_values"] is None: + if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output = self.embeddings( - pixel_values=inputs["pixel_values"], - interpolate_pos_encoding=inputs["interpolate_pos_encoding"], - training=inputs["training"], + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + training=training, ) # Prepare head mask if needed @@ -518,25 +504,25 @@ def call( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.config.num_hidden_layers + head_mask = [None] * self.config.num_hidden_layers encoder_outputs = self.encoder( hidden_states=embedding_output, - head_mask=inputs["head_mask"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) sequence_output = encoder_outputs[0] sequence_output = self.layernorm(inputs=sequence_output) pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None - if not inputs["return_dict"]: + if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return TFBaseModelOutputWithPooling( @@ -659,6 +645,7 @@ def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs) self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit") + @unpack_inputs @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) def call( @@ -692,30 +679,15 @@ def call( >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=pixel_values, + + outputs = self.vit( + pixel_values=pixel_values, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - - if "input_ids" in inputs: - inputs["pixel_values"] = inputs.pop("input_ids") - - outputs = self.vit( - pixel_values=inputs["pixel_values"], - head_mask=inputs["head_mask"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - interpolate_pos_encoding=inputs["interpolate_pos_encoding"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -773,6 +745,7 @@ def __init__(self, config: ViTConfig, *inputs, **kwargs): name="classifier", ) + @unpack_inputs @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -816,37 +789,21 @@ def call( >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=pixel_values, + + outputs = self.vit( + pixel_values=pixel_values, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - - if "input_ids" in inputs: - inputs["pixel_values"] = inputs.pop("input_ids") - - outputs = self.vit( - pixel_values=inputs["pixel_values"], - head_mask=inputs["head_mask"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - interpolate_pos_encoding=inputs["interpolate_pos_encoding"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] logits = self.classifier(inputs=sequence_output[:, 0, :]) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output