diff --git a/chromadb/test/ef/test_transformer_ef.py b/chromadb/test/ef/test_transformer_ef.py new file mode 100644 index 000000000000..7199a0eeaa38 --- /dev/null +++ b/chromadb/test/ef/test_transformer_ef.py @@ -0,0 +1,15 @@ +from chromadb.utils.embedding_functions import TransformerEmbeddingFunction + + +def test_transformer_ef_default_mdoel(): + ef = TransformerEmbeddingFunction() + embedding = ef(["text"]) + assert len(embedding[0]) == 384 + + +def test_transformer_ef_custom_model(): + ef = TransformerEmbeddingFunction(model_name="dbmdz/bert-base-turkish-cased") + embedding = ef(["Merhaba dünya", "Bu bir test cümlesidir"]) + assert embedding is not None + assert len(embedding) == 2 + assert len(embedding[0]) == 768 diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index f54ab88c42e3..cba9185f1133 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -815,6 +815,40 @@ def __call__(self, input: Documents) -> Embeddings: ) +class TransformerEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, + model_name: str = "sentence-transformers/all-MiniLM-L6-v2", + cache_dir: Optional[str] = None, + ): + try: + from transformers import AutoModel, AutoTokenizer + + self._torch = importlib.import_module("torch") + self._tokenizer = AutoTokenizer.from_pretrained(model_name) + self._model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir) + except ImportError: + raise ValueError( + "The transformers and/or pytorch python package is not installed. Please install it with " + "`pip install transformers` or `pip install torch`" + ) + + @staticmethod + def _normalize(v: npt.NDArray) -> npt.NDArray: + norm = np.linalg.norm(v, axis=1) + norm[norm == 0] = 1e-12 + return cast(npt.NDArray, v / norm[:, np.newaxis]) + + def __call__(self, input: Documents) -> Embeddings: + inputs = self._tokenizer( + input, padding=True, truncation=True, return_tensors="pt" + ) + with self._torch.no_grad(): + outputs = self._model(**inputs) + embeddings = outputs.last_hidden_state.mean(dim=1) + return [e.tolist() for e in self._normalize(embeddings)] + + # List of all classes in this module _classes = [ name