diff --git a/requirements/embed-huggingface.in b/requirements/embed-huggingface.in new file mode 100644 index 0000000000..813cae9225 --- /dev/null +++ b/requirements/embed-huggingface.in @@ -0,0 +1,6 @@ +-c constraints.in +-c base.txt + +huggingface +langchain +sentence_transformers \ No newline at end of file diff --git a/requirements/embed-huggingface.txt b/requirements/embed-huggingface.txt new file mode 100644 index 0000000000..217d0e9dec --- /dev/null +++ b/requirements/embed-huggingface.txt @@ -0,0 +1,207 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile --constraint=requirements/constraints.in requirements/embed-huggingface.in +# +aiohttp==3.8.6 + # via langchain +aiosignal==1.3.1 + # via aiohttp +anyio==3.7.1 + # via + # -c requirements/constraints.in + # langchain +async-timeout==4.0.3 + # via + # aiohttp + # langchain +attrs==23.1.0 + # via aiohttp +certifi==2023.7.22 + # via + # -c requirements/base.txt + # -c requirements/constraints.in + # requests +charset-normalizer==3.3.0 + # via + # -c requirements/base.txt + # aiohttp + # requests +click==8.1.7 + # via + # -c requirements/base.txt + # nltk +dataclasses-json==0.6.1 + # via + # -c requirements/base.txt + # langchain +exceptiongroup==1.1.3 + # via anyio +filelock==3.12.4 + # via + # huggingface-hub + # torch + # transformers +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +fsspec==2023.9.1 + # via + # -c requirements/constraints.in + # huggingface-hub + # torch +huggingface==0.0.1 + # via -r requirements/embed-huggingface.in +huggingface-hub==0.17.3 + # via + # sentence-transformers + # tokenizers + # transformers +idna==3.4 + # via + # -c requirements/base.txt + # anyio + # requests + # yarl +jinja2==3.1.2 + # via torch +joblib==1.3.2 + # via + # -c requirements/base.txt + # nltk + # scikit-learn +jsonpatch==1.33 + # via langchain +jsonpointer==2.4 + # via jsonpatch +langchain==0.0.317 + # via -r requirements/embed-huggingface.in +langsmith==0.0.46 + # via langchain +markupsafe==2.1.3 + # via jinja2 +marshmallow==3.20.1 + # via + # -c requirements/base.txt + # dataclasses-json +mpmath==1.3.0 + # via sympy +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 + # via + # -c requirements/base.txt + # typing-inspect +networkx==3.1 + # via torch +nltk==3.8.1 + # via + # -c requirements/base.txt + # sentence-transformers +numpy==1.24.4 + # via + # -c requirements/base.txt + # -c requirements/constraints.in + # langchain + # scikit-learn + # scipy + # sentence-transformers + # torchvision + # transformers +packaging==23.2 + # via + # -c requirements/base.txt + # huggingface-hub + # marshmallow + # transformers +pillow==10.1.0 + # via torchvision +pydantic==1.10.13 + # via + # -c requirements/constraints.in + # langchain + # langsmith +pyyaml==6.0.1 + # via + # huggingface-hub + # langchain + # transformers +regex==2023.10.3 + # via + # -c requirements/base.txt + # nltk + # transformers +requests==2.31.0 + # via + # -c requirements/base.txt + # huggingface-hub + # langchain + # langsmith + # torchvision + # transformers +safetensors==0.3.2 + # via + # -c requirements/constraints.in + # transformers +scikit-learn==1.3.1 + # via sentence-transformers +scipy==1.10.1 + # via + # -c requirements/constraints.in + # scikit-learn + # sentence-transformers +sentence-transformers==2.2.2 + # via -r requirements/embed-huggingface.in +sentencepiece==0.1.99 + # via sentence-transformers +sniffio==1.3.0 + # via anyio +sqlalchemy==2.0.22 + # via langchain +sympy==1.12 + # via torch +tenacity==8.2.3 + # via langchain +threadpoolctl==3.2.0 + # via scikit-learn +tokenizers==0.14.1 + # via transformers +torch==2.1.0 + # via + # -c requirements/constraints.in + # sentence-transformers + # torchvision +torchvision==0.16.0 + # via sentence-transformers +tqdm==4.66.1 + # via + # -c requirements/base.txt + # huggingface-hub + # nltk + # sentence-transformers + # transformers +transformers==4.34.1 + # via sentence-transformers +typing-extensions==4.8.0 + # via + # -c requirements/base.txt + # huggingface-hub + # pydantic + # sqlalchemy + # torch + # typing-inspect +typing-inspect==0.9.0 + # via + # -c requirements/base.txt + # dataclasses-json +urllib3==1.26.18 + # via + # -c requirements/base.txt + # -c requirements/constraints.in + # requests +yarl==1.9.2 + # via aiohttp diff --git a/test_unstructured/embed/test_embed_huggingface.py b/test_unstructured/embed/test_embed_huggingface.py new file mode 100644 index 0000000000..655178ccd6 --- /dev/null +++ b/test_unstructured/embed/test_embed_huggingface.py @@ -0,0 +1,23 @@ +from unstructured.documents.elements import Text +from unstructured.embed.huggingface import HuggingFaceEmbeddingEncoder + + +def test_embed_documents_does_not_break_element_to_dict(mocker): + # Mocked client with the desired behavior for embed_documents + mock_client = mocker.MagicMock() + mock_client.embed_documents.return_value = [1, 2] + + # Mock get_openai_client to return our mock_client + mocker.patch.object( + HuggingFaceEmbeddingEncoder, + "get_huggingface_client", + return_value=mock_client, + ) + + encoder = HuggingFaceEmbeddingEncoder() + elements = encoder.embed_documents( + elements=[Text("This is sentence 1"), Text("This is sentence 2")], + ) + assert len(elements) == 2 + assert elements[0].to_dict()["text"] == "This is sentence 1" + assert elements[1].to_dict()["text"] == "This is sentence 2" diff --git a/unstructured/embed/huggingface.py b/unstructured/embed/huggingface.py new file mode 100644 index 0000000000..fa75fb4008 --- /dev/null +++ b/unstructured/embed/huggingface.py @@ -0,0 +1,74 @@ +from typing import List, Optional + +import numpy as np + +from unstructured.documents.elements import ( + Element, +) +from unstructured.embed.interfaces import BaseEmbeddingEncoder +from unstructured.ingest.error import EmbeddingEncoderConnectionError +from unstructured.utils import requires_dependencies + + +class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder): + def __init__( + self, + model_name: Optional[str] = "sentence-transformers/all-MiniLM-L6-v2", + model_kwargs: Optional[dict] = {"device": "cpu"}, + encode_kwargs: Optional[dict] = {"normalize_embeddings": False}, + cache_folder: Optional[dict] = None, + ): + self.model_name = model_name + self.model_kwargs = model_kwargs + self.encode_kwargs = encode_kwargs + self.cache_folder = cache_folder + + self.initialize() + + def initialize(self): + """Creates a langchain HuggingFace object to embed elements.""" + self.hf = self.get_huggingface_client() + + def num_of_dimensions(self): + return np.shape(self.examplary_embedding) + + def is_unit_vector(self): + return np.isclose(np.linalg.norm(self.examplary_embedding), 1.0) + + def embed_query(self, query): + return self.hf.embed_query(str(query)) + + def embed_documents(self, elements: List[Element]) -> List[Element]: + embeddings = self.hf.embed_documents([str(e) for e in elements]) + elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings) + return elements_with_embeddings + + def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]: + assert len(elements) == len(embeddings) + elements_w_embedding = [] + + for i, element in enumerate(elements): + element.embeddings = embeddings[i] + elements_w_embedding.append(element) + return elements + + @EmbeddingEncoderConnectionError.wrap + @requires_dependencies( + ["langchain", "sentence_transformers"], + extras="embed-huggingface", + ) + def get_huggingface_client(self): + """Creates a langchain Huggingface python client to embed elements.""" + if hasattr(self, "hf_client"): + return self.hf_client + + from langchain.embeddings import HuggingFaceEmbeddings + + hf_client = HuggingFaceEmbeddings( + model_name=self.model_name, + model_kwargs=self.model_kwargs, + encode_kwargs=self.encode_kwargs, + cache_folder=self.cache_folder, + ) + self.examplary_embedding = hf_client.embed_query("Q") + return hf_client