From d246290df943c7ca750e5e7997501b1fb412e59b Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 27 Mar 2024 16:44:20 -0700 Subject: [PATCH 1/2] Add batched decorator Signed-off-by: Ryan Wolf --- ...bleInformationIdentificationAndRemoval.rst | 8 +-- examples/classifier_filtering.py | 4 +- examples/find_pii_and_deidentify.py | 6 +- nemo_curator/filters/__init__.py | 7 +-- nemo_curator/filters/classifier_filter.py | 59 +++++-------------- nemo_curator/modifiers/pii_modifier.py | 10 ++-- nemo_curator/modules/filter.py | 19 +++--- nemo_curator/modules/modify.py | 6 +- nemo_curator/modules/task.py | 2 +- nemo_curator/pii/recognizers/__init__.py | 0 .../scripts/find_pii_and_deidentify.py | 6 +- nemo_curator/utils/decorators.py | 24 ++++++++ nemo_curator/utils/module_utils.py | 17 ++++++ tests/test_filters.py | 9 +-- tutorials/tinystories/main.py | 5 +- 15 files changed, 93 insertions(+), 89 deletions(-) create mode 100644 nemo_curator/pii/recognizers/__init__.py create mode 100644 nemo_curator/utils/decorators.py create mode 100644 nemo_curator/utils/module_utils.py diff --git a/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst b/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst index 3f2ebcc67..6c19a2ea2 100644 --- a/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst +++ b/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst @@ -51,9 +51,9 @@ You could read, de-identify the dataset, and write it to an output directory usi from nemo_curator.utils.distributed_utils import read_data, write_to_disk, get_client from nemo_curator.utils.file_utils import get_batched_files from nemo_curator.modules.modify import Modify - from nemo_curator.modifiers.pii_modifier import PiiModifierBatched + from nemo_curator.modifiers.pii_modifier import PiiModifier - modifier = PiiModifierBatched( + modifier = PiiModifier( language="en", supported_entities=["PERSON", "EMAIL_ADDRESS"], anonymize_action="replace", @@ -70,7 +70,7 @@ You could read, de-identify the dataset, and write it to an output directory usi dataset = DocumentDataset(source_data) print(f"Dataset has {source_data.npartitions} partitions") - modify = Modify(modifier, batched=True) + modify = Modify(modifier) modified_dataset = modify(dataset) write_to_disk(modified_dataset.df, "output_directory", @@ -84,7 +84,7 @@ Let's walk through this code line by line. * ``for file_names in get_batched_files`` retrieves a batch of 32 documents from the `book_dataset` * ``source_data = read_data(file_names, file_type="jsonl", backend='pandas', add_filename=True)`` reads the data from all the files using Dask using Pandas as the backend. The ``add_filename`` argument ensures that the output files have the same filename as the input files. * ``dataset = DocumentDataset(source_data)`` creates an instance of ``DocumentDataset`` using the batch files. ``DocumentDataset`` is the standard format for text datasets in NeMo Curator. -* ``modify = Modify(modifier, batched=True)`` creates an instance of the ``Modify`` class. This class can take any modifier as an argument +* ``modify = Modify(modifier)`` creates an instance of the ``Modify`` class. This class can take any modifier as an argument * ``modified_dataset = modify(dataset)`` modifies the data in the dataset by performing the PII de-identification based upon the passed parameters. * ``write_to_disk(modified_dataset.df ....`` writes the de-identified documents to disk. diff --git a/examples/classifier_filtering.py b/examples/classifier_filtering.py index dbe24c45e..4e44be978 100644 --- a/examples/classifier_filtering.py +++ b/examples/classifier_filtering.py @@ -19,7 +19,7 @@ import nemo_curator as nc from nemo_curator.datasets import DocumentDataset -from nemo_curator.filters import BatchedFastTextQualityFilter +from nemo_curator.filters import FastTextQualityFilter from nemo_curator.modifiers import FastTextLabelModifier from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under @@ -85,7 +85,7 @@ def main(args): # Filter data target_dataset = load_dataset(low_quality_data_path) filter_pipeline = nc.ScoreFilter( - BatchedFastTextQualityFilter(model_path), + FastTextQualityFilter(model_path), score_field="quality_score", batched=True, score_type=float, diff --git a/examples/find_pii_and_deidentify.py b/examples/find_pii_and_deidentify.py index 9cb6d9bd5..e633dd3a9 100644 --- a/examples/find_pii_and_deidentify.py +++ b/examples/find_pii_and_deidentify.py @@ -18,7 +18,7 @@ import pandas as pd from nemo_curator.datasets import DocumentDataset -from nemo_curator.modifiers.pii_modifier import PiiModifierBatched +from nemo_curator.modifiers.pii_modifier import PiiModifier from nemo_curator.modules.modify import Modify from nemo_curator.utils.distributed_utils import get_client from nemo_curator.utils.script_utils import add_distributed_args @@ -35,7 +35,7 @@ def console_script(): dd = dask.dataframe.from_pandas(dataframe, npartitions=1) dataset = DocumentDataset(dd) - modifier = PiiModifierBatched( + modifier = PiiModifier( log_dir="./logs", batch_size=2000, language="en", @@ -43,7 +43,7 @@ def console_script(): anonymize_action="replace", ) - modify = Modify(modifier, batched=True) + modify = Modify(modifier) modified_dataset = modify(dataset) modified_dataset.df.to_json("output_files/*.jsonl", lines=True, orient="records") diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 488df5ee9..4eb800992 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .classifier_filter import ( - BatchedFastTextQualityFilter, - FastTextLangId, - FastTextQualityFilter, -) +from .classifier_filter import FastTextLangId, FastTextQualityFilter from .code import ( AlphaFilter, GeneralCommentToCodeFilter, @@ -54,7 +50,6 @@ ) __all__ = [ - "BatchedFastTextQualityFilter", "DocumentFilter", "import_filter", "FastTextLangId", diff --git a/nemo_curator/filters/classifier_filter.py b/nemo_curator/filters/classifier_filter.py index 43c5bf9ef..f32e2ff57 100644 --- a/nemo_curator/filters/classifier_filter.py +++ b/nemo_curator/filters/classifier_filter.py @@ -17,6 +17,7 @@ import pandas as pd from nemo_curator.filters.doc_filter import DocumentFilter +from nemo_curator.utils.decorators import batched from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker @@ -34,42 +35,7 @@ def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42): self._seed = np.random.seed(seed) self._name = "fasttext_quality_filter" - def score_document(self, text): - text = text.replace("\n", " ").replace("__label__", " ") - model_attr = f"{self._name}_{self._model_path}" - # Workers don't exist during type inference - try: - model = load_object_on_worker(model_attr, self._load_model, {}) - except NoWorkerError: - return 1.0 - pred = model.predict(text) - document_score = pred[1][0] - if pred[0][0] != self._label: - document_score = 1 - document_score - - return document_score - - def keep_document(self, score): - return np.random.pareto(self._alpha) > 1 - score - - def _load_model(self): - return fasttext.load_model(self._model_path) - - -class BatchedFastTextQualityFilter(DocumentFilter): - - def __init__(self, model_path=None, label="__label__hq", alpha=3, seed=42): - if model_path is None: - raise ValueError( - "Must provide a valid path to a FastText model " - "to compute document scores with this filter" - ) - self._model_path = model_path - self._label = label - self._alpha = alpha - self._seed = np.random.seed(seed) - self._name = "fasttext_quality_filter" - + @batched def score_document(self, df): model_attr = f"{self._name}_{self._model_path}" try: @@ -88,6 +54,7 @@ def _score_document(text): return df.apply(_score_document) + @batched def keep_document(self, df): return np.random.pareto(self._alpha, size=len(df)) > 1 - df @@ -108,19 +75,23 @@ def __init__(self, model_path=None, min_langid_score=0.3): self._cutoff = min_langid_score self._name = "lang_id" - def score_document(self, text): - pp = text.strip().replace("\n", " ") - + @batched + def score_document(self, df): model_attr = f"{self._name}_{self._model_path}" try: model = load_object_on_worker(model_attr, self._load_model, {}) except NoWorkerError: - return [1.0, "N/A"] - label, score = model.predict(pp, k=1) - score = score[0] - lang_code = label[0][-2:].upper() + return pd.Series([[1.0, "N/A"] for _ in range(len(df))]) - return [score, lang_code] + def _score_document(text): + pp = text.strip().replace("\n", " ") + label, score = model.predict(pp, k=1) + score = score[0] + lang_code = label[0][-2:].upper() + + return [score, lang_code] + + return df.apply(_score_document) def keep_document(self, score): return score[0] >= self._cutoff diff --git a/nemo_curator/modifiers/pii_modifier.py b/nemo_curator/modifiers/pii_modifier.py index 4a6ef37cf..23c713fbf 100644 --- a/nemo_curator/modifiers/pii_modifier.py +++ b/nemo_curator/modifiers/pii_modifier.py @@ -18,14 +18,15 @@ from nemo_curator.modifiers import DocumentModifier from nemo_curator.pii.algorithm import DEFAULT_LANGUAGE +from nemo_curator.utils.decorators import batched from nemo_curator.utils.distributed_utils import load_object_on_worker -__all__ = ["PiiModifierBatched"] +__all__ = ["PiiModifier"] DEFAULT_BATCH_SIZE = 2000 -class PiiModifierBatched(DocumentModifier): +class PiiModifier(DocumentModifier): """ This class is the entry point to using the PII de-identification module on documents stored as CSV, JSONL or other formats. It works with the `Modify` functionality as shown below: @@ -34,13 +35,13 @@ class PiiModifierBatched(DocumentModifier): dd = dask.dataframe.from_pandas(dataframe, npartitions=1) dataset = DocumentDataset(dd) - modifier = PiiModifierBatched( + modifier = PiiModifier( batch_size=2000, language='en', supported_entities=['PERSON', "EMAIL_ADDRESS"], anonymize_action='replace') - modify = Modify(modifier, batched=True) + modify = Modify(modifier) modified_dataset = modify(dataset) modified_dataset.df.to_json('output_files/*.jsonl', lines=True, orient='records') @@ -65,6 +66,7 @@ def __init__( self.batch_size = batch_size self.device = device + @batched def modify_document(self, text: pd.Series, partition_info: Dict = None): import logging diff --git a/nemo_curator/modules/filter.py b/nemo_curator/modules/filter.py index 4c30de0d9..07f8cb634 100644 --- a/nemo_curator/modules/filter.py +++ b/nemo_curator/modules/filter.py @@ -15,12 +15,11 @@ from dask.typing import no_default from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.module_utils import is_batched class Score: - def __init__( - self, score_fn, score_field, text_field="text", batched=False, score_type=None - ): + def __init__(self, score_fn, score_field, text_field="text", score_type=None): """ Args: score_fn: The score function that takes in a document string and outputs a score for the document @@ -30,7 +29,6 @@ def __init__( self.score_fn = score_fn self.score_field = score_field self.text_field = text_field - self.batched = batched self.score_type = score_type def __call__(self, dataset): @@ -40,7 +38,7 @@ def __call__(self, dataset): else: meta = no_default - if self.batched: + if is_batched(self.score_fn): dataset.df[self.score_field] = dataset.df[self.text_field].map_partitions( self.score_fn, meta=meta ) @@ -53,7 +51,7 @@ def __call__(self, dataset): class Filter: - def __init__(self, filter_fn, filter_field, invert=False, batched=False): + def __init__(self, filter_fn, filter_field, invert=False): """ Args: filter_fn: A function that returns True if the document is to be kept @@ -63,10 +61,9 @@ def __init__(self, filter_fn, filter_field, invert=False, batched=False): self.filter_fn = filter_fn self.filter_field = filter_field self.invert = invert - self.batched = batched def __call__(self, dataset): - if self.batched: + if is_batched(self.filter_fn): bool_mask = dataset.df[self.filter_field].map_partitions( self.filter_fn, meta=(None, bool) ) @@ -89,7 +86,6 @@ def __init__( score_field=None, score_type=None, invert=False, - batched=False, ): """ Args: @@ -100,7 +96,6 @@ def __init__( self.score_field = score_field self.score_type = score_type self.invert = invert - self.batched = batched def __call__(self, dataset): # Set the metadata for the function calls if provided @@ -109,7 +104,7 @@ def __call__(self, dataset): else: meta = no_default - if self.batched: + if is_batched(self.filter_obj.score_document): scores = dataset.df[self.text_field].map_partitions( self.filter_obj.score_document, meta=meta ) @@ -121,7 +116,7 @@ def __call__(self, dataset): if self.score_field is not None: dataset.df[self.score_field] = scores - if self.batched: + if is_batched(self.filter_obj.keep_document): bool_mask = scores.map_partitions( self.filter_obj.keep_document, meta=(None, bool) ) diff --git a/nemo_curator/modules/modify.py b/nemo_curator/modules/modify.py index 24dd47833..1307ab177 100644 --- a/nemo_curator/modules/modify.py +++ b/nemo_curator/modules/modify.py @@ -14,16 +14,16 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.modifiers import DocumentModifier +from nemo_curator.utils.module_utils import is_batched class Modify: - def __init__(self, modifier: DocumentModifier, text_field="text", batched=False): + def __init__(self, modifier: DocumentModifier, text_field="text"): self.modifier = modifier self.text_field = text_field - self.batched = batched def __call__(self, dataset: DocumentDataset) -> DocumentDataset: - if self.batched: + if is_batched(self.modifier.modify_document): dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions( self.modifier.modify_document, meta=(None, str) ) diff --git a/nemo_curator/modules/task.py b/nemo_curator/modules/task.py index 443679c2d..a7d9ae722 100644 --- a/nemo_curator/modules/task.py +++ b/nemo_curator/modules/task.py @@ -51,7 +51,7 @@ def __init__( tasks = [tasks] self.tasks = tasks self.text_field = text_field - self.max_ngram_size = 13 + self.max_ngram_size = max_ngram_size self.max_matches = max_matches self.min_document_length = min_document_length self.remove_char_each_side = remove_char_each_side diff --git a/nemo_curator/pii/recognizers/__init__.py b/nemo_curator/pii/recognizers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nemo_curator/scripts/find_pii_and_deidentify.py b/nemo_curator/scripts/find_pii_and_deidentify.py index d55a1117b..fedc59e52 100644 --- a/nemo_curator/scripts/find_pii_and_deidentify.py +++ b/nemo_curator/scripts/find_pii_and_deidentify.py @@ -18,7 +18,7 @@ from pathlib import Path from nemo_curator.datasets import DocumentDataset -from nemo_curator.modifiers.pii_modifier import PiiModifierBatched +from nemo_curator.modifiers.pii_modifier import PiiModifier from nemo_curator.modules.modify import Modify # from nemo_curator.pii.algorithm import DEFAULT_LANGUAGE @@ -43,7 +43,7 @@ def main(args): args.supported_entities.split(",") if args.supported_entities else None ) - modifier = PiiModifierBatched( + modifier = PiiModifier( language=args.language, supported_entities=supported_entities, anonymize_action=args.anonymize_action, @@ -68,7 +68,7 @@ def main(args): dataset = DocumentDataset(source_data) logging.debug(f"Dataset has {source_data.npartitions} partitions") - modify = Modify(modifier, batched=True) + modify = Modify(modifier) modified_dataset = modify(dataset) write_to_disk( modified_dataset.df, diff --git a/nemo_curator/utils/decorators.py b/nemo_curator/utils/decorators.py new file mode 100644 index 000000000..5592c7ebe --- /dev/null +++ b/nemo_curator/utils/decorators.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def batched(function): + """ + Marks a function as accepting a pandas series of elements instead of a single element + + Args: + function: The function that accepts a batch of elements + """ + function.batched = True + return function diff --git a/nemo_curator/utils/module_utils.py b/nemo_curator/utils/module_utils.py new file mode 100644 index 000000000..dc4a693d2 --- /dev/null +++ b/nemo_curator/utils/module_utils.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def is_batched(function): + return hasattr(function, "batched") and function.batched diff --git a/tests/test_filters.py b/tests/test_filters.py index bd6f0e638..11bf57388 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -53,6 +53,7 @@ XMLHeaderFilter, ) from nemo_curator.modules import Filter, Score, ScoreFilter, Sequential +from nemo_curator.utils.decorators import batched class LetterCountFilter(DocumentFilter): @@ -82,9 +83,11 @@ def __init__(self, min_length=5, max_length=10): self.min_length = min_length self.max_length = max_length + @batched def score_document(self, df): return df.str.len() + @batched def keep_document(self, scores): min_threshold = self.min_length <= scores max_threshold = scores <= self.max_length @@ -200,7 +203,7 @@ def test_sequential_filter(self, letter_count_data): def test_batch_score_filter(self, letter_count_data): length_filter = BatchedLengthFilter(min_length=8, max_length=11) - filter_step = ScoreFilter(length_filter, text_field="documents", batched=True) + filter_step = ScoreFilter(length_filter, text_field="documents") filtered_data = filter_step(letter_count_data) expected_indices = [1, 2] @@ -216,7 +219,6 @@ def test_batch_score(self, letter_count_data): length_filter.score_document, text_field="documents", score_field=score_field, - batched=True, ) scored_data = score_step(letter_count_data) @@ -233,10 +235,9 @@ def test_batch_filter(self, letter_count_data): length_filter.score_document, text_field="documents", score_field=score_field, - batched=True, ) scored_data = score_step(letter_count_data) - filter_step = Filter(length_filter.keep_document, score_field, batched=True) + filter_step = Filter(length_filter.keep_document, score_field) filtered_data = filter_step(scored_data) expected_indices = [1, 2] diff --git a/tutorials/tinystories/main.py b/tutorials/tinystories/main.py index 13df07120..fa4470c35 100644 --- a/tutorials/tinystories/main.py +++ b/tutorials/tinystories/main.py @@ -25,7 +25,7 @@ from nemo_curator import ScoreFilter, Sequential from nemo_curator.datasets import DocumentDataset from nemo_curator.filters import RepeatingTopNGramsFilter, WordCountFilter -from nemo_curator.modifiers.pii_modifier import PiiModifierBatched +from nemo_curator.modifiers.pii_modifier import PiiModifier from nemo_curator.modifiers.unicode_reformatter import UnicodeReformatter from nemo_curator.modules import ExactDuplicates from nemo_curator.modules.modify import Modify @@ -128,12 +128,11 @@ def redact_pii(dataset: DocumentDataset) -> DocumentDataset: DocumentDataset: The redacted dataset with PII replaced by a generic value. """ redactor = Modify( - PiiModifierBatched( + PiiModifier( supported_entities=["PERSON"], anonymize_action="replace", device="cpu", ), - batched=True, ) return redactor(dataset) From 0df26b2cb82afd18b49c74e1152869bfde45c304 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 2 Apr 2024 11:31:25 -0700 Subject: [PATCH 2/2] Fix stray legacy batched statements Signed-off-by: Ryan Wolf --- .../PersonalIdentifiableInformationIdentificationAndRemoval.rst | 2 +- examples/classifier_filtering.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst b/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst index 6c19a2ea2..81f17d0e7 100644 --- a/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst +++ b/docs/user-guide/PersonalIdentifiableInformationIdentificationAndRemoval.rst @@ -80,7 +80,7 @@ You could read, de-identify the dataset, and write it to an output directory usi Let's walk through this code line by line. -* ``modifier = PiiModifierBatched`` creates an instance of ``PiiModifierBatched`` class that is responsible for PII de-identification +* ``modifier = PiiModifier`` creates an instance of ``PiiModifier`` class that is responsible for PII de-identification * ``for file_names in get_batched_files`` retrieves a batch of 32 documents from the `book_dataset` * ``source_data = read_data(file_names, file_type="jsonl", backend='pandas', add_filename=True)`` reads the data from all the files using Dask using Pandas as the backend. The ``add_filename`` argument ensures that the output files have the same filename as the input files. * ``dataset = DocumentDataset(source_data)`` creates an instance of ``DocumentDataset`` using the batch files. ``DocumentDataset`` is the standard format for text datasets in NeMo Curator. diff --git a/examples/classifier_filtering.py b/examples/classifier_filtering.py index 4e44be978..df1f4197c 100644 --- a/examples/classifier_filtering.py +++ b/examples/classifier_filtering.py @@ -87,7 +87,6 @@ def main(args): filter_pipeline = nc.ScoreFilter( FastTextQualityFilter(model_path), score_field="quality_score", - batched=True, score_type=float, ) filtered_dataset = filter_pipeline(target_dataset)