Skip to content

Commit

Permalink
Fx support for Deberta-v[1-2], Hubert and LXMERT (huggingface#17539)
Browse files Browse the repository at this point in the history
* Support for deberta and deberta-v2

* Support for LXMert

* Support for Hubert

* Fix for pt1.11

* Trigger CI
  • Loading branch information
michaelbenayoun authored and elusenji committed Jun 12, 2022
1 parent 3c8bb28 commit bd519b3
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 29 deletions.
12 changes: 6 additions & 6 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
rmask = ~(mask.bool())
rmask = ~(mask.to(torch.bool))

output = input.masked_fill(rmask, float("-inf"))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
Expand All @@ -129,7 +129,7 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))

Expand All @@ -152,7 +152,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None

if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)

if isinstance(local_context, DropoutContext):
if local_context.mask is None:
Expand Down Expand Up @@ -564,7 +564,7 @@ def __init__(self, config):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -652,7 +652,7 @@ def linear(w, b, x):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
rmask = ~(mask.bool())
rmask = ~(mask.to(torch.bool))

output = input.masked_fill(rmask, float("-inf"))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
Expand All @@ -132,7 +132,7 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))

Expand All @@ -157,7 +157,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None

if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)

if isinstance(local_context, DropoutContext):
if local_context.mask is None:
Expand Down Expand Up @@ -638,7 +638,7 @@ def __init__(self, config):

def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))

def forward(
Expand Down Expand Up @@ -719,7 +719,7 @@ def forward(
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def transpose_for_scores(self, x):
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
Expand Down Expand Up @@ -365,7 +365,7 @@ def forward(self, hidden_states, context, attention_mask=None, output_attentions
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
Expand Down Expand Up @@ -1253,7 +1253,7 @@ def forward(
visual_prediction_scores = visual_prediction_scores_dict[key]
visual_loss = visual_loss_fct(
visual_prediction_scores.view(-1, output_dim),
label.view(*label_shape),
label.view(label_shape),
)
if visual_loss.dim() > 1: # Regression Losses
visual_loss = visual_loss.mean(1)
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None

if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)

if isinstance(local_context, DropoutContext):
if local_context.mask is None:
Expand Down Expand Up @@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
rmask = ~(mask.bool())
rmask = ~(mask.to(torch.bool))

output = input.masked_fill(rmask, float("-inf"))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
Expand All @@ -557,7 +557,7 @@ def symbolic(g, self, mask, dim):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))

Expand Down Expand Up @@ -711,7 +711,7 @@ def __init__(self, config):

def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))

def forward(
Expand Down Expand Up @@ -792,7 +792,7 @@ def forward(
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
Expand Down
63 changes: 61 additions & 2 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
Expand Down Expand Up @@ -72,6 +74,8 @@ def _generate_supported_model_class_names(
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
}

if supported_tasks is None:
Expand All @@ -95,12 +99,16 @@ def _generate_supported_model_class_names(
"blenderbot",
"blenderbot-small",
"clip",
"deberta",
"deberta-v2",
"distilbert",
"electra",
"gpt2",
"gpt_neo",
"gptj",
"hubert",
"layoutlm",
"lxmert",
"m2m_100",
"marian",
"mbart",
Expand All @@ -118,8 +126,8 @@ def _generate_supported_model_class_names(
"trocr",
"vit",
"xglm",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "xlnet",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
]

_REGULAR_SUPPORTED_MODELS = []
Expand Down Expand Up @@ -155,6 +163,10 @@ def torch_nn_layernorm(self, input):
return input


def torch_nn_groupnorm(self, input):
return input


def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")

Expand Down Expand Up @@ -372,6 +384,27 @@ def torch_nn_conv2d(self, input):
return torch.empty(shape, device="meta")


def torch_squeeze(input, dim=None):
shape = list(input.shape)
if dim is not None:
if dim < 0:
dim = input.dim() + dim
if shape[dim] == 1:
shape.pop(dim)
else:
new_shape = []
for dim_value in shape:
if dim_value == 1:
continue
new_shape.append(dim_value)
shape = new_shape
return torch.empty(shape, device="meta")


def torch_tensor_squeeze(self, dim=None):
return torch_squeeze(self, dim)


def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
Expand Down Expand Up @@ -446,6 +479,7 @@ def to_concrete(t):
torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.GroupNorm: torch_nn_groupnorm,
torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu,
torch.nn.functional.relu: torch_nn_functional_relu,
Expand All @@ -469,6 +503,8 @@ def to_concrete(t):
torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d,
torch.squeeze: torch_squeeze,
torch.Tensor.squeeze: torch_tensor_squeeze,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
Expand Down Expand Up @@ -605,7 +641,7 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]

def __init__(self, autowrap_modules=(math,), autowrap_functions=()):

Expand Down Expand Up @@ -704,8 +740,31 @@ def _generate_dummy_input(
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "visual_feats" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
+ [
model.config.visual_feat_dim,
],
dtype=torch.float,
device=device,
)
elif "visual_pos" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
+ [
model.config.visual_pos_dim,
],
dtype=torch.float,
device=device,
)
elif "inputs" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
elif "input_values" in input_name:
batch_size, _ = shape
# Generating big sequence length for audio inputs.
seq_length = _generate_random_int(low=10000, high=20000)
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else:
Expand Down
1 change: 1 addition & 0 deletions tests/models/deberta/test_modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)

fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
Expand Down
1 change: 1 addition & 0 deletions tests/models/deberta_v2/test_modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)

fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
Expand Down
Loading

0 comments on commit bd519b3

Please sign in to comment.