Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
Signed-off-by: Issam Arabi <[email protected]>
  • Loading branch information
issamarabi committed Oct 1, 2023
1 parent c9711b8 commit a2ab86d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
20 changes: 6 additions & 14 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,16 +1251,8 @@ def __init__(self, detr_layer, config):
self.norm_first = True

self.original_layers_mapping = {
"in_proj_weight": [
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight"
],
"in_proj_bias": [
"self_attn.q_proj.bias",
"self_attn.k_proj.bias",
"self_attn.v_proj.bias"
],
"in_proj_weight": ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"],
"in_proj_bias": ["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"],
"out_proj_weight": "self_attn.out_proj.weight",
"out_proj_bias": "self_attn.out_proj.bias",
"linear1_weight": "fc1.weight",
Expand All @@ -1272,7 +1264,7 @@ def __init__(self, detr_layer, config):
"norm2_weight": "final_layer_norm.weight",
"norm2_bias": "final_layer_norm.bias",
}

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__):
Expand Down Expand Up @@ -1303,15 +1295,15 @@ def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **
self.linear2_bias,
attention_mask,
)

if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)

else:
raise NotImplementedError(
"Training and Autocast are not implemented for BetterTransformer + Detr. Please open an issue."
)

return (hidden_states,)


Expand Down
15 changes: 13 additions & 2 deletions tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "detr", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
SUPPORTED_ARCH = [
"blip-2",
"clip",
"clip_text_model",
"deit",
"detr",
"vilt",
"vit",
"vit_mae",
"vit_msn",
"yolos",
]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
if model_type == "vilt":
Expand Down Expand Up @@ -56,7 +67,7 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc

if model_type == "blip-2":
inputs["decoder_input_ids"] = inputs["input_ids"]

elif model_type == "detr":
# Assuming detr just needs an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand Down

0 comments on commit a2ab86d

Please sign in to comment.