forked from UKPLab/sentence-transformers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_sentence_transformer.py
47 lines (39 loc) · 2.2 KB
/
test_sentence_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
Tests general behaviour of the SentenceTransformer class
"""
from pathlib import Path
import tempfile
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
import unittest
class TestSentenceTransformer(unittest.TestCase):
def test_load_with_safetensors(self):
with tempfile.TemporaryDirectory() as cache_folder:
safetensors_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_folder=cache_folder,
)
# Only the safetensors file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.")
with tempfile.TemporaryDirectory() as cache_folder:
transformer = Transformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_dir=cache_folder,
model_args={"use_safetensors": False},
)
pooling = Pooling(transformer.get_word_embedding_dimension())
pytorch_model = SentenceTransformer(modules=[transformer, pooling])
# Only the pytorch file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.")
sentences = ["This is a test sentence", "This is another test sentence"]
self.assertTrue(
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)),
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings",
)