Skip to content

Commit

Permalink
Improve speed of AddId module (#36)
Browse files Browse the repository at this point in the history
* Add fast id method

Signed-off-by: Ryan Wolf <[email protected]>

* Add type conversion

Signed-off-by: Ryan Wolf <[email protected]>

* Fix off by one errors in tests

Signed-off-by: Ryan Wolf <[email protected]>

---------

Signed-off-by: Ryan Wolf <[email protected]>
  • Loading branch information
ryantwolf authored Apr 23, 2024
1 parent 0bbc77e commit 9864988
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 12 deletions.
45 changes: 41 additions & 4 deletions nemo_curator/modules/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import dask.dataframe as dd
import numpy as np
from dask import delayed

from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.module_utils import count_digits


class AddId:
def __init__(self, id_field, id_prefix="doc_id", start_index=0) -> None:
def __init__(
self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None
) -> None:
self.id_field = id_field
self.id_prefix = id_prefix
self.start_index = start_index

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
if self.start_index is None:
return self._add_id_fast(dataset)
else:
return self._add_id_ordered(dataset)

def _add_id_fast(self, dataset: DocumentDataset) -> DocumentDataset:
meta = dataset.df.dtypes.to_dict()
meta[self.id_field] = "string"

partition_zero_padding = count_digits(dataset.df.npartitions)
id_df = dataset.df.map_partitions(
self._add_id_fast_partition,
partition_zero_padding,
meta=meta,
)

return DocumentDataset(id_df)

def _add_id_fast_partition(self, partition, global_padding, partition_info=None):
local_padding = count_digits(len(partition))
global_id = partition_info["number"]

id_column = [
f"{self.id_prefix}-{local_id:0{local_padding}d}{global_id:0{global_padding}d}"
for local_id in range(len(partition))
]
partition[self.id_field] = id_column

return partition

def _add_id_ordered(self, dataset: DocumentDataset) -> DocumentDataset:
original_meta = dataset.df.dtypes.to_dict()
original_meta[self.id_field] = "object"
original_meta[self.id_field] = "string"
delayed_dataset = dataset.df.to_delayed()

parition_lengths = [0]
Expand All @@ -38,7 +74,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
delayed_id_dataset = []
for i, partition in enumerate(delayed_dataset):
delayed_id_dataset.append(
delayed(self._add_id_to_partition)(partition, lower_id_bounds[i])
delayed(self._add_id_ordered_partition)(partition, lower_id_bounds[i])
)

id_dataset = DocumentDataset(
Expand All @@ -47,11 +83,12 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:

return id_dataset

def _add_id_to_partition(self, partition, partition_start_id):
def _add_id_ordered_partition(self, partition, partition_start_id):
id_column = [
f"{self.id_prefix}-{int(i + self.start_index):010d}"
for i in range(partition_start_id, len(partition) + partition_start_id)
]
partition[self.id_field] = id_column
partition[self.id_field] = partition[self.id_field].astype("string")

return partition
6 changes: 4 additions & 2 deletions nemo_curator/scripts/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def attach_args(
parser.add_argument(
"--starting-index",
type=int,
default=0,
help="Starting index from which to start indexing the documents",
default=None,
help="If supplied, determines the starting index from which to start "
"indexing the documents. By default, it is unspecified, and uses an id"
" scheme that is fast to calculate and is not guaranteed to be ordered.",
)
parser.add_argument(
"--output-data-dir",
Expand Down
5 changes: 5 additions & 0 deletions nemo_curator/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math


def is_batched(function):
return hasattr(function, "batched") and function.batched


def count_digits(num):
return math.floor(math.log10(num)) + 1
50 changes: 44 additions & 6 deletions tests/test_add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pandas as pd
import pytest

import nemo_curator
import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset


Expand All @@ -41,10 +41,10 @@ def two_partition_dataset():
)


class TestPrepareTaskData:
class TestAddId:
def test_basic_id(self, single_partition_dataset):
id_field = "id"
add_id = nemo_curator.AddId(id_field)
add_id = nc.AddId(id_field, start_index=0)
id_dataset = add_id(single_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -63,7 +63,7 @@ def test_basic_id(self, single_partition_dataset):

def test_two_partitions(self, two_partition_dataset):
id_field = "id"
add_id = nemo_curator.AddId(id_field)
add_id = nc.AddId(id_field, start_index=0)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -83,7 +83,7 @@ def test_two_partitions(self, two_partition_dataset):
def test_id_prefix(self, two_partition_dataset):
id_field = "id"
id_prefix = "my_id"
add_id = nemo_curator.AddId(id_field, id_prefix=id_prefix)
add_id = nc.AddId(id_field, id_prefix=id_prefix, start_index=0)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -103,7 +103,7 @@ def test_id_prefix(self, two_partition_dataset):
def test_start_index(self, two_partition_dataset):
id_field = "id"
start_index = 13
add_id = nemo_curator.AddId(id_field, start_index=start_index)
add_id = nc.AddId(id_field, start_index=start_index)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -119,3 +119,41 @@ def test_start_index(self, two_partition_dataset):
assert all(
expected_ids == actual_ids
), f"Expected: {expected_ids}, got: {actual_ids}"

def test_fast_id_single_partition(self, single_partition_dataset):
id_field = "id"
add_id = nc.AddId(id_field)
id_dataset = add_id(single_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
[
"doc_id-00",
"doc_id-10",
"doc_id-20",
"doc_id-30",
"doc_id-40",
]
)

assert all(
expected_ids == actual_ids
), f"Expected: {expected_ids}, got: {actual_ids}"

def test_fast_id_two_partitions(self, two_partition_dataset):
id_field = "id"
add_id = nc.AddId(id_field)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
[
"doc_id-00",
"doc_id-10",
"doc_id-20",
"doc_id-01",
"doc_id-11",
]
)

assert all(
expected_ids == actual_ids
), f"Expected: {expected_ids}, got: {actual_ids}"

0 comments on commit 9864988

Please sign in to comment.