From 39dc25ab736a15197229a1b57e474b3ec9aab4e4 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Tue, 19 Mar 2024 13:53:19 +0800 Subject: [PATCH 1/2] Add TransformersBgeEmbeddings class in bigdl.llm.langchain.embeddings --- .../src/bigdl/llm/langchain/embeddings/__init__.py | 5 +++-- .../langchain/embeddings/transformersembeddings.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py index d001919c976..e6ec52acf8d 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py @@ -20,7 +20,7 @@ # only search the first bigdl package and end up finding only one sub-package. from .bigdlllm import * -from .transformersembeddings import TransformersEmbeddings +from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings __all__ = [ "BigdlNativeEmbeddings", @@ -28,5 +28,6 @@ "BloomEmbeddings", "GptneoxEmbeddings", "StarcoderEmbeddings", - "TransformersEmbeddings" + "TransformersEmbeddings", + "TransformersBgeEmbeddings" ] diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py index c52a8adf285..b62b2c8178c 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -45,6 +45,7 @@ # THE SOFTWARE. """Wrapper around BigdlLLM embedding models.""" +import torch from typing import Any, Dict, List, Optional import numpy as np @@ -181,3 +182,14 @@ def embed_query(self, text: str) -> List[float]: text = text.replace("\n", " ") embedding = self.embed(text, **self.encode_kwargs) return embedding.tolist() + +# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings +# TODO: directly support HuggingFaceBgeEmbeddings +class TransformersBgeEmbeddings(TransformersEmbeddings): + + def embed(self, text: str, **kwargs): + input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) + input_ids = input_ids.to(self.model.device) + embeddings = self.model(input_ids, return_dict=False)[0].cpu() + embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1) + return embeddings[0] \ No newline at end of file From cee27c0c6bf1ce371ac157324747afb10f761218 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Tue, 19 Mar 2024 13:56:21 +0800 Subject: [PATCH 2/2] Small fixes --- .../bigdl/llm/langchain/embeddings/transformersembeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py index b62b2c8178c..9c69f4744c3 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -192,4 +192,4 @@ def embed(self, text: str, **kwargs): input_ids = input_ids.to(self.model.device) embeddings = self.model(input_ids, return_dict=False)[0].cpu() embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1) - return embeddings[0] \ No newline at end of file + return embeddings[0]