Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
Signed-off-by: Shogo Hida <[email protected]>
  • Loading branch information
shogohida committed Jan 8, 2023
1 parent 6c09f51 commit 61ff473
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion tests/bettertransformer/test_bettertransformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
"hf-internal-testing/tiny-random-MarkupLMModel",
"hf-internal-testing/tiny-random-rembert",
"hf-internal-testing/tiny-random-RobertaModel",
"hf-internal-testing/tiny-random-RoCBertModel",
"hf-internal-testing/tiny-random-SplinterModel",
"hf-internal-testing/tiny-random-TapasModel",
"hf-internal-testing/tiny-random-RoCBertModel",
"hf-internal-testing/tiny-xlm-roberta",
"ybelkada/random-tiny-BertGenerationModel",
]
Expand Down Expand Up @@ -256,6 +256,29 @@ def test_accelerate_compatibility_single_gpu_without_keeping(self):
self.check_accelerate_compatibility_cpu_gpu(keep_original_model=False, max_memory=max_memory)


class BetterTransformersRoCBertTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Full testing suite of the `BetterTransformers` integration into Hugging Face
`transformers` ecosystem. Check the docstring of each test to understand the
purpose of each test. Basically we test:
- if the conversion dictionnary is consistent, ie if the converted model exists
in HuggingFace `transformers` library.
- if the converted model produces the same logits as the original model.
- if the converted model is faster than the original model.
"""
all_models_to_test = ALL_ENCODER_MODELS_TO_TEST

def tearDown(self):
gc.collect()

def prepare_inputs_for_class(self, model_id=None):
input_dict = {
"input_ids": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]),
"attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]),
}
return input_dict


class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Full testing suite of the `BetterTransformers` integration into Hugging Face
Expand Down

0 comments on commit 61ff473

Please sign in to comment.