diff --git a/docs/source/torchdata.datapipes.iter.rst b/docs/source/torchdata.datapipes.iter.rst index ec4d61d26..50275af0a 100644 --- a/docs/source/torchdata.datapipes.iter.rst +++ b/docs/source/torchdata.datapipes.iter.rst @@ -167,6 +167,7 @@ These DataPipes apply the a given function to each element in the DataPipe. BatchMapper FlatMapper Mapper + ThreadPoolMapper Other DataPipes ------------------------- diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 9ee73ba44..d0df5caf4 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -12,6 +12,7 @@ import warnings from collections import defaultdict +from functools import partial from typing import Dict import expecttest @@ -20,6 +21,7 @@ import torchdata from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls +from torch.testing._internal.common_utils import suppress_warnings from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration from torchdata.datapipes.iter import ( @@ -80,12 +82,12 @@ def _convert_to_tensor(data): async def _async_mul_ten(x): - await asyncio.sleep(1) + await asyncio.sleep(0.1) return x * 10 async def _async_x_mul_y(x, y): - await asyncio.sleep(1) + await asyncio.sleep(0.1) return x * y @@ -289,7 +291,7 @@ def odd_even_bug(i: int) -> int: self.assertEqual(len(source_dp), len(result_dp)) def test_prefetcher_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(50000)) + source_dp = IterableWrapper(range(5000)) prefetched_dp = source_dp.prefetch(10) # check if early termination resets child thread properly for _, _ in zip(range(100), prefetched_dp): @@ -1619,6 +1621,273 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_ self.assertEqual(v1, exp) self.assertEqual(v2, exp) + def test_threadpool_map(self): + target_length = 30 + input_dp = IterableWrapper(range(target_length)) + input_dp_parallel = IterableWrapper(range(target_length)) + + def fn(item, dtype=torch.float, *, sum=False): + data = torch.tensor(item, dtype=dtype) + return data if not sum else data.sum() + + # Functional Test: apply to each element correctly + map_dp = input_dp.threadpool_map(fn) + self.assertEqual(target_length, len(map_dp)) + for x, y in zip(map_dp, range(target_length)): + self.assertEqual(x, torch.tensor(y, dtype=torch.float)) + + # Functional Test: works with partial function + map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True)) + for x, y in zip(map_dp, range(target_length)): + self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) + + # __len__ Test: inherits length from source DataPipe + self.assertEqual(target_length, len(map_dp)) + + input_dp_nl = IDP_NoLen(range(target_length)) + map_dp_nl = input_dp_nl.threadpool_map(lambda x: x) + for x, y in zip(map_dp_nl, range(target_length)): + self.assertEqual(x, torch.tensor(y, dtype=torch.float)) + + # __len__ Test: inherits length from source DataPipe - raises error when invalid + with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): + len(map_dp_nl) + + # Test: two independent ThreadPoolExecutors running at the same time + map_dp_parallel = input_dp_parallel.threadpool_map(fn) + for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): + self.assertEqual(x, torch.tensor(z, dtype=torch.float)) + self.assertEqual(y, torch.tensor(z, dtype=torch.float)) + + # Reset Test: DataPipe resets properly + n_elements_before_reset = 5 + res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) + self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) + self.assertEqual(list(range(target_length)), res_after_reset) + + @suppress_warnings # Suppress warning for lambda fn + def test_threadpool_map_tuple_list_with_col_iterdatapipe(self): + def fn_11(d): + return -d + + def fn_1n(d): + return -d, d + + def fn_n1(d0, d1): + return d0 + d1 + + def fn_nn(d0, d1): + return -d0, -d1, d0 + d1 + + def fn_n1_def(d0, d1=1): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_sep_pos(d0, *args, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + p_fn_cmplx = partial(fn_cmplx, d2=2) + + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): + for constr in (list, tuple): + datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.threadpool_map(fn, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.threadpool_map(fn, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + if constr is list: + ref_dp = ref_dp.map(list) + self.assertEqual(list(res_dp), list(ref_dp), "First test failed") + # Reset + self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") + + _helper(lambda data: data, fn_n1_def, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) + _helper(lambda data: data, p_fn_n1, 0, 1) + _helper(lambda data: data, p_fn_cmplx, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2) + _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) + + # Replacing with one input column and default output column + _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) + _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) + # The index of input column is out of range + _helper(None, fn_1n, 3, error=IndexError) + # Unmatched input columns with fn arguments + _helper(None, fn_n1, 1, error=ValueError) + _helper(None, fn_n1, [0, 1, 2], error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError) + _helper(None, fn_cmplx, 0, 1, ValueError) + _helper(None, fn_n1_pos, 1, error=ValueError) + _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) + _helper(None, p_fn_n1, [0, 1], error=ValueError) + _helper(None, fn_1n, [1, 2], error=ValueError) + # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) + _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) + # Fn has keyword-only arguments + _helper(None, fn_n1_kwargs, 1, error=ValueError) + _helper(None, fn_cmplx, [0, 1], 2, ValueError) + + # Replacing with multiple input columns and default output column (the left-most input column) + _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) + _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1]) + + # output_col can only be specified when input_col is not None + _helper(None, fn_n1, None, 1, error=ValueError) + # output_col can only be single-element list or tuple + _helper(None, fn_n1, None, [0, 1], error=ValueError) + # Single-element list as output_col + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) + # Replacing with one input column and single specified output column + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) + _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) + # The index of output column is out of range + _helper(None, fn_1n, 1, 3, error=IndexError) + _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) + _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) + + # Appending the output at the end + _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) + _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) + _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) + _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1) + + # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected + _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) + _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) + + @suppress_warnings # Suppress warning for lambda fn + def test_threadpool_map_dict_with_col_iterdatapipe(self): + def fn_11(d): + return -d + + def fn_1n(d): + return -d, d + + def fn_n1(d0, d1): + return d0 + d1 + + def fn_nn(d0, d1): + return -d0, -d1, d0 + d1 + + def fn_n1_def(d0, d1=1): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_kwonly(*, d0, d1): + return d0 + d1 + + def fn_has_nondefault_kwonly(d0, *, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_cmplx = partial(fn_cmplx, d2=2) + + # Prevent modification in-place to support resetting + def _dict_update(data, newdata, remove_idx=None): + _data = dict(data) + _data.update(newdata) + if remove_idx: + for idx in remove_idx: + del _data[idx] + return _data + + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): + datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.threadpool_map(fn, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.threadpool_map(fn, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + self.assertEqual(list(res_dp), list(ref_dp), "First test failed") + # Reset + self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") + + _helper(lambda data: data, fn_n1_def, "x", "y") + _helper(lambda data: data, p_fn_n1, "x", "y") + _helper(lambda data: data, p_fn_cmplx, "x", "y") + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z") + + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z") + + _helper(None, fn_n1_pos, "x", error=ValueError) + _helper(None, fn_n1_kwargs, "x", error=ValueError) + # non-default kw-only args + _helper(None, fn_kwonly, ["x", "y"], error=ValueError) + _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) + _helper(None, fn_cmplx, ["x", "y"], error=ValueError) + + # Replacing with one input column and default output column + _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") + _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") + # The key of input column is not in dict + _helper(None, fn_1n, "a", error=KeyError) + # Unmatched input columns with fn arguments + _helper(None, fn_n1, "y", error=ValueError) + _helper(None, fn_1n, ["x", "y"], error=ValueError) + _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) + _helper(None, p_fn_n1, ["x", "y"], error=ValueError) + _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) + # Replacing with multiple input columns and default output column (the left-most input column) + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) + _helper( + lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), + fn_nn, + ["z", "y"], + ) + + # output_col can only be specified when input_col is not None + _helper(None, fn_n1, None, "x", error=ValueError) + # output_col can only be single-element list or tuple + _helper(None, fn_n1, None, ["x", "y"], error=ValueError) + # Single-element list as output_col + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) + # Replacing with one input column and single specified output column + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") + _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z") + _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y") + _helper( + lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), + fn_nn, + ["y", "z"], + "x", + ) + + # Adding new key to dict for the output + _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") + _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a") + _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a") + _helper( + lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), + fn_nn, + ["y", "z"], + "a", + ) + if __name__ == "__main__": unittest.main() diff --git a/test/test_serialization.py b/test/test_serialization.py index a837a3948..b891ed4f8 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -297,6 +297,7 @@ def test_serializable(self): (iterdp.TarArchiveLoader, None, (), {}), # TODO(594): Add serialization tests for optional DataPipe # (iterdp.TFRecordLoader, None, (), {}), + (iterdp.ThreadPoolMapper, None, (_fake_fn_ls,), {}), (iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}), (iterdp.WebDataset, IterableWrapper([("foo.txt", b"1"), ("bar.txt", b"2")]), (), {}), (iterdp.XzFileLoader, None, (), {}), @@ -366,6 +367,7 @@ def test_serializable_with_dill(self): (iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}), (iterdp.OnDiskCacheHolder, (lambda x: x,), {}), (iterdp.ParagraphAggregator, (lambda x: x,), {}), + (iterdp.ThreadPoolMapper, (lambda x: x,), {}), ] # Skipping value comparison for these DataPipes dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator} diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index f2d564055..22f038f45 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -72,6 +72,7 @@ FlatMapperIterDataPipe as FlatMapper, FlattenIterDataPipe as Flattener, SliceIterDataPipe as Slicer, + ThreadPoolMapperIterDataPipe as ThreadPoolMapper, ) from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader from torchdata.datapipes.iter.util.cacheholder import ( @@ -213,6 +214,7 @@ "StreamReader", "TFRecordLoader", "TarArchiveLoader", + "ThreadPoolMapper", "UnBatcher", "UnZipper", "WebDataset", diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 6cbc1266e..23c58a011 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -7,6 +7,8 @@ import asyncio import inspect import warnings +from collections import deque +from concurrent import futures from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union @@ -532,15 +534,19 @@ class BatchAsyncMapperIterDataPipe(IterDataPipe): async_fn: The coroutine function to be applied to each batch of data batch_size: The size of batch to be aggregated from ``source_datapipe`` input_col: Index or indices of data which ``fn`` is applied, such as: + - ``None`` as default to apply ``fn`` to the data directly. - Integer(s) is used for list/tuple. - Key(s) is used for dict. + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified only when ``input_col`` is not ``None`` + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with multiple indices, the left-most one is used, and other indices will be removed. - Integer is used for list/tuple. ``-1`` represents to append result at the end. - Key is used for dict. New key is acceptable. + max_concurrency: Maximum concurrency to call async functions. (Default value: 32) Example: @@ -585,3 +591,194 @@ def __new__( dp = _BatchAsyncMapperIterDataPipe(dp, async_fn, input_col, output_col, max_concurrency) dp = dp.flatmap() return dp + + +@functional_datapipe("threadpool_map") +class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): + r""" + Applies a function over each item from the source DataPipe concurrently + using ``ThreadPoolExecutor`` (functional name: ``threadpool_map``). + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + source_datapipe: Source IterDataPipe + fn: Function being applied over each item + input_col: Index or indices of data which ``fn`` is applied, such as: + + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified + only when ``input_col`` is not ``None`` + + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with + multiple indices, the left-most one is used, and other indices will be removed. + - Integer is used for list/tuple. ``-1`` represents to append result at the end. + - Key is used for dict. New key is acceptable. + + scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 128) + max_workers: Maximum number of threads to execute function calls + **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` + + Note: + For more information about ``max_workers`` and additional arguments for the ``ThreadPoolExecutor`` + please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor + + Note: + For optimal use of all threads, ``scheduled_tasks`` > ``max_workers`` is strongly recommended. The higher the + variance of the time needed to finish execution of the given ``fn`` is, the higher the value + of ``scheduled_tasks`` needs to be to avoid threads sitting idle while waiting + for the next result (as results are returned in correct order). + + However, too high value of ``scheduled_tasks`` might lead to long waiting period until the first element is yielded + as ``next`` is called ``scheduled_tasks`` many times on ``source_datapipe`` before yielding. + + We encourage you to try out different values of ``max_workers`` and ``scheduled_tasks`` + in search for optimal values for your use-case. + + Example: + + .. testsetup:: + + from torchdata.datapipes.iter import IterableWrapper + import requests + import time + from unittest.mock import MagicMock + + requests.get = MagicMock() + urls = [] + + .. testcode:: + + # fetching html from remote + def fetch_html(url: str, **kwargs): + r = requests.get(url, **kwargs) + r.raise_for_status() + return r.content + dp = IterableWrapper(urls) + dp = dp.threadpool_map(fetch_html,max_workers=16) + + .. testcode:: + + def mul_ten(x): + time.sleep(0.1) + return x * 10 + + dp = IterableWrapper([(i, i) for i in range(50)]) + dp = dp.threadpool_map(mul_ten, input_col=1) + print(list(dp)) + + .. testoutput:: + + [(0, 0), (1, 10), (2, 20), (3, 30), ...] + + .. testcode:: + + dp = IterableWrapper([(i, i) for i in range(50)]) + dp = dp.threadpool_map(mul_ten, input_col=1, output_col=-1) + print(list(dp)) + + .. testoutput:: + + [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] + + """ + + datapipe: IterDataPipe + fn: Callable + + def __init__( + self, + source_datapipe: IterDataPipe, + fn: Callable, + input_col=None, + output_col=None, + scheduled_tasks: int = 128, + max_workers: Optional[int] = None, + **threadpool_kwargs, + ) -> None: + super().__init__() + self.datapipe = source_datapipe + + _check_unpickable_fn(fn) + self.fn = fn # type: ignore[assignment] + + if scheduled_tasks <= 0: + raise ValueError("'scheduled_tasks' is required to be a positive integer.") + self.scheduled_tasks = scheduled_tasks + if max_workers is not None and max_workers <= 0: + raise ValueError("'max_workers' is required to be a positive integer.") + self.max_workers = max_workers + self.threadpool_kwargs = threadpool_kwargs + + self.input_col = input_col + if input_col is None and output_col is not None: + raise ValueError("`output_col` must be None when `input_col` is None.") + if isinstance(output_col, (list, tuple)): + if len(output_col) > 1: + raise ValueError("`output_col` must be a single-element list or tuple") + output_col = output_col[0] + self.output_col = output_col + validate_input_col(fn, input_col) + + def _apply_fn(self, data): + if self.input_col is None and self.output_col is None: + return self.fn(data) + + if self.input_col is None: + res = self.fn(data) + elif isinstance(self.input_col, (list, tuple)): + args = tuple(data[col] for col in self.input_col) + res = self.fn(*args) + else: + res = self.fn(data[self.input_col]) + + # Copy tuple to list and run in-place modification because tuple is immutable. + if isinstance(data, tuple): + t_flag = True + data = list(data) + else: + t_flag = False + + if self.output_col is None: + if isinstance(self.input_col, (list, tuple)): + data[self.input_col[0]] = res + for idx in sorted(self.input_col[1:], reverse=True): + del data[idx] + else: + data[self.input_col] = res + else: + if self.output_col == -1: + data.append(res) + else: + data[self.output_col] = res + + # Convert list back to tuple + return tuple(data) if t_flag else data + + def __iter__(self) -> Iterator[T_co]: + with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: + futures_deque: deque = deque() + has_next = True + itr = iter(self.datapipe) + for _ in range(self.scheduled_tasks): + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + break + + while len(futures_deque) > 0: + if has_next: + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + yield futures_deque.popleft().result() + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length")