Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeParrot] Near-deduplication with jaccard similarity #17054

Merged
merged 38 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7842832
deduplication draft
May 2, 2022
c24b32c
update style
May 2, 2022
d0d6fec
update style test
May 5, 2022
d489572
dummy test main
May 5, 2022
2dab7a8
rename modules
May 6, 2022
28c800c
rename functions
May 6, 2022
3ac2967
return extremes in deduplicate_clusters
May 6, 2022
b60d265
update style
May 6, 2022
22e626f
cast str for gzip
May 6, 2022
a036ca2
update doc string
May 7, 2022
6ee984d
time processing
May 7, 2022
32306d2
use dataset map to compute minhash
May 7, 2022
989959a
fill value for short token
May 9, 2022
0bc110c
remove da map method
May 11, 2022
cac2308
update style
May 11, 2022
8280787
use share object to multiprocess
May 14, 2022
15821dc
update style
May 14, 2022
e71a3f4
use f-string and minor fix
liyongsea May 20, 2022
a9312da
Merge branch 'main' into codeparrot_deduplication
liyongsea May 20, 2022
0545149
update style
May 20, 2022
be6faa9
use module parameters
May 20, 2022
9458dd3
change ds_dedup to ds_filter
May 20, 2022
a1ed605
save ds_dedup
May 20, 2022
0478fdc
mv test to script tests
May 22, 2022
eddf6c2
make jaccard threshold a parameter of deduplicate_dataset
May 22, 2022
206dbec
update style
May 22, 2022
4495b19
Merge branch 'main' into codeparrot_deduplication
May 23, 2022
ed5bf2b
add doc strings
May 23, 2022
b5dd2eb
update style
May 23, 2022
9161f60
add doc string for DuplicationIndex
May 23, 2022
c08bebb
save files into data dir
May 23, 2022
dd9bcf7
update readme
May 23, 2022
6431ba5
Update examples/research_projects/codeparrot/README.md
liyongsea May 24, 2022
319206d
make near deduplication optional
May 24, 2022
de1491d
move near deduplication in README
May 24, 2022
192a0d8
Update examples/research_projects/codeparrot/README.md
liyongsea Jun 16, 2022
7f8be34
Merge branch 'main' into codeparrot_deduplication
Jun 16, 2022
6329f1f
use f string
Jun 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/research_projects/codeparrot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The source of the dataset is the GitHub dump available on Google's [BigQuery](ht
The raw dataset contains many duplicates. We deduplicated and filtered the dataset using the heuristics proposed in OpenAI's Codex [paper](https://arxiv.org/abs/2107.03374) and some new ones:

- exact deduplication using each file's hash
- near deduplication using MinHash and Jaccard similarity. MinHash with a Jaccard threshold (default=0.85) is first used to create duplicate clusters. Then these clusters are then reduced to unique files based on the exact Jaccard similarity. See `deduplicate_dataset` in `minhash_deduplication.py` for a detailed description.
- filtering files with max line length > 1000
- filtering files with mean line length > 100
- fraction of alphanumeric characters < 0.25
Expand Down
4 changes: 3 additions & 1 deletion examples/research_projects/codeparrot/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ wandb==0.12.0
tensorboard==2.6.0
torch==1.11.0
huggingface-hub==0.1.0
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
datasketch==1.5.7
dpu_utils
6 changes: 6 additions & 0 deletions examples/research_projects/codeparrot/scripts/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class PreprocessingArguments:
default="lvwerra/codeparrot",
metadata={"help": "Name or path to the tokenizer."},
)
near_deduplication: Optional[bool] = field(
default=False, metadata={"help": "If True, near-duplicate samples are removed."}
)
jaccard_threshold: Optional[float] = field(
default=0.85, metadata={"help": "Jaccard threshold for near-duplicate samples."}
)


@dataclass
Expand Down
270 changes: 270 additions & 0 deletions examples/research_projects/codeparrot/scripts/minhash_deduplication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import json
import multiprocessing as mp
import re
from collections import defaultdict
from functools import partial
from typing import Dict, List, Optional, Set, Tuple, Type

from datasets import Dataset
from tqdm import tqdm

from datasketch import MinHash, MinHashLSH
from dpu_utils.utils.iterators import ThreadedIterator


NON_ALPHA = re.compile("[^A-Za-z_0-9]")
# parameters used in DuplicationIndex
MIN_NUM_TOKENS = 10
NUM_PERM = 256


def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
"""Compute the MinHash of a code snippet."""
if len(tokens) < MIN_NUM_TOKENS:
return None
min_hash = MinHash(num_perm=NUM_PERM)
for token in set(tokens):
min_hash.update(token.encode())
return min_hash


def get_tokens(code: str) -> Set[str]:
"""Tokenize a code snippet."""
return set([t for t in NON_ALPHA.split(code) if len(t.strip()) > 0])


class DuplicationIndex:
def __init__(
self,
*,
duplication_jaccard_threshold: float = 0.85,
):
self._duplication_jaccard_threshold = duplication_jaccard_threshold
self._num_perm = NUM_PERM
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)

