Skip to content

Commit

Permalink
Merge branch 'main' into klaijan/ci-text-extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Klaijan authored Oct 19, 2023
2 parents 389f0e0 + 0063574 commit da019d1
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 0 deletions.
6 changes: 6 additions & 0 deletions requirements/embed-huggingface.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-c constraints.in
-c base.txt

huggingface
langchain
sentence_transformers
207 changes: 207 additions & 0 deletions requirements/embed-huggingface.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --constraint=requirements/constraints.in requirements/embed-huggingface.in
#
aiohttp==3.8.6
# via langchain
aiosignal==1.3.1
# via aiohttp
anyio==3.7.1
# via
# -c requirements/constraints.in
# langchain
async-timeout==4.0.3
# via
# aiohttp
# langchain
attrs==23.1.0
# via aiohttp
certifi==2023.7.22
# via
# -c requirements/base.txt
# -c requirements/constraints.in
# requests
charset-normalizer==3.3.0
# via
# -c requirements/base.txt
# aiohttp
# requests
click==8.1.7
# via
# -c requirements/base.txt
# nltk
dataclasses-json==0.6.1
# via
# -c requirements/base.txt
# langchain
exceptiongroup==1.1.3
# via anyio
filelock==3.12.4
# via
# huggingface-hub
# torch
# transformers
frozenlist==1.4.0
# via
# aiohttp
# aiosignal
fsspec==2023.9.1
# via
# -c requirements/constraints.in
# huggingface-hub
# torch
huggingface==0.0.1
# via -r requirements/embed-huggingface.in
huggingface-hub==0.17.3
# via
# sentence-transformers
# tokenizers
# transformers
idna==3.4
# via
# -c requirements/base.txt
# anyio
# requests
# yarl
jinja2==3.1.2
# via torch
joblib==1.3.2
# via
# -c requirements/base.txt
# nltk
# scikit-learn
jsonpatch==1.33
# via langchain
jsonpointer==2.4
# via jsonpatch
langchain==0.0.317
# via -r requirements/embed-huggingface.in
langsmith==0.0.46
# via langchain
markupsafe==2.1.3
# via jinja2
marshmallow==3.20.1
# via
# -c requirements/base.txt
# dataclasses-json
mpmath==1.3.0
# via sympy
multidict==6.0.4
# via
# aiohttp
# yarl
mypy-extensions==1.0.0
# via
# -c requirements/base.txt
# typing-inspect
networkx==3.1
# via torch
nltk==3.8.1
# via
# -c requirements/base.txt
# sentence-transformers
numpy==1.24.4
# via
# -c requirements/base.txt
# -c requirements/constraints.in
# langchain
# scikit-learn
# scipy
# sentence-transformers
# torchvision
# transformers
packaging==23.2
# via
# -c requirements/base.txt
# huggingface-hub
# marshmallow
# transformers
pillow==10.1.0
# via torchvision
pydantic==1.10.13
# via
# -c requirements/constraints.in
# langchain
# langsmith
pyyaml==6.0.1
# via
# huggingface-hub
# langchain
# transformers
regex==2023.10.3
# via
# -c requirements/base.txt
# nltk
# transformers
requests==2.31.0
# via
# -c requirements/base.txt
# huggingface-hub
# langchain
# langsmith
# torchvision
# transformers
safetensors==0.3.2
# via
# -c requirements/constraints.in
# transformers
scikit-learn==1.3.1
# via sentence-transformers
scipy==1.10.1
# via
# -c requirements/constraints.in
# scikit-learn
# sentence-transformers
sentence-transformers==2.2.2
# via -r requirements/embed-huggingface.in
sentencepiece==0.1.99
# via sentence-transformers
sniffio==1.3.0
# via anyio
sqlalchemy==2.0.22
# via langchain
sympy==1.12
# via torch
tenacity==8.2.3
# via langchain
threadpoolctl==3.2.0
# via scikit-learn
tokenizers==0.14.1
# via transformers
torch==2.1.0
# via
# -c requirements/constraints.in
# sentence-transformers
# torchvision
torchvision==0.16.0
# via sentence-transformers
tqdm==4.66.1
# via
# -c requirements/base.txt
# huggingface-hub
# nltk
# sentence-transformers
# transformers
transformers==4.34.1
# via sentence-transformers
typing-extensions==4.8.0
# via
# -c requirements/base.txt
# huggingface-hub
# pydantic
# sqlalchemy
# torch
# typing-inspect
typing-inspect==0.9.0
# via
# -c requirements/base.txt
# dataclasses-json
urllib3==1.26.18
# via
# -c requirements/base.txt
# -c requirements/constraints.in
# requests
yarl==1.9.2
# via aiohttp
23 changes: 23 additions & 0 deletions test_unstructured/embed/test_embed_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from unstructured.documents.elements import Text
from unstructured.embed.huggingface import HuggingFaceEmbeddingEncoder


