diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py index fcdaab6359..1420879cc9 100644 --- a/extensions/openai/embeddings.py +++ b/extensions/openai/embeddings.py @@ -1,6 +1,7 @@ import os import numpy as np +from transformers import AutoModel from extensions.openai.errors import ServiceUnavailableError from extensions.openai.utils import debug_msg, float_list_to_base64 @@ -41,7 +42,12 @@ def load_embedding_model(model: str): global embeddings_device, embeddings_model try: print(f"Try embedding model: {model} on {embeddings_device}") - embeddings_model = SentenceTransformer(model, device=embeddings_device) + if 'jina-embeddings' in model: + embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True) # trust_remote_code is needed to use the encode method + embeddings_model = embeddings_model.to(embeddings_device) + else: + embeddings_model = SentenceTransformer(model, device=embeddings_device) + print(f"Loaded embedding model: {model}") except Exception as e: embeddings_model = None