self._duplicate_clusters = defaultdict(set)

def add(self, code_key: Tuple, min_hash: MinHash) -> None:
"""Add a key to _index (MinHashLSH)
the min_hash is used to query closest matches based on the jaccard_threshold.
The new key is either added to a existing cluster of one close match,
or a new cluster is created. The clusters created in this way, depend on the order of add.

Args:
code_key (Tuple of (index, repo_name, path)):
Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
min_hash: MinHash of the code_key.
"""
close_duplicates = self._index.query(min_hash)
if code_key in self._index.keys:
print(f"Duplicate key {code_key}")
return

self._index.insert(code_key, min_hash)
if len(close_duplicates) > 0:

for base_duplicate in close_duplicates:
if base_duplicate in self._duplicate_clusters:
self._duplicate_clusters[base_duplicate].add(code_key)
break
else:
self._duplicate_clusters[close_duplicates[0]].add(code_key)

def get_duplicate_clusters(self) -> List[List[Dict]]:
"""Export the duplicate clusters.
For each cluster, the first element is the base element of the cluster.
The base element has an estimation jaccard similarity higher than the threshold with all the other elements.

Returns:
duplicate_clusters (List[List[Dict]]):
List of duplicate clusters.
"""
duplicate_clusters = []
for base, duplicates in self._duplicate_clusters.items():
cluster = [base] + list(duplicates)
# reformat the cluster to be a list of dict
cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
loubnabnl marked this conversation as resolved.
Show resolved Hide resolved
duplicate_clusters.append(cluster)
return duplicate_clusters

def save(self, filepath) -> None:
duplicate_clusters = self.get_duplicate_clusters()
with open(filepath, "w") as f:
json.dump(duplicate_clusters, f)


def _compute_min_hash(element):
index, data = element
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["content"]) if len(t.strip()) > 0])
if min_hash is not None:
return (index, data["repo_name"], data["path"]), min_hash


def minhash_iter(dataset_iterator: Type[Dataset]):
with mp.Pool() as pool:
for data in pool.imap_unordered(
_compute_min_hash,
ThreadedIterator(dataset_iterator, max_queue_size=10000),
chunksize=100,
):
if data is not None:
yield data


def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
"""Find duplicate clusters in the dataset in two steps:
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
This step is computed using an asynchronous multiprocessing pool, minhash_iter
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
"""
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)

for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
di.add(filename, min_hash)

# Returns a List[Cluster] where Cluster is List[str] with the filenames.
return di.get_duplicate_clusters()


def jaccard_similarity(code1: str, code2: str) -> float:
"""Compute the Jaccard similarity of two code snippets."""
tokens1 = get_tokens(code1)
tokens2 = get_tokens(code2)
return len(tokens1 & tokens2) / len(tokens1 | tokens2)


_shared_dataset = None


def _find_cluster_extremes_shared(cluster, jaccard_threshold):
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
Two codes are similar if their Jaccard similarity is above the threshold.

Args:
cluster (List[dict]):
cluster is a list of dict, each dict contains the following keys:
- base_index
- repo_name
- path
This is a typical output of DuplicationIndex.get_duplicate_clusters()
jaccard_threshold (float):
threshold for Jaccard similarity.
Two codes are similar if their Jaccard similarity is above the threshold.

Returns:
extremes (List[dict]):
A reduced representation of the cluster. The field copies is added to each dict.
The copies field indicates the number of similar codes in the cluster for a extreme.
"""
extremes = []
for element1 in cluster:
code1 = _shared_dataset[element1["base_index"]]["content"]
for element2 in extremes:
code2 = _shared_dataset[element2["base_index"]]["content"]
if jaccard_similarity(code1, code2) >= jaccard_threshold:
element2["copies"] += 1
break
else:
element1["copies"] = 1
extremes.append(element1)
return extremes


def find_extremes(cluster_list, dataset, jaccard_threshold):
"""Call the _find_cluster_extremes_shared function in a parallel fashion.

Args:
cluster_list (List[List[Dict]]):
each cluster is a list of dicts with the key base_index,
referring to the index of the base code in the dataset.
dataset (Type[Dataset]):
dataset is used to access the content of the code snippets,
using the base_index from the cluster_list.
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
otherwise the multi processing is not speeded up.
jaccard_threshold (float):
the threshold for the jaccard similarity. The default value is 0.85

Returns:
extremes_list (List[Dict]):
Each cluster is reduced to extremes.
See _find_cluster_extremes_shared for the definition of extremes.
"""
global _shared_dataset
_shared_dataset = dataset
extremes_list = []
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
with mp.Pool() as pool:
for extremes in tqdm(
pool.imap_unordered(
f,
cluster_list,
),
total=len(cluster_list),
):
extremes_list.append(extremes)
return extremes_list


def deduplicate_dataset(
dataset: Type[Dataset], jaccard_threshold: float = 0.85
) -> Tuple[Type[Dataset], List[List[Dict]]]:
"""Deduplicate the dataset using minhash and jaccard similarity.
This function first generate duplicate clusters, then each cluster
is reduced to the extremes that are similar to the other elements in the cluster.
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).

