Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RoCBert support for Bettertransformer #542

Merged
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ The list of supported model below:
- [M2M100](https://arxiv.org/abs/2010.11125)
- [RemBERT](https://arxiv.org/abs/2010.12821)
- [RoBERTa](https://arxiv.org/abs/1907.11692)
- [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf)
- [Splinter](https://arxiv.org/abs/2101.00438)
- [Tapas](https://arxiv.org/abs/2211.06550)
- [ViLT](https://arxiv.org/abs/2102.03334)
Expand Down
1 change: 1 addition & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BetterTransformerManager:
"mbart": ("MBartEncoderLayer", MBartEncoderLayerBetterTransformer),
"rembert": ("RemBertLayer", BertLayerBetterTransformer),
"roberta": ("RobertaLayer", BertLayerBetterTransformer),
"roc_bert": ("RoCBertLayer", BertLayerBetterTransformer),
"splinter": ("SplinterLayer", BertLayerBetterTransformer),
"tapas": ("TapasLayer", BertLayerBetterTransformer),
"vilt": ("ViltLayer", ViltLayerBetterTransformer),
Expand Down
8 changes: 8 additions & 0 deletions tests/bettertransformer/test_bettertransformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def test_accelerate_compatibility_single_gpu_without_keeping(self):
self.check_accelerate_compatibility_cpu_gpu(keep_original_model=False, max_memory=max_memory)


class BetterTransformersRoCBertTest(BetterTransformersEncoderTest):
all_models_to_test = ["hf-internal-testing/tiny-random-RoCBertModel"]

# unrelated issue with torch.amp.autocast with rocbert (expected scalar type BFloat16 but found Float)
def test_raise_autocast(self):
pass


class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Full testing suite of the `BetterTransformers` integration into Hugging Face
Expand Down