-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Component/generate embeddings (#520)
Besides: - ignore vscode settings - mypy retry --------- Co-authored-by: Philippe Moussalli <[email protected]>
- Loading branch information
1 parent
5cc8c24
commit 6b20f46
Showing
12 changed files
with
627 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,9 @@ dist/ | |
# JetBrains IDE | ||
.idea/ | ||
|
||
#VSCode | ||
.vscode/ | ||
|
||
# Unit test reports | ||
TEST*.xml | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
FROM --platform=linux/amd64 python:3.8-slim as base | ||
|
||
# System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# Install requirements | ||
COPY requirements.txt / | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Install Fondant | ||
# This is split from other requirements to leverage caching | ||
ARG FONDANT_VERSION=main | ||
RUN pip3 install fondant[aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} | ||
# Set the working directory to the component folder | ||
WORKDIR /component/src | ||
|
||
# Copy over src-files | ||
COPY src/ . | ||
|
||
ENTRYPOINT ["fondant", "execute", "main"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Generate embeddings | ||
|
||
### Description | ||
Component that generates embeddings of text passages. | ||
|
||
### Inputs / outputs | ||
|
||
**This component consumes:** | ||
|
||
- text | ||
- data: string | ||
|
||
**This component produces:** | ||
|
||
- text | ||
- data: string | ||
- embedding: list<item: float> | ||
|
||
### Arguments | ||
|
||
The component takes the following arguments to alter its behavior: | ||
|
||
| argument | type | description | default | | ||
| -------- | ---- | ----------- | ------- | | ||
| model_provider | str | The provider of the model - corresponding to langchain embedding classes. Currently the following providers are supported: aleph_alpha, cohere, huggingface, openai. | huggingface | | ||
| model | str | The model to generate embeddings from. Choose an available model name to pass to the model provider's langchain embedding class. | all-MiniLM-L6-v2 | | ||
| api_keys | dict | The API keys to use for the model provider that are written to environment variables.Pass only the keys required by the model provider or conveniently pass all keys you will ever need. Pay attention how to name the dictionary keys so that they can be used by the model provider. | / | | ||
|
||
### Usage | ||
|
||
You can add this component to your pipeline using the following code: | ||
|
||
```python | ||
from fondant.pipeline import ComponentOp | ||
|
||
|
||
generate_embeddings_op = ComponentOp.from_registry( | ||
name="generate_embeddings", | ||
arguments={ | ||
# Add arguments | ||
# "model_provider": "huggingface", | ||
# "model": "all-MiniLM-L6-v2", | ||
# "api_keys": {}, | ||
} | ||
) | ||
pipeline.add_op(generate_embeddings_op, dependencies=[...]) #Add previous component as dependency | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
name: Generate embeddings | ||
description: Component that generates embeddings of text passages. | ||
image: generate_embeddings:latest | ||
|
||
consumes: | ||
text: | ||
fields: | ||
data: | ||
type: string | ||
|
||
produces: | ||
text: | ||
fields: | ||
data: | ||
type: string | ||
embedding: | ||
type: array | ||
items: | ||
type: float32 | ||
|
||
args: | ||
model_provider: | ||
description: | | ||
The provider of the model - corresponding to langchain embedding classes. | ||
Currently the following providers are supported: aleph_alpha, cohere, huggingface, openai. | ||
type: str | ||
default: huggingface | ||
model: | ||
description: | | ||
The model to generate embeddings from. | ||
Choose an available model name to pass to the model provider's langchain embedding class. | ||
type: str | ||
default: all-MiniLM-L6-v2 | ||
api_keys: | ||
description: | | ||
The API keys to use for the model provider that are written to environment variables. | ||
Pass only the keys required by the model provider or conveniently pass all keys you will ever need. | ||
Pay attention how to name the dictionary keys so that they can be used by the model provider. | ||
type: dict | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
aleph_alpha_client==3.5.1 | ||
cohere==4.27 | ||
langchain==0.0.313 | ||
openai==0.28.1 | ||
pandas==1.5.0 | ||
retry==0.9.2 | ||
sentence-transformers==2.2.2 | ||
tiktoken==0.5.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import logging | ||
|
||
import pandas as pd | ||
from fondant.component import PandasTransformComponent | ||
from langchain.embeddings import ( | ||
AlephAlphaAsymmetricSemanticEmbedding, | ||
CohereEmbeddings, | ||
HuggingFaceEmbeddings, | ||
OpenAIEmbeddings, | ||
) | ||
from retry import retry | ||
from utils import to_env_vars | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GenerateEmbeddingsComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*_, | ||
model_provider: str, | ||
model: str, | ||
api_keys: dict, | ||
): | ||
self.model_provider = model_provider | ||
self.model = model | ||
|
||
to_env_vars(api_keys) | ||
|
||
def get_embedding_model(self, model_provider, model: str): | ||
# contains a first selection of embedding models | ||
if model_provider == "aleph_alpha": | ||
return AlephAlphaAsymmetricSemanticEmbedding(model=model) | ||
if model_provider == "cohere": | ||
return CohereEmbeddings(model=model) | ||
if model_provider == "huggingface": | ||
return HuggingFaceEmbeddings(model_name=model) | ||
if model_provider == "openai": | ||
return OpenAIEmbeddings(model=model) | ||
msg = f"Unknown provider {model_provider}" | ||
raise ValueError(msg) | ||
|
||
@retry() # make sure to keep trying even when api call limit is reached | ||
def get_embeddings_vectors(self, embedding_model, texts): | ||
return embedding_model.embed_documents(texts.tolist()) | ||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
embedding_model = self.get_embedding_model(self.model_provider, self.model) | ||
dataframe[("text", "embedding")] = self.get_embeddings_vectors( | ||
embedding_model, | ||
dataframe[("text", "data")], | ||
) | ||
return dataframe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import os | ||
|
||
|
||
def to_env_vars(api_keys: dict): | ||
for key, value in api_keys.items(): | ||
os.environ[key] = value |
53 changes: 53 additions & 0 deletions
53
components/generate_embeddings/tests/generate_embeddings_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Unit test for generate embeddings component.""" | ||
import json | ||
from math import isclose | ||
|
||
import pandas as pd | ||
|
||
from components.generate_embeddings.src.main import GenerateEmbeddingsComponent | ||
|
||
|
||
def embeddings_close(a, b): | ||
return all(isclose(x, y, abs_tol=1e-5) for x, y in zip(a, b)) | ||
|
||
|
||
def test_run_component_test(): | ||
"""Test generate embeddings component.""" | ||
with open("lorem_300.txt", encoding="utf-8") as f: | ||
lorem_300 = f.read() | ||
with open("lorem_400.txt", encoding="utf-8") as f: | ||
lorem_400 = f.read() | ||
|
||
# Given: Dataframe with text | ||
data = [ | ||
{"data": "Hello World!!"}, | ||
{"data": lorem_300}, | ||
{"data": lorem_400}, | ||
] | ||
|
||
DATA_LENTGH = 3 | ||
|
||
dataframe = pd.concat({"text": pd.DataFrame(data)}, axis=1, names=["text", "data"]) | ||
|
||
component = GenerateEmbeddingsComponent( | ||
model_provider="huggingface", | ||
model="all-MiniLM-L6-v2", | ||
api_keys={}, | ||
) | ||
|
||
dataframe = component.transform(dataframe=dataframe) | ||
|
||
with open("hello_world_embedding.txt", encoding="utf-8") as f: | ||
hello_world_embedding = f.read() | ||
hello_world_embedding = json.loads(hello_world_embedding) | ||
|
||
# Then: right embeddings are generated | ||
assert len(dataframe) == DATA_LENTGH | ||
assert embeddings_close( | ||
dataframe.iloc[0]["text"]["embedding"], | ||
hello_world_embedding, | ||
) | ||
# Then: too long text is truncated and thus embeddings are the same | ||
assert ( | ||
dataframe.iloc[1]["text"]["embedding"] == dataframe.iloc[2]["text"]["embedding"] | ||
) |
Oops, something went wrong.