Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unpack_input decorator to ViT model #16102

Merged
merged 1 commit into from
Mar 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 24 additions & 67 deletions src/transformers/models/vit/modeling_tf_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
Expand Down Expand Up @@ -477,6 +477,7 @@ class PreTrainedModel
"""
raise NotImplementedError

@unpack_inputs
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
Expand All @@ -488,55 +489,40 @@ def call(
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")

if inputs["pixel_values"] is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

embedding_output = self.embeddings(
pixel_values=inputs["pixel_values"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
training=inputs["training"],
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.num_hidden_layers

encoder_outputs = self.encoder(
hidden_states=embedding_output,
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)

sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None

if not inputs["return_dict"]:
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return TFBaseModelOutputWithPooling(
Expand Down Expand Up @@ -659,6 +645,7 @@ def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs)

self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")

@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -692,30 +679,15 @@ def call(
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,

outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")

outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

return outputs
Expand Down Expand Up @@ -773,6 +745,7 @@ def __init__(self, config: ViTConfig, *inputs, **kwargs):
name="classifier",
)

@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -816,37 +789,21 @@ def call(
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,

outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)

if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")

outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down