def test_embed_documents_does_not_break_element_to_dict(mocker):
# Mocked client with the desired behavior for embed_documents
mock_client = mocker.MagicMock()
mock_client.embed_documents.return_value = [1, 2]

# Mock get_openai_client to return our mock_client
mocker.patch.object(
HuggingFaceEmbeddingEncoder,
"get_huggingface_client",
return_value=mock_client,
)

encoder = HuggingFaceEmbeddingEncoder()
elements = encoder.embed_documents(
elements=[Text("This is sentence 1"), Text("This is sentence 2")],
)
assert len(elements) == 2
assert elements[0].to_dict()["text"] == "This is sentence 1"
assert elements[1].to_dict()["text"] == "This is sentence 2"
74 changes: 74 additions & 0 deletions unstructured/embed/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import List, Optional

import numpy as np

from unstructured.documents.elements import (
Element,
)
from unstructured.embed.interfaces import BaseEmbeddingEncoder
from unstructured.ingest.error import EmbeddingEncoderConnectionError
from unstructured.utils import requires_dependencies


class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
def __init__(
self,
model_name: Optional[str] = "sentence-transformers/all-MiniLM-L6-v2",
model_kwargs: Optional[dict] = {"device": "cpu"},
encode_kwargs: Optional[dict] = {"normalize_embeddings": False},
cache_folder: Optional[dict] = None,
):
self.model_name = model_name
self.model_kwargs = model_kwargs
self.encode_kwargs = encode_kwargs
self.cache_folder = cache_folder

self.initialize()

def initialize(self):
"""Creates a langchain HuggingFace object to embed elements."""
self.hf = self.get_huggingface_client()

def num_of_dimensions(self):
return np.shape(self.examplary_embedding)

def is_unit_vector(self):
return np.isclose(np.linalg.norm(self.examplary_embedding), 1.0)

def embed_query(self, query):
return self.hf.embed_query(str(query))

def embed_documents(self, elements: List[Element]) -> List[Element]:
embeddings = self.hf.embed_documents([str(e) for e in elements])
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings

def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]:
assert len(elements) == len(embeddings)
elements_w_embedding = []

for i, element in enumerate(elements):
element.embeddings = embeddings[i]
elements_w_embedding.append(element)
return elements

@EmbeddingEncoderConnectionError.wrap
@requires_dependencies(
["langchain", "sentence_transformers"],
extras="embed-huggingface",
)
def get_huggingface_client(self):
"""Creates a langchain Huggingface python client to embed elements."""
if hasattr(self, "hf_client"):
return self.hf_client

from langchain.embeddings import HuggingFaceEmbeddings

hf_client = HuggingFaceEmbeddings(
model_name=self.model_name,
model_kwargs=self.model_kwargs,
encode_kwargs=self.encode_kwargs,
cache_folder=self.cache_folder,
)
self.examplary_embedding = hf_client.embed_query("Q")
return hf_client

0 comments on commit da019d1

Please sign in to comment.