Skip to content

Commit

Permalink
Set TabularTransform to process clean transform in parallel (#1648)
Browse files Browse the repository at this point in the history
- `Clean` transform took long time due to tabular dataset usually have large dataset
- To reduce the time for transform, I made `TabularTransform` to do this process in parallel
- When I set `batch_size` as 100 and `num_workers` as 2, then the total process time reduces in half (about 300s -> 166s)
  • Loading branch information
sooahleex authored Oct 21, 2024
1 parent 9478b9a commit 3b66351
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1603>)
- Optimize path assignment to handle point cloud in JSON without images
(<https://github.com/openvinotoolkit/datumaro/pull/1643>)
- Set TabularTransform to process clean transform in parallel
(<https://github.com/openvinotoolkit/datumaro/pull/1648>)

### Bug fixes
- Fix datumaro format to load visibility information from Points annotations
Expand Down
74 changes: 74 additions & 0 deletions src/datumaro/components/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,80 @@ def __iter__(self):
yield item


class TabularTransform(Transform):
"""A transformation class for processing dataset items in batches with optional parallelism.
This class takes a dataset extractor, batch size, and number of worker threads to process
dataset items. Depending on the number of workers specified, it can process items either
sequentially (single-process) or in parallel (multi-process), making it efficient for
batch transformations.
Parameters:
extractor: The dataset extractor to obtain items from.
batch_size: The batch size for processing items. Default is 1.
num_workers: The number of worker threads to use for parallel processing.
Set to 0 for single-process mode. Default is 0.
"""

def __init__(
self,
extractor: IDataset,
batch_size: int = 1,
num_workers: int = 0,
):
super().__init__(extractor)
self._batch_size = batch_size
if not (isinstance(num_workers, int) and num_workers >= 0):
raise ValueError(
f"num_workers should be a non negative integer, but it is {num_workers}"
)
self._num_workers = num_workers

def __iter__(self) -> Iterator[DatasetItem]:
if self._num_workers == 0:
return self._iter_single_proc()
return self._iter_multi_procs()

def _iter_multi_procs(self):
with ThreadPool(processes=self._num_workers) as pool:

def _producer_gen():
for batch in take_by(self._extractor, self._batch_size):
future = pool.apply_async(
func=self._process_batch,
args=(batch,),
)
yield future

with consumer_generator(producer_generator=_producer_gen()) as consumer_gen:
for future in consumer_gen:
for item in future.get():
yield item

def _iter_single_proc(self) -> Iterator[DatasetItem]:
for batch in take_by(self._extractor, self._batch_size):
for item in self._process_batch(batch=batch):
yield item

def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]:
"""
Returns a modified copy of the input item.
Avoid changing and returning the input item, because it can lead to
unexpected problems. Use wrap_item() or item.wrap() to simplify copying.
"""

raise NotImplementedError()

def _process_batch(
self,
batch: List[DatasetItem],
) -> List[DatasetItem]:
results = [self.transform_item(item) for item in batch]

return results


class ModelTransform(Transform):
"""A transformation class for applying a model's inference to dataset items.
Expand Down
8 changes: 5 additions & 3 deletions src/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
UndefinedLabel,
)
from datumaro.components.media import Image, TableRow
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.components.transformer import ItemTransform, TabularTransform, Transform
from datumaro.util import NOTSET, filter_dict, parse_json_file, parse_str_enum_value, take_by
from datumaro.util.annotation_util import find_group_leader, find_instances
from datumaro.util.tabular_util import emoji_pattern
Expand Down Expand Up @@ -1864,7 +1864,7 @@ def transform_item(self, item: DatasetItem):
return self.wrap_item(item, annotations=annotations)


class Clean(ItemTransform):
class Clean(TabularTransform):
"""
A class used to refine the media items in a dataset.|n
|n
Expand All @@ -1883,8 +1883,10 @@ class Clean(ItemTransform):
def __init__(
self,
extractor: IDataset,
batch_size: int = 1,
num_workers: int = 0,
):
super().__init__(extractor)
super().__init__(extractor, batch_size, num_workers)

self._outlier_value = {}
self._missing_value = {}
Expand Down
31 changes: 29 additions & 2 deletions tests/unit/components/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: MIT

from typing import List, Tuple
from typing import List, Optional, Tuple

import pytest

Expand All @@ -11,7 +11,7 @@
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.launcher import Launcher
from datumaro.components.transformer import ModelTransform
from datumaro.components.transformer import ModelTransform, TabularTransform


class MockLauncher(Launcher):
Expand Down Expand Up @@ -64,3 +64,30 @@ def test_model_transform(
assert item.annotations == [Annotation(id=0), Annotation(id=1)]
else:
assert item.annotations == [Annotation(id=1)]


class TabularTransformTest:
@pytest.fixture
def fxt_dataset(self):
return Dataset.from_iterable(
[DatasetItem(id=f"item_{i}", annotations=[Annotation(id=0)]) for i in range(10)]
)

@pytest.mark.parametrize("batch_size", [1, 10])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_tabular_transform(self, fxt_dataset, batch_size, num_workers):
class MockTabularTransform(TabularTransform):
def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]:
# Mock transformation logic
item.annotations.append(Annotation(id=1))
return item

transform = MockTabularTransform(
extractor=fxt_dataset,
batch_size=batch_size,
num_workers=num_workers,
)

for idx, item in enumerate(transform):
assert item.id == f"item_{idx}"
assert item.annotations == [Annotation(id=0), Annotation(id=1)]

0 comments on commit 3b66351

Please sign in to comment.