Skip to content

Commit

Permalink
feat: Transformer-based embedding function
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Mar 1, 2024
1 parent e1ad5f9 commit 27e280f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
15 changes: 15 additions & 0 deletions chromadb/test/ef/test_transformer_ef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from chromadb.utils.embedding_functions import TransformerEmbeddingFunction


def test_transformer_ef_default_mdoel():
ef = TransformerEmbeddingFunction()
embedding = ef(["text"])
assert len(embedding[0]) == 384


def test_transformer_ef_custom_model():
ef = TransformerEmbeddingFunction(model_name="dbmdz/bert-base-turkish-cased")
embedding = ef(["Merhaba dünya", "Bu bir test cümlesidir"])
assert embedding is not None
assert len(embedding) == 2
assert len(embedding[0]) == 768
34 changes: 34 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,40 @@ def __call__(self, input: Documents) -> Embeddings:
)


class TransformerEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
cache_dir: Optional[str] = None,
):
try:
from transformers import AutoModel, AutoTokenizer

self._torch = importlib.import_module("torch")
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
except ImportError:
raise ValueError(
"The transformers and/or pytorch python package is not installed. Please install it with "
"`pip install transformers` or `pip install torch`"
)

@staticmethod
def _normalize(v: npt.NDArray) -> npt.NDArray:
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return cast(npt.NDArray, v / norm[:, np.newaxis])

def __call__(self, input: Documents) -> Embeddings:
inputs = self._tokenizer(
input, padding=True, truncation=True, return_tensors="pt"
)
with self._torch.no_grad():
outputs = self._model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
return [e.tolist() for e in self._normalize(embeddings)]


# List of all classes in this module
_classes = [
name
Expand Down

0 comments on commit 27e280f

Please sign in to comment.