-
Notifications
You must be signed in to change notification settings - Fork 817
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into klaijan/ci-text-extraction
- Loading branch information
Showing
4 changed files
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
-c constraints.in | ||
-c base.txt | ||
|
||
huggingface | ||
langchain | ||
sentence_transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# | ||
# This file is autogenerated by pip-compile with Python 3.8 | ||
# by the following command: | ||
# | ||
# pip-compile --constraint=requirements/constraints.in requirements/embed-huggingface.in | ||
# | ||
aiohttp==3.8.6 | ||
# via langchain | ||
aiosignal==1.3.1 | ||
# via aiohttp | ||
anyio==3.7.1 | ||
# via | ||
# -c requirements/constraints.in | ||
# langchain | ||
async-timeout==4.0.3 | ||
# via | ||
# aiohttp | ||
# langchain | ||
attrs==23.1.0 | ||
# via aiohttp | ||
certifi==2023.7.22 | ||
# via | ||
# -c requirements/base.txt | ||
# -c requirements/constraints.in | ||
# requests | ||
charset-normalizer==3.3.0 | ||
# via | ||
# -c requirements/base.txt | ||
# aiohttp | ||
# requests | ||
click==8.1.7 | ||
# via | ||
# -c requirements/base.txt | ||
# nltk | ||
dataclasses-json==0.6.1 | ||
# via | ||
# -c requirements/base.txt | ||
# langchain | ||
exceptiongroup==1.1.3 | ||
# via anyio | ||
filelock==3.12.4 | ||
# via | ||
# huggingface-hub | ||
# torch | ||
# transformers | ||
frozenlist==1.4.0 | ||
# via | ||
# aiohttp | ||
# aiosignal | ||
fsspec==2023.9.1 | ||
# via | ||
# -c requirements/constraints.in | ||
# huggingface-hub | ||
# torch | ||
huggingface==0.0.1 | ||
# via -r requirements/embed-huggingface.in | ||
huggingface-hub==0.17.3 | ||
# via | ||
# sentence-transformers | ||
# tokenizers | ||
# transformers | ||
idna==3.4 | ||
# via | ||
# -c requirements/base.txt | ||
# anyio | ||
# requests | ||
# yarl | ||
jinja2==3.1.2 | ||
# via torch | ||
joblib==1.3.2 | ||
# via | ||
# -c requirements/base.txt | ||
# nltk | ||
# scikit-learn | ||
jsonpatch==1.33 | ||
# via langchain | ||
jsonpointer==2.4 | ||
# via jsonpatch | ||
langchain==0.0.317 | ||
# via -r requirements/embed-huggingface.in | ||
langsmith==0.0.46 | ||
# via langchain | ||
markupsafe==2.1.3 | ||
# via jinja2 | ||
marshmallow==3.20.1 | ||
# via | ||
# -c requirements/base.txt | ||
# dataclasses-json | ||
mpmath==1.3.0 | ||
# via sympy | ||
multidict==6.0.4 | ||
# via | ||
# aiohttp | ||
# yarl | ||
mypy-extensions==1.0.0 | ||
# via | ||
# -c requirements/base.txt | ||
# typing-inspect | ||
networkx==3.1 | ||
# via torch | ||
nltk==3.8.1 | ||
# via | ||
# -c requirements/base.txt | ||
# sentence-transformers | ||
numpy==1.24.4 | ||
# via | ||
# -c requirements/base.txt | ||
# -c requirements/constraints.in | ||
# langchain | ||
# scikit-learn | ||
# scipy | ||
# sentence-transformers | ||
# torchvision | ||
# transformers | ||
packaging==23.2 | ||
# via | ||
# -c requirements/base.txt | ||
# huggingface-hub | ||
# marshmallow | ||
# transformers | ||
pillow==10.1.0 | ||
# via torchvision | ||
pydantic==1.10.13 | ||
# via | ||
# -c requirements/constraints.in | ||
# langchain | ||
# langsmith | ||
pyyaml==6.0.1 | ||
# via | ||
# huggingface-hub | ||
# langchain | ||
# transformers | ||
regex==2023.10.3 | ||
# via | ||
# -c requirements/base.txt | ||
# nltk | ||
# transformers | ||
requests==2.31.0 | ||
# via | ||
# -c requirements/base.txt | ||
# huggingface-hub | ||
# langchain | ||
# langsmith | ||
# torchvision | ||
# transformers | ||
safetensors==0.3.2 | ||
# via | ||
# -c requirements/constraints.in | ||
# transformers | ||
scikit-learn==1.3.1 | ||
# via sentence-transformers | ||
scipy==1.10.1 | ||
# via | ||
# -c requirements/constraints.in | ||
# scikit-learn | ||
# sentence-transformers | ||
sentence-transformers==2.2.2 | ||
# via -r requirements/embed-huggingface.in | ||
sentencepiece==0.1.99 | ||
# via sentence-transformers | ||
sniffio==1.3.0 | ||
# via anyio | ||
sqlalchemy==2.0.22 | ||
# via langchain | ||
sympy==1.12 | ||
# via torch | ||
tenacity==8.2.3 | ||
# via langchain | ||
threadpoolctl==3.2.0 | ||
# via scikit-learn | ||
tokenizers==0.14.1 | ||
# via transformers | ||
torch==2.1.0 | ||
# via | ||
# -c requirements/constraints.in | ||
# sentence-transformers | ||
# torchvision | ||
torchvision==0.16.0 | ||
# via sentence-transformers | ||
tqdm==4.66.1 | ||
# via | ||
# -c requirements/base.txt | ||
# huggingface-hub | ||
# nltk | ||
# sentence-transformers | ||
# transformers | ||
transformers==4.34.1 | ||
# via sentence-transformers | ||
typing-extensions==4.8.0 | ||
# via | ||
# -c requirements/base.txt | ||
# huggingface-hub | ||
# pydantic | ||
# sqlalchemy | ||
# torch | ||
# typing-inspect | ||
typing-inspect==0.9.0 | ||
# via | ||
# -c requirements/base.txt | ||
# dataclasses-json | ||
urllib3==1.26.18 | ||
# via | ||
# -c requirements/base.txt | ||
# -c requirements/constraints.in | ||
# requests | ||
yarl==1.9.2 | ||
# via aiohttp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from unstructured.documents.elements import Text | ||
from unstructured.embed.huggingface import HuggingFaceEmbeddingEncoder | ||
|
||
|
||
def test_embed_documents_does_not_break_element_to_dict(mocker): | ||
# Mocked client with the desired behavior for embed_documents | ||
mock_client = mocker.MagicMock() | ||
mock_client.embed_documents.return_value = [1, 2] | ||
|
||
# Mock get_openai_client to return our mock_client | ||
mocker.patch.object( | ||
HuggingFaceEmbeddingEncoder, | ||
"get_huggingface_client", | ||
return_value=mock_client, | ||
) | ||
|
||
encoder = HuggingFaceEmbeddingEncoder() | ||
elements = encoder.embed_documents( | ||
elements=[Text("This is sentence 1"), Text("This is sentence 2")], | ||
) | ||
assert len(elements) == 2 | ||
assert elements[0].to_dict()["text"] == "This is sentence 1" | ||
assert elements[1].to_dict()["text"] == "This is sentence 2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
|
||
from unstructured.documents.elements import ( | ||
Element, | ||
) | ||
from unstructured.embed.interfaces import BaseEmbeddingEncoder | ||
from unstructured.ingest.error import EmbeddingEncoderConnectionError | ||
from unstructured.utils import requires_dependencies | ||
|
||
|
||
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder): | ||
def __init__( | ||
self, | ||
model_name: Optional[str] = "sentence-transformers/all-MiniLM-L6-v2", | ||
model_kwargs: Optional[dict] = {"device": "cpu"}, | ||
encode_kwargs: Optional[dict] = {"normalize_embeddings": False}, | ||
cache_folder: Optional[dict] = None, | ||
): | ||
self.model_name = model_name | ||
self.model_kwargs = model_kwargs | ||
self.encode_kwargs = encode_kwargs | ||
self.cache_folder = cache_folder | ||
|
||
self.initialize() | ||
|
||
def initialize(self): | ||
"""Creates a langchain HuggingFace object to embed elements.""" | ||
self.hf = self.get_huggingface_client() | ||
|
||
def num_of_dimensions(self): | ||
return np.shape(self.examplary_embedding) | ||
|
||
def is_unit_vector(self): | ||
return np.isclose(np.linalg.norm(self.examplary_embedding), 1.0) | ||
|
||
def embed_query(self, query): | ||
return self.hf.embed_query(str(query)) | ||
|
||
def embed_documents(self, elements: List[Element]) -> List[Element]: | ||
embeddings = self.hf.embed_documents([str(e) for e in elements]) | ||
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings) | ||
return elements_with_embeddings | ||
|
||
def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]: | ||
assert len(elements) == len(embeddings) | ||
elements_w_embedding = [] | ||
|
||
for i, element in enumerate(elements): | ||
element.embeddings = embeddings[i] | ||
elements_w_embedding.append(element) | ||
return elements | ||
|
||
@EmbeddingEncoderConnectionError.wrap | ||
@requires_dependencies( | ||
["langchain", "sentence_transformers"], | ||
extras="embed-huggingface", | ||
) | ||
def get_huggingface_client(self): | ||
"""Creates a langchain Huggingface python client to embed elements.""" | ||
if hasattr(self, "hf_client"): | ||
return self.hf_client | ||
|
||
from langchain.embeddings import HuggingFaceEmbeddings | ||
|
||
hf_client = HuggingFaceEmbeddings( | ||
model_name=self.model_name, | ||
model_kwargs=self.model_kwargs, | ||
encode_kwargs=self.encode_kwargs, | ||
cache_folder=self.cache_folder, | ||
) | ||
self.examplary_embedding = hf_client.embed_query("Q") | ||
return hf_client |