Skip to content

Commit

Permalink
apply unpack_input decorator to ViT model (#16102)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnv1 authored Mar 12, 2022
1 parent 62b05b6 commit 8416206
Showing 1 changed file with 24 additions and 67 deletions.
91 changes: 24 additions & 67 deletions src/transformers/models/vit/modeling_tf_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
Expand Down Expand Up @@ -477,6 +477,7 @@ class PreTrainedModel
"""
raise NotImplementedError

@unpack_inputs
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
Expand All @@ -488,55 +489,40 @@ def call(
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")

if inputs["pixel_values"] is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

embedding_output = self.embeddings(
pixel_values=inputs["pixel_values"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
training=inputs["training"],
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.num_hidden_layers

encoder_outputs = self.encoder(
hidden_states=embedding_output,
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)

sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None

if not inputs["return_dict"]:
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return TFBaseModelOutputWithPooling(
Expand Down Expand Up @@ -659,6 +645,7 @@ def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs)

self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")

@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -692,30 +679,15 @@ def call(
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,

outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")

outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

return outputs
Expand Down Expand Up @@ -773,6 +745,7 @@ def __init__(self, config: ViTConfig, *inputs, **kwargs):
name="classifier",
)

@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -816,37 +789,21 @@ def call(
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,

outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)

if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")

outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down

0 comments on commit 8416206

Please sign in to comment.