forked from ml6team/fondant
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LLM pipeline] MinHash generation for deduplication (ml6team#295)
This component generates MinHashes of text. The MinHash similarity will be used to determine duplicated text passages. --------- Co-authored-by: NielsRogge <[email protected]> Co-authored-by: Robbe Sneyders <[email protected]>
- Loading branch information
1 parent
5a7dce6
commit d4b9775
Showing
5 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
FROM --platform=linux/amd64 python:3.8-slim | ||
|
||
# System dependencies | ||
RUN apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install git -y | ||
|
||
# Install requirements | ||
COPY requirements.txt / | ||
RUN pip3 install --no-cache-dir -r requirements.txt | ||
|
||
# Install Fondant | ||
# This is split from other requirements to leverage caching | ||
ARG FONDANT_VERSION=main | ||
RUN pip3 install fondant[aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} | ||
|
||
# Set the working directory to the component folder | ||
WORKDIR /component/src | ||
|
||
# Copy over src-files | ||
COPY src/ . | ||
|
||
ENTRYPOINT ["python", "main.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
name: MinHash generator | ||
description: A component that generates minhashes of text. | ||
image: ghcr.io/ml6team/minhash_generator:latest | ||
|
||
consumes: | ||
text: | ||
fields: | ||
data: | ||
type: string | ||
|
||
produces: | ||
text: | ||
fields: | ||
minhash: | ||
type: array | ||
items: | ||
type: uint64 | ||
args: | ||
shingle_ngram_size: | ||
description: Define size of ngram used for the shingle generation | ||
type: int | ||
default: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
datasketch==1.5.9 | ||
nltk==3.8.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""A component that generates minhashes of text.""" | ||
import logging | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from datasketch import MinHash | ||
from fondant.component import PandasTransformComponent | ||
from nltk.util import ngrams | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def create_shingles(text: str) -> list: | ||
"""Creates text shingles that will be used for the hash generation.""" | ||
# Split text into words | ||
words = text.split() | ||
|
||
# Generate shingles of size 3 using nltk's ngrams function | ||
return list(ngrams(words, 3)) | ||
|
||
def compute_minhash(shingles: list) -> np.ndarray: | ||
"""Calculate minhash based on the shingles.""" | ||
minhash = MinHash() | ||
|
||
# Update the MinHash object with the shingles | ||
for shingle in shingles: | ||
minhash.update(" ".join(shingle).encode("utf-8")) | ||
|
||
return minhash.hashvalues | ||
|
||
class MinHashGeneratorComponent(PandasTransformComponent): | ||
"""Component generates minhashes of text.""" | ||
|
||
def setup(self, *, shingle_ngram_size: int): | ||
"""Setup component. | ||
Args: | ||
shingle_ngram_size: Defines size of ngram used for the shingle generation. | ||
""" | ||
self.shingle_ngram_size = shingle_ngram_size | ||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Generates minhash values of text. | ||
Args: | ||
dataframe: Pandas dataframe. | ||
Returns: | ||
Pandas dataframe | ||
""" | ||
dataframe[("text", "shingles")] = dataframe[("text", "data")].apply( | ||
create_shingles, | ||
) | ||
dataframe[("text", "minhash")] = dataframe[("text", "shingles")].apply( | ||
compute_minhash, | ||
) | ||
|
||
return dataframe | ||
|
||
|
||
if __name__ == "__main__": | ||
component = MinHashGeneratorComponent.from_args() | ||
component.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""Unit test for minhash generation component.""" | ||
import pandas as pd | ||
from fondant.component_spec import ComponentSpec | ||
|
||
from components.minhash_generator.src.main import MinHashGeneratorComponent | ||
|
||
|
||
def test_run_component_test(): | ||
"""Test MinHash generation.""" | ||
# Given: Dataframe with text, one duplicate in | ||
data = [ | ||
{"data": "This is my first sentence"}, | ||
{"data": "This is my first sentence"}, | ||
{"data": "This is a different sentence"}, | ||
] | ||
|
||
dataframe = pd.concat({"text": pd.DataFrame(data)}, axis=1, names=["text", "data"]) | ||
|
||
# When: The text filter component proceed the dataframe | ||
spec = ComponentSpec.from_file("../fondant_component.yaml") | ||
|
||
component = MinHashGeneratorComponent( | ||
spec, | ||
input_manifest_path="./dummy_input_manifest.json", | ||
output_manifest_path="./dummy_input_manifest.json", | ||
metadata={}, | ||
user_arguments={}, | ||
) | ||
|
||
dataframe = component.transform(dataframe=dataframe) | ||
|
||
# Then: dataframe contain minhashes for each entry | ||
assert any( | ||
dataframe.loc[0]["text"]["minhash"] == dataframe.loc[1]["text"]["minhash"], | ||
) | ||
assert not any( | ||
dataframe.loc[0]["text"]["minhash"] == dataframe.loc[2]["text"]["minhash"], | ||
) |