Args:
dataset (Type[Dataset]):
The dataset to deduplicate.
jaccard_threshold (float, default=0.85):
jaccard threshold to determine if two codes are similar

Returns:
ds_dedup (Type[Dataset]):
The deduplicated dataset.
duplicate_clusters (List[List[Dict]]):
The list of duplicate clusters.
Each cluster is a list of dicts with the following keys:
- base_index : int
The index of the code in the original dataset.
- repo_name : str
- path : str
- copies : int
The number of copies of the code in the cluster. (find_cluster_extremes)
- is_extreme : bool
Whether the code is an extreme in the cluster.
All the codes in the cluster are removed from the dataset except the extremes.

Example:
>>> from datasets import load_dataset
>>> from minhash_deduplication import deduplicate_dataset
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
"""
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
duplicate_indices = set(x["base_index"] for cluster in duplicate_clusters for x in cluster)
extreme_dict = {}
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
for extremes in extremes_clusters:
for element in extremes:
extreme_dict[element["base_index"]] = element
remove_indices = duplicate_indices - set(extreme_dict.keys())
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)

# update duplicate_clusters
for cluster in duplicate_clusters:
for element in cluster:
element["is_extreme"] = element["base_index"] in extreme_dict
if element["is_extreme"]:
element["copies"] = extreme_dict[element["base_index"]]["copies"]

print(f"Original dataset size: {len(dataset)}")
print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
print(f"Files in duplicate cluster: {len(duplicate_indices)}")
print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
print(f"Filtered dataset size: {len(ds_filter)}")

return ds_filter, duplicate_clusters
28 changes: 24 additions & 4 deletions examples/research_projects/codeparrot/scripts/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import gzip
import hashlib
import json
import multiprocessing
import os
import shutil
import time
from pathlib import Path

import numpy as np
from datasets import load_dataset

from arguments import PreprocessingArguments
from minhash_deduplication import deduplicate_dataset
from transformers import AutoTokenizer, HfArgumentParser


Expand Down Expand Up @@ -146,7 +149,7 @@ def filter(example, uniques, args):
def compress_file(file_path):
"""Compress a file with g-zip."""
with open(file_path, "rb") as f_in:
with gzip.open(file_path + ".gz", "wb", compresslevel=6) as f_out:
with gzip.open(str(file_path) + ".gz", "wb", compresslevel=6) as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(file_path)

Expand Down Expand Up @@ -179,12 +182,29 @@ def compress_file(file_path):
print(f"Time to filter dataset: {time.time()-t_start:.2f}")
print(f"Size of filtered dataset: {len(ds_filter)}")

# Deduplicate with minhash and jaccard similarity
if args.near_deduplication:
t_start = time.time()
ds_filter, duplicate_clusters = deduplicate_dataset(ds_filter, args.jaccard_threshold)
print(f"Time to deduplicate dataset: {time.time()-t_start:.2f}")
print(f"Size of deduplicate dataset: {len(ds_filter)}")

# Save data in batches of samples_per_file
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)

# save duplicate_clusters in the output_dir as artifacts
# not sure it is the right place the save it
if args.near_deduplication:
with open(output_dir / "duplicate_clusters.json", "w") as f:
json.dump(duplicate_clusters, f)

data_dir = output_dir / "data"
data_dir.mkdir(exist_ok=True)

t_start = time.time()
for file_number, index in enumerate(range(0, len(ds_filter), args.samples_per_file)):
file_path = f"{args.output_dir}/file-{file_number+1:012}.json"
file_path = str(data_dir / f"file-{file_number+1:012}.json")
end_index = min(len(ds_filter), index + args.samples_per_file)
ds_filter.select(list(range(index, end_index))).to_json(file_path)
compress_file(file_path)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest import TestCase

from datasets import Dataset

from minhash_deduplication import deduplicate_dataset, make_duplicate_clusters


def get_dataset():
data_dict = {
"repo_name": ["test_repo1", "test_repo2", "test_repo3"],
"path": ["test_1.py", "test_2.py", "unit_test.py"],
"content": ["a " * 20, "a " * 30, "b " * 7],
}
dataset = Dataset.from_dict(data_dict)
return dataset


class MakeDuplicateClustersTest(TestCase):
def test_make_duplicate_clusters(self):
ds = get_dataset()
duplicate_clusters = make_duplicate_clusters(ds, 0.85)
self.assertEqual(len(duplicate_clusters[0]), 2)

def test_deduplicate_dataset(self):
ds = get_dataset()
ds_filter, duplicate_clusters = deduplicate_dataset(ds)
self.assertEqual(len(ds_filter), 2)
print(duplicate_clusters)
self.assertEqual(duplicate_clusters[0][0]["copies"], 2)
self.assertEqual(duplicate_clusters[0][0]["is_extreme"], True)