From ae2355f73d33babbab4acc36cc29189025629842 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Mon, 27 Feb 2023 14:01:00 +0100 Subject: [PATCH 01/20] Add ThreadPoolMapper --- docs/source/torchdata.datapipes.iter.rst | 1 + test/test_iterdatapipe.py | 40 ++++ tools/gen_pyi.py | 1 + torchdata/datapipes/iter/__init__.py | 1 + .../datapipes/iter/transform/callable.py | 177 ++++++++++++++++++ 5 files changed, 220 insertions(+) diff --git a/docs/source/torchdata.datapipes.iter.rst b/docs/source/torchdata.datapipes.iter.rst index ec4d61d26..706934886 100644 --- a/docs/source/torchdata.datapipes.iter.rst +++ b/docs/source/torchdata.datapipes.iter.rst @@ -165,6 +165,7 @@ These DataPipes apply the a given function to each element in the DataPipe. BatchAsyncMapper BatchMapper + BatchThreadPoolMapper FlatMapper Mapper diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 9ee73ba44..e8f624732 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -12,6 +12,7 @@ import warnings from collections import defaultdict +from functools import partial from typing import Dict import expecttest @@ -1619,6 +1620,45 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_ self.assertEqual(v1, exp) self.assertEqual(v2, exp) + def test_threadpool_map_batches(self): + batch_size = 16 + target_length = 30 + input_dp = IterableWrapper(range(target_length)) + + def fn(item, dtype=torch.float, *, sum=False): + data = torch.tensor(item, dtype=dtype) + return data if not sum else data.sum() + + # Functional Test: apply to each element correctly + map_dp = input_dp.thread_map_batches(fn, batch_size) + # self.assertEqual(target_length, len(map_dp)) + for x, y in zip(map_dp, range(target_length)): + self.assertEqual(x, torch.tensor(y, dtype=torch.float)) + + # Functional Test: works with partial function + map_dp = input_dp.thread_map_batches(partial(fn, dtype=torch.int, sum=True), batch_size) + for x, y in zip(map_dp, range(target_length)): + self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) + + # __len__ Test: inherits length from source DataPipe + # this doesn't work atm + # self.assertEqual(target_length, len(map_dp)) + + input_dp_nl = IDP_NoLen(range(target_length)) + map_dp_nl = input_dp_nl.thread_map_batches((lambda x: x), batch_size) + for x, y in zip(map_dp_nl, range(target_length)): + self.assertEqual(x, torch.tensor(y, dtype=torch.float)) + + # __len__ Test: inherits length from source DataPipe - raises error when invalid + # with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): + # len(map_dp_nl) + + # Reset Test: DataPipe resets properly + n_elements_before_reset = 5 + res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) + self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) + self.assertEqual(list(range(target_length)), res_after_reset) + if __name__ == "__main__": unittest.main() diff --git a/tools/gen_pyi.py b/tools/gen_pyi.py index 7e63ce126..29fc54ce9 100644 --- a/tools/gen_pyi.py +++ b/tools/gen_pyi.py @@ -77,6 +77,7 @@ def gen_pyi() -> None: "read_from_xz": "IterDataPipe", "read_from_zip": "IterDataPipe", "round_robin_demux": "List[IterDataPipe]", + "thread_map_batches": "IterDataPipe", "to_map_datapipe": "MapDataPipe", "unzip": "List[IterDataPipe]", } diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index f2d564055..3a6708c0a 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -68,6 +68,7 @@ from torchdata.datapipes.iter.transform.callable import ( BatchAsyncMapperIterDataPipe as BatchAsyncMapper, BatchMapperIterDataPipe as BatchMapper, + BatchThreadPoolMapperIterDataPipe as BatchThreadPoolMapper, DropperIterDataPipe as Dropper, FlatMapperIterDataPipe as FlatMapper, FlattenIterDataPipe as Flattener, diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 6cbc1266e..c48548146 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -7,6 +7,7 @@ import asyncio import inspect import warnings +from concurrent import futures from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union @@ -585,3 +586,179 @@ def __new__( dp = _BatchAsyncMapperIterDataPipe(dp, async_fn, input_col, output_col, max_concurrency) dp = dp.flatmap() return dp + + +class _BatchThreadPoolMapperIterDataPipe(IterDataPipe): + datapipe: IterDataPipe + fn: Callable + + def __init__( + self, + source_datapipe: IterDataPipe, + fn: Callable, + input_col=None, + output_col=None, + max_workers: Optional[int] = None, + **threadpool_kwargs, + ): + self.source_datapipe = source_datapipe + self.fn = fn # type: ignore[assignment] + if input_col is None and output_col is not None: + raise ValueError("`output_col` must be None when `input_col` is None.") + self.input_col = input_col + if isinstance(output_col, (list, tuple)): + if len(output_col) > 1: + raise ValueError("`output_col` must be a single-element list or tuple") + output_col = output_col[0] + self.output_col = output_col + self.max_workers = max_workers + self.threadpool_kwargs = threadpool_kwargs + + def __iter__(self): + with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: + for batch in self.source_datapipe: + prepared_batch = self.preparebatch(batch) + results = executor.map(self.fn, prepared_batch) + return_batch = self.merge_batch_with_result(batch, results) + yield return_batch + + def preparebatch(self, batch): + if self.input_col is None: + return batch + + prepared_batch = [] + for data in batch: + if isinstance(self.input_col, (list, tuple)): + args = tuple(batch[col] for col in self.input_col) + prepared_batch.append(args) + else: + prepared_batch.append(data[self.input_col]) + return prepared_batch + + def merge_batch_with_result(self, orig_batch, results): + if self.input_col is None: + return results + + new_batch = [] + for data, res in zip(orig_batch, results): + t_flag = isinstance(data, tuple) + if t_flag: + data = list(data) + + if self.output_col is None: + if isinstance(self.input_col, (list, tuple)): + data[self.input_col[0]] = res + for idx in sorted(self.input_col[1:], reverse=True): + del data[idx] + else: + data[self.input_col] = res + elif self.output_col == -1: + data.append(res) + else: + data[self.output_col] = res + + if t_flag: + data = tuple(data) + + new_batch.append(data) + return new_batch + + +@functional_datapipe("thread_map_batches") +class BatchThreadPoolMapperIterDataPipe(IterDataPipe): + r""" + Combines elements from the source DataPipe to batches and applies a function + over each element within the batch concurrently using ``ThreadPoolExecutor``, then flattens the output to a + single, unnested IterDataPipe (functional name: ``thread_map_batches``). + + Args: + source_datapipe: Source IterDataPipe + fn: The function to be applied to each element within batch of data + batch_size: The size of batch to be aggregated from ``source_datapipe`` + input_col: Index or indices of data which ``fn`` is applied, such as: + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified + only when ``input_col`` is not ``None`` + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with + multiple indices, the left-most one is used, and other indices will be removed. + - Integer is used for list/tuple. ``-1`` represents to append result at the end. + - Key is used for dict. New key is acceptable. + max_workers: Maximum num of threads to execute function calls. (Default value: None) + + Note: + For more information about ``max_workers`` and please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor + + Example: + + .. testsetup:: + + from torchdata.datapipes.iter import IterableWrapper + import requests + import time + from unittest.mock import MagicMock + + requests.get = MagicMock() + urls = [] + + .. testcode:: + + def mul_ten(x): + time.sleep(0.1) + return x * 10 + dp = IterableWrapper(range(50)) + dp = dp.thread_map_batches(mul_ten, 16) + print(list(dp)) + + .. testoutput:: + + [0, 10, 20, 30, ...] + + .. testcode:: + + dp = IterableWrapper([(i, i) for i in range(50)]) + dp = dp.thread_map_batches(mul_ten, 16, input_col=1) + print(list(dp)) + + .. testoutput:: + + [(0, 0), (1, 10), (2, 20), (3, 30), ...] + + .. testcode:: + + dp = IterableWrapper([(i, i) for i in range(50)]) + dp = dp.thread_map_batches(mul_ten, 16, input_col=1, output_col=-1) + print(list(dp)) + + .. testoutput:: + + [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] + + .. testcode:: + + # fetching html from remote + def fetch_html(url: str, **kwargs): + r = requests.get(url, **kwargs) + r.raise_for_status() + return r.content + dp = IterableWrapper(urls) + dp = dp.thread_map_batches(fetch_html, 16) + + """ + + def __new__( + cls, + source_datapipe, + fn: Callable, + batch_size: int, + input_col=None, + output_col=None, + max_workers: Optional[int] = None, + **threadpool_kwargs, + ): + dp = source_datapipe.batch(batch_size) + dp = _BatchThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) + dp = dp.flatmap() + + return dp From 873796d7a85859443f678705378226bd8ea48656 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Tue, 28 Feb 2023 13:27:39 +0100 Subject: [PATCH 02/20] Add Tests and extract merge-with-result-function --- test/test_iterdatapipe.py | 244 +++++++++++++++++- .../datapipes/iter/transform/callable.py | 99 +++---- 2 files changed, 281 insertions(+), 62 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index e8f624732..ecbb510aa 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -21,6 +21,7 @@ import torchdata from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls +from torch.testing._internal.common_utils import suppress_warnings from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration from torchdata.datapipes.iter import ( @@ -81,12 +82,12 @@ def _convert_to_tensor(data): async def _async_mul_ten(x): - await asyncio.sleep(1) + await asyncio.sleep(0.1) return x * 10 async def _async_x_mul_y(x, y): - await asyncio.sleep(1) + await asyncio.sleep(0.1) return x * y @@ -290,7 +291,7 @@ def odd_even_bug(i: int) -> int: self.assertEqual(len(source_dp), len(result_dp)) def test_prefetcher_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(50000)) + source_dp = IterableWrapper(range(5000)) prefetched_dp = source_dp.prefetch(10) # check if early termination resets child thread properly for _, _ in zip(range(100), prefetched_dp): @@ -1624,6 +1625,7 @@ def test_threadpool_map_batches(self): batch_size = 16 target_length = 30 input_dp = IterableWrapper(range(target_length)) + input_dp_parallel = IterableWrapper(range(target_length)) def fn(item, dtype=torch.float, *, sum=False): data = torch.tensor(item, dtype=dtype) @@ -1640,8 +1642,8 @@ def fn(item, dtype=torch.float, *, sum=False): for x, y in zip(map_dp, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) - # __len__ Test: inherits length from source DataPipe # this doesn't work atm + # __len__ Test: inherits length from source DataPipe # self.assertEqual(target_length, len(map_dp)) input_dp_nl = IDP_NoLen(range(target_length)) @@ -1649,16 +1651,250 @@ def fn(item, dtype=torch.float, *, sum=False): for x, y in zip(map_dp_nl, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) + # this doesn't work atm # __len__ Test: inherits length from source DataPipe - raises error when invalid # with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): # len(map_dp_nl) + # Test: two independent ThreadPoolExecutors running at the same time + map_dp_parallel = input_dp_parallel.thread_map_batches(fn, batch_size) + for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): + self.assertEqual(x, torch.tensor(z, dtype=torch.float)) + self.assertEqual(y, torch.tensor(z, dtype=torch.float)) + # Reset Test: DataPipe resets properly n_elements_before_reset = 5 res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) self.assertEqual(list(range(target_length)), res_after_reset) + @suppress_warnings # Suppress warning for lambda fn + def test_threadpool_map_batches_tuple_list_with_col_iterdatapipe(self): + batch_size = 3 + + def fn_11(d): + return -d + + def fn_1n(d): + return -d, d + + def fn_n1(d0, d1): + return d0 + d1 + + def fn_nn(d0, d1): + return -d0, -d1, d0 + d1 + + def fn_n1_def(d0, d1=1): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_sep_pos(d0, *args, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + p_fn_cmplx = partial(fn_cmplx, d2=2) + + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): + for constr in (list, tuple): + datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.map(fn, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + if constr is list: + ref_dp = ref_dp.map(list) + self.assertEqual(list(res_dp), list(ref_dp)) + # Reset + self.assertEqual(list(res_dp), list(ref_dp)) + + _helper(lambda data: data, fn_n1_def, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) + _helper(lambda data: data, p_fn_n1, 0, 1) + _helper(lambda data: data, p_fn_cmplx, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2) + _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) + + # Replacing with one input column and default output column + _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) + _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) + # The index of input column is out of range + _helper(None, fn_1n, 3, error=IndexError) + # Unmatched input columns with fn arguments + _helper(None, fn_n1, 1, error=ValueError) + _helper(None, fn_n1, [0, 1, 2], error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError) + _helper(None, fn_cmplx, 0, 1, ValueError) + _helper(None, fn_n1_pos, 1, error=ValueError) + _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) + _helper(None, p_fn_n1, [0, 1], error=ValueError) + _helper(None, fn_1n, [1, 2], error=ValueError) + # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) + _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) + # Fn has keyword-only arguments + _helper(None, fn_n1_kwargs, 1, error=ValueError) + _helper(None, fn_cmplx, [0, 1], 2, ValueError) + + # Replacing with multiple input columns and default output column (the left-most input column) + _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) + _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1]) + + # output_col can only be specified when input_col is not None + _helper(None, fn_n1, None, 1, error=ValueError) + # output_col can only be single-element list or tuple + _helper(None, fn_n1, None, [0, 1], error=ValueError) + # Single-element list as output_col + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) + # Replacing with one input column and single specified output column + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) + _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) + # The index of output column is out of range + _helper(None, fn_1n, 1, 3, error=IndexError) + _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) + _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) + + # Appending the output at the end + _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) + _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) + _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) + _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1) + + # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected + _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) + _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) + + @suppress_warnings # Suppress warning for lambda fn + def test_threadpool_map_batches_dict_with_col_iterdatapipe(self): + batch_size = 3 + + def fn_11(d): + return -d + + def fn_1n(d): + return -d, d + + def fn_n1(d0, d1): + return d0 + d1 + + def fn_nn(d0, d1): + return -d0, -d1, d0 + d1 + + def fn_n1_def(d0, d1=1): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_kwonly(*, d0, d1): + return d0 + d1 + + def fn_has_nondefault_kwonly(d0, *, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_cmplx = partial(fn_cmplx, d2=2) + + # Prevent modification in-place to support resetting + def _dict_update(data, newdata, remove_idx=None): + _data = dict(data) + _data.update(newdata) + if remove_idx: + for idx in remove_idx: + del _data[idx] + return _data + + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): + datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + self.assertEqual(list(res_dp), list(ref_dp)) + # Reset + self.assertEqual(list(res_dp), list(ref_dp)) + + _helper(lambda data: data, fn_n1_def, "x", "y") + _helper(lambda data: data, p_fn_n1, "x", "y") + _helper(lambda data: data, p_fn_cmplx, "x", "y") + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z") + + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z") + + _helper(None, fn_n1_pos, "x", error=ValueError) + _helper(None, fn_n1_kwargs, "x", error=ValueError) + # non-default kw-only args + _helper(None, fn_kwonly, ["x", "y"], error=ValueError) + _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) + _helper(None, fn_cmplx, ["x", "y"], error=ValueError) + + # Replacing with one input column and default output column + _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") + _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") + # The key of input column is not in dict + _helper(None, fn_1n, "a", error=KeyError) + # Unmatched input columns with fn arguments + _helper(None, fn_n1, "y", error=ValueError) + _helper(None, fn_1n, ["x", "y"], error=ValueError) + _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) + _helper(None, p_fn_n1, ["x", "y"], error=ValueError) + _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) + # Replacing with multiple input columns and default output column (the left-most input column) + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) + _helper( + lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), + fn_nn, + ["z", "y"], + ) + + # output_col can only be specified when input_col is not None + _helper(None, fn_n1, None, "x", error=ValueError) + # output_col can only be single-element list or tuple + _helper(None, fn_n1, None, ["x", "y"], error=ValueError) + # Single-element list as output_col + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) + # Replacing with one input column and single specified output column + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") + _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z") + _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y") + _helper( + lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), + fn_nn, + ["y", "z"], + "x", + ) + + # Adding new key to dict for the output + _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") + _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a") + _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a") + _helper( + lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), + fn_nn, + ["y", "z"], + "a", + ) + if __name__ == "__main__": unittest.main() diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index c48548146..357604654 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -27,6 +27,35 @@ def _no_op_fn(*args): return args +def _merge_batch_with_result(orig_batch, results, input_col, output_col): + if input_col is None: + return results + + new_batch = [] + for data, res in zip(orig_batch, results): + t_flag = isinstance(data, tuple) + if t_flag: + data = list(data) + + if output_col is None: + if isinstance(input_col, (list, tuple)): + data[input_col[0]] = res + for idx in sorted(input_col[1:], reverse=True): + del data[idx] + else: + data[input_col] = res + elif output_col == -1: + data.append(res) + else: + data[output_col] = res + + if t_flag: + data = tuple(data) + + new_batch.append(data) + return new_batch + + @functional_datapipe("map_batches") class BatchMapperIterDataPipe(IterDataPipe[T_co]): r""" @@ -495,30 +524,7 @@ async def controlled_async_fn(async_fn, *data): else: coroutines.append(controlled_async_fn(self.async_fn, data[self.input_col])) results = await asyncio.gather(*coroutines) - - new_batch = [] - for data, res in zip(batch, results): - t_flag = isinstance(data, tuple) - if t_flag: - data = list(data) - - if self.output_col is None: - if isinstance(self.input_col, (list, tuple)): - data[self.input_col[0]] = res - for idx in sorted(self.input_col[1:], reverse=True): - del data[idx] - else: - data[self.input_col] = res - elif self.output_col == -1: - data.append(res) - else: - data[self.output_col] = res - - if t_flag: - data = tuple(data) - - new_batch.append(data) - return new_batch + return _merge_batch_with_result(batch, results, self.input_col, self.output_col) @functional_datapipe("async_map_batches") @@ -602,14 +608,21 @@ def __init__( **threadpool_kwargs, ): self.source_datapipe = source_datapipe - self.fn = fn # type: ignore[assignment] + self.fn = fn + self.input_col = input_col + validate_input_col(fn, input_col) if input_col is None and output_col is not None: raise ValueError("`output_col` must be None when `input_col` is None.") - self.input_col = input_col if isinstance(output_col, (list, tuple)): if len(output_col) > 1: raise ValueError("`output_col` must be a single-element list or tuple") output_col = output_col[0] + if isinstance(self.input_col, (list, tuple)): + + def wrapper_fn(args): + return fn(*args) + + self.fn = wrapper_fn self.output_col = output_col self.max_workers = max_workers self.threadpool_kwargs = threadpool_kwargs @@ -619,8 +632,7 @@ def __iter__(self): for batch in self.source_datapipe: prepared_batch = self.preparebatch(batch) results = executor.map(self.fn, prepared_batch) - return_batch = self.merge_batch_with_result(batch, results) - yield return_batch + yield _merge_batch_with_result(batch, results, self.input_col, self.output_col) def preparebatch(self, batch): if self.input_col is None: @@ -629,40 +641,12 @@ def preparebatch(self, batch): prepared_batch = [] for data in batch: if isinstance(self.input_col, (list, tuple)): - args = tuple(batch[col] for col in self.input_col) + args = tuple(data[col] for col in self.input_col) prepared_batch.append(args) else: prepared_batch.append(data[self.input_col]) return prepared_batch - def merge_batch_with_result(self, orig_batch, results): - if self.input_col is None: - return results - - new_batch = [] - for data, res in zip(orig_batch, results): - t_flag = isinstance(data, tuple) - if t_flag: - data = list(data) - - if self.output_col is None: - if isinstance(self.input_col, (list, tuple)): - data[self.input_col[0]] = res - for idx in sorted(self.input_col[1:], reverse=True): - del data[idx] - else: - data[self.input_col] = res - elif self.output_col == -1: - data.append(res) - else: - data[self.output_col] = res - - if t_flag: - data = tuple(data) - - new_batch.append(data) - return new_batch - @functional_datapipe("thread_map_batches") class BatchThreadPoolMapperIterDataPipe(IterDataPipe): @@ -760,5 +744,4 @@ def __new__( dp = source_datapipe.batch(batch_size) dp = _BatchThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) dp = dp.flatmap() - return dp From 57d2f9e56998411df54813342b66ae641656f25d Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Tue, 28 Feb 2023 13:45:46 +0100 Subject: [PATCH 03/20] Fix mypy assign issue --- torchdata/datapipes/iter/transform/callable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 357604654..6fbf0daf9 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -608,7 +608,7 @@ def __init__( **threadpool_kwargs, ): self.source_datapipe = source_datapipe - self.fn = fn + self.fn = fn # type: ignore[assignment] self.input_col = input_col validate_input_col(fn, input_col) if input_col is None and output_col is not None: @@ -622,7 +622,7 @@ def __init__( def wrapper_fn(args): return fn(*args) - self.fn = wrapper_fn + self.fn = wrapper_fn # type: ignore[assignment] self.output_col = output_col self.max_workers = max_workers self.threadpool_kwargs = threadpool_kwargs From 197fd41b542c8475ff9886e37e536a9952490f5e Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Tue, 28 Feb 2023 14:18:07 +0100 Subject: [PATCH 04/20] Small doc fixes --- torchdata/datapipes/iter/transform/callable.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 6fbf0daf9..b77889f8b 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -539,15 +539,19 @@ class BatchAsyncMapperIterDataPipe(IterDataPipe): async_fn: The coroutine function to be applied to each batch of data batch_size: The size of batch to be aggregated from ``source_datapipe`` input_col: Index or indices of data which ``fn`` is applied, such as: + - ``None`` as default to apply ``fn`` to the data directly. - Integer(s) is used for list/tuple. - Key(s) is used for dict. + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified only when ``input_col`` is not ``None`` + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with multiple indices, the left-most one is used, and other indices will be removed. - Integer is used for list/tuple. ``-1`` represents to append result at the end. - Key is used for dict. New key is acceptable. + max_concurrency: Maximum concurrency to call async functions. (Default value: 32) Example: @@ -630,11 +634,11 @@ def wrapper_fn(args): def __iter__(self): with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: for batch in self.source_datapipe: - prepared_batch = self.preparebatch(batch) + prepared_batch = self._prepare_batch(batch) results = executor.map(self.fn, prepared_batch) yield _merge_batch_with_result(batch, results, self.input_col, self.output_col) - def preparebatch(self, batch): + def _prepare_batch(self, batch): if self.input_col is None: return batch @@ -660,19 +664,25 @@ class BatchThreadPoolMapperIterDataPipe(IterDataPipe): fn: The function to be applied to each element within batch of data batch_size: The size of batch to be aggregated from ``source_datapipe`` input_col: Index or indices of data which ``fn`` is applied, such as: + - ``None`` as default to apply ``fn`` to the data directly. - Integer(s) is used for list/tuple. - Key(s) is used for dict. + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified only when ``input_col`` is not ``None`` + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with multiple indices, the left-most one is used, and other indices will be removed. - Integer is used for list/tuple. ``-1`` represents to append result at the end. - Key is used for dict. New key is acceptable. - max_workers: Maximum num of threads to execute function calls. (Default value: None) + + max_workers: Maximum number of threads to execute function calls. (Default value: None) + **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` Note: - For more information about ``max_workers`` and please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor + For more information about ``max_workers`` and additional arguments for the ``ThreadPoolExecutor`` + please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor Example: From 04f4f974f42cf69e5cffdb9a8ad744324d9a3d66 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 2 Mar 2023 10:34:30 +0100 Subject: [PATCH 05/20] Test: call thread_map instead of map --- test/test_iterdatapipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index ecbb510aa..fcab2665a 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1710,7 +1710,7 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) list(res_dp) else: - res_dp = datapipe.map(fn, input_col, output_col) + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) ref_dp = datapipe.map(ref_fn) if constr is list: ref_dp = ref_dp.map(list) From ca21fdcd846d46dd54c8355619141a885b8b7988 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 2 Mar 2023 17:58:12 +0100 Subject: [PATCH 06/20] Addressing PR comments --- docs/source/torchdata.datapipes.iter.rst | 2 +- test/test_iterdatapipe.py | 24 +++++++++---------- tools/gen_pyi.py | 2 +- torchdata/datapipes/iter/__init__.py | 2 +- .../datapipes/iter/transform/callable.py | 10 ++++---- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/source/torchdata.datapipes.iter.rst b/docs/source/torchdata.datapipes.iter.rst index 706934886..50275af0a 100644 --- a/docs/source/torchdata.datapipes.iter.rst +++ b/docs/source/torchdata.datapipes.iter.rst @@ -165,9 +165,9 @@ These DataPipes apply the a given function to each element in the DataPipe. BatchAsyncMapper BatchMapper - BatchThreadPoolMapper FlatMapper Mapper + ThreadPoolMapper Other DataPipes ------------------------- diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index fcab2665a..cad0033d4 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1632,13 +1632,13 @@ def fn(item, dtype=torch.float, *, sum=False): return data if not sum else data.sum() # Functional Test: apply to each element correctly - map_dp = input_dp.thread_map_batches(fn, batch_size) + map_dp = input_dp.threadpool_map(fn, batch_size) # self.assertEqual(target_length, len(map_dp)) for x, y in zip(map_dp, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) # Functional Test: works with partial function - map_dp = input_dp.thread_map_batches(partial(fn, dtype=torch.int, sum=True), batch_size) + map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True), batch_size) for x, y in zip(map_dp, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) @@ -1647,7 +1647,7 @@ def fn(item, dtype=torch.float, *, sum=False): # self.assertEqual(target_length, len(map_dp)) input_dp_nl = IDP_NoLen(range(target_length)) - map_dp_nl = input_dp_nl.thread_map_batches((lambda x: x), batch_size) + map_dp_nl = input_dp_nl.threadpool_map((lambda x: x), batch_size) for x, y in zip(map_dp_nl, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) @@ -1657,7 +1657,7 @@ def fn(item, dtype=torch.float, *, sum=False): # len(map_dp_nl) # Test: two independent ThreadPoolExecutors running at the same time - map_dp_parallel = input_dp_parallel.thread_map_batches(fn, batch_size) + map_dp_parallel = input_dp_parallel.threadpool_map(fn, batch_size) for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): self.assertEqual(x, torch.tensor(z, dtype=torch.float)) self.assertEqual(y, torch.tensor(z, dtype=torch.float)) @@ -1707,16 +1707,16 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) if ref_fn is None: with self.assertRaises(error): - res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) list(res_dp) else: - res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) ref_dp = datapipe.map(ref_fn) if constr is list: ref_dp = ref_dp.map(list) - self.assertEqual(list(res_dp), list(ref_dp)) + self.assertEqual(list(res_dp), list(ref_dp), "First test failed") # Reset - self.assertEqual(list(res_dp), list(ref_dp)) + self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") _helper(lambda data: data, fn_n1_def, 0, 1) _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) @@ -1825,14 +1825,14 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) if ref_fn is None: with self.assertRaises(error): - res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) list(res_dp) else: - res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) ref_dp = datapipe.map(ref_fn) - self.assertEqual(list(res_dp), list(ref_dp)) + self.assertEqual(list(res_dp), list(ref_dp), "First test failed") # Reset - self.assertEqual(list(res_dp), list(ref_dp)) + self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") _helper(lambda data: data, fn_n1_def, "x", "y") _helper(lambda data: data, p_fn_n1, "x", "y") diff --git a/tools/gen_pyi.py b/tools/gen_pyi.py index 29fc54ce9..ab188223c 100644 --- a/tools/gen_pyi.py +++ b/tools/gen_pyi.py @@ -77,7 +77,7 @@ def gen_pyi() -> None: "read_from_xz": "IterDataPipe", "read_from_zip": "IterDataPipe", "round_robin_demux": "List[IterDataPipe]", - "thread_map_batches": "IterDataPipe", + "threadpool_map": "IterDataPipe", "to_map_datapipe": "MapDataPipe", "unzip": "List[IterDataPipe]", } diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 3a6708c0a..40027e8ee 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -68,11 +68,11 @@ from torchdata.datapipes.iter.transform.callable import ( BatchAsyncMapperIterDataPipe as BatchAsyncMapper, BatchMapperIterDataPipe as BatchMapper, - BatchThreadPoolMapperIterDataPipe as BatchThreadPoolMapper, DropperIterDataPipe as Dropper, FlatMapperIterDataPipe as FlatMapper, FlattenIterDataPipe as Flattener, SliceIterDataPipe as Slicer, + ThreadPoolMapperIterDataPipe as ThreadPoolMapper, ) from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader from torchdata.datapipes.iter.util.cacheholder import ( diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index b77889f8b..a88fdca14 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -598,7 +598,7 @@ def __new__( return dp -class _BatchThreadPoolMapperIterDataPipe(IterDataPipe): +class _ThreadPoolMapperIterDataPipe(IterDataPipe): datapipe: IterDataPipe fn: Callable @@ -652,12 +652,12 @@ def _prepare_batch(self, batch): return prepared_batch -@functional_datapipe("thread_map_batches") -class BatchThreadPoolMapperIterDataPipe(IterDataPipe): +@functional_datapipe("threadpool_map") +class ThreadPoolMapperIterDataPipe(IterDataPipe): r""" Combines elements from the source DataPipe to batches and applies a function over each element within the batch concurrently using ``ThreadPoolExecutor``, then flattens the output to a - single, unnested IterDataPipe (functional name: ``thread_map_batches``). + single, unnested IterDataPipe (functional name: ``threadpool_map``). Args: source_datapipe: Source IterDataPipe @@ -752,6 +752,6 @@ def __new__( **threadpool_kwargs, ): dp = source_datapipe.batch(batch_size) - dp = _BatchThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) + dp = _ThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) dp = dp.flatmap() return dp From 9a042e35183488b4fe04626a23714ae7f83aefc8 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 2 Mar 2023 18:15:24 +0100 Subject: [PATCH 07/20] Fix doctest --- torchdata/datapipes/iter/transform/callable.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index a88fdca14..d8568ca3f 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -702,7 +702,7 @@ def mul_ten(x): time.sleep(0.1) return x * 10 dp = IterableWrapper(range(50)) - dp = dp.thread_map_batches(mul_ten, 16) + dp = dp.threadpool_map(mul_ten, 16) print(list(dp)) .. testoutput:: @@ -712,7 +712,7 @@ def mul_ten(x): .. testcode:: dp = IterableWrapper([(i, i) for i in range(50)]) - dp = dp.thread_map_batches(mul_ten, 16, input_col=1) + dp = dp.threadpool_map(mul_ten, 16, input_col=1) print(list(dp)) .. testoutput:: @@ -722,7 +722,7 @@ def mul_ten(x): .. testcode:: dp = IterableWrapper([(i, i) for i in range(50)]) - dp = dp.thread_map_batches(mul_ten, 16, input_col=1, output_col=-1) + dp = dp.threadpool_map(mul_ten, 16, input_col=1, output_col=-1) print(list(dp)) .. testoutput:: @@ -737,7 +737,7 @@ def fetch_html(url: str, **kwargs): r.raise_for_status() return r.content dp = IterableWrapper(urls) - dp = dp.thread_map_batches(fetch_html, 16) + dp = dp.threadpool_map(fetch_html, 16) """ From d79451c9e754d51f5bbb6ffdd7122eadd882efd4 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 2 Mar 2023 20:26:04 +0100 Subject: [PATCH 08/20] Change test names --- test/test_iterdatapipe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index cad0033d4..fe44227c0 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1621,7 +1621,7 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_ self.assertEqual(v1, exp) self.assertEqual(v2, exp) - def test_threadpool_map_batches(self): + def test_threadpool_map(self): batch_size = 16 target_length = 30 input_dp = IterableWrapper(range(target_length)) @@ -1669,7 +1669,7 @@ def fn(item, dtype=torch.float, *, sum=False): self.assertEqual(list(range(target_length)), res_after_reset) @suppress_warnings # Suppress warning for lambda fn - def test_threadpool_map_batches_tuple_list_with_col_iterdatapipe(self): + def test_threadpool_map_tuple_list_with_col_iterdatapipe(self): batch_size = 3 def fn_11(d): @@ -1775,7 +1775,7 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) @suppress_warnings # Suppress warning for lambda fn - def test_threadpool_map_batches_dict_with_col_iterdatapipe(self): + def test_threadpool_map_dict_with_col_iterdatapipe(self): batch_size = 3 def fn_11(d): From 1e62f3343bfdc10d1d117ae6593eafe6a8479915 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Fri, 3 Mar 2023 11:07:58 +0100 Subject: [PATCH 09/20] Add prefetch to ThreadMapper --- torchdata/datapipes/iter/transform/callable.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index d8568ca3f..57eb4e59f 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -753,5 +753,6 @@ def __new__( ): dp = source_datapipe.batch(batch_size) dp = _ThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) + dp = dp.prefetch(buffer_size=2) # start working on the next batch before previous batch is exhausted dp = dp.flatmap() return dp From c948a779bee9007ef422a128ec3fd61f57db012b Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sat, 4 Mar 2023 11:31:09 +0100 Subject: [PATCH 10/20] No longer use batches in ThreadPoolMap --- test/test_iterdatapipe.py | 31 +-- .../datapipes/iter/transform/callable.py | 241 ++++++++++-------- 2 files changed, 140 insertions(+), 132 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index fe44227c0..d0df5caf4 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1622,7 +1622,6 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_ self.assertEqual(v2, exp) def test_threadpool_map(self): - batch_size = 16 target_length = 30 input_dp = IterableWrapper(range(target_length)) input_dp_parallel = IterableWrapper(range(target_length)) @@ -1632,32 +1631,30 @@ def fn(item, dtype=torch.float, *, sum=False): return data if not sum else data.sum() # Functional Test: apply to each element correctly - map_dp = input_dp.threadpool_map(fn, batch_size) - # self.assertEqual(target_length, len(map_dp)) + map_dp = input_dp.threadpool_map(fn) + self.assertEqual(target_length, len(map_dp)) for x, y in zip(map_dp, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) # Functional Test: works with partial function - map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True), batch_size) + map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True)) for x, y in zip(map_dp, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) - # this doesn't work atm # __len__ Test: inherits length from source DataPipe - # self.assertEqual(target_length, len(map_dp)) + self.assertEqual(target_length, len(map_dp)) input_dp_nl = IDP_NoLen(range(target_length)) - map_dp_nl = input_dp_nl.threadpool_map((lambda x: x), batch_size) + map_dp_nl = input_dp_nl.threadpool_map(lambda x: x) for x, y in zip(map_dp_nl, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) - # this doesn't work atm # __len__ Test: inherits length from source DataPipe - raises error when invalid - # with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): - # len(map_dp_nl) + with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): + len(map_dp_nl) # Test: two independent ThreadPoolExecutors running at the same time - map_dp_parallel = input_dp_parallel.threadpool_map(fn, batch_size) + map_dp_parallel = input_dp_parallel.threadpool_map(fn) for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): self.assertEqual(x, torch.tensor(z, dtype=torch.float)) self.assertEqual(y, torch.tensor(z, dtype=torch.float)) @@ -1670,8 +1667,6 @@ def fn(item, dtype=torch.float, *, sum=False): @suppress_warnings # Suppress warning for lambda fn def test_threadpool_map_tuple_list_with_col_iterdatapipe(self): - batch_size = 3 - def fn_11(d): return -d @@ -1707,10 +1702,10 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) if ref_fn is None: with self.assertRaises(error): - res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, input_col, output_col) list(res_dp) else: - res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, input_col, output_col) ref_dp = datapipe.map(ref_fn) if constr is list: ref_dp = ref_dp.map(list) @@ -1776,8 +1771,6 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): @suppress_warnings # Suppress warning for lambda fn def test_threadpool_map_dict_with_col_iterdatapipe(self): - batch_size = 3 - def fn_11(d): return -d @@ -1825,10 +1818,10 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) if ref_fn is None: with self.assertRaises(error): - res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, input_col, output_col) list(res_dp) else: - res_dp = datapipe.threadpool_map(fn, batch_size, input_col, output_col) + res_dp = datapipe.threadpool_map(fn, input_col, output_col) ref_dp = datapipe.map(ref_fn) self.assertEqual(list(res_dp), list(ref_dp), "First test failed") # Reset diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 57eb4e59f..dd598de5e 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -7,6 +7,7 @@ import asyncio import inspect import warnings +from collections import deque from concurrent import futures from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union @@ -27,35 +28,6 @@ def _no_op_fn(*args): return args -def _merge_batch_with_result(orig_batch, results, input_col, output_col): - if input_col is None: - return results - - new_batch = [] - for data, res in zip(orig_batch, results): - t_flag = isinstance(data, tuple) - if t_flag: - data = list(data) - - if output_col is None: - if isinstance(input_col, (list, tuple)): - data[input_col[0]] = res - for idx in sorted(input_col[1:], reverse=True): - del data[idx] - else: - data[input_col] = res - elif output_col == -1: - data.append(res) - else: - data[output_col] = res - - if t_flag: - data = tuple(data) - - new_batch.append(data) - return new_batch - - @functional_datapipe("map_batches") class BatchMapperIterDataPipe(IterDataPipe[T_co]): r""" @@ -524,7 +496,30 @@ async def controlled_async_fn(async_fn, *data): else: coroutines.append(controlled_async_fn(self.async_fn, data[self.input_col])) results = await asyncio.gather(*coroutines) - return _merge_batch_with_result(batch, results, self.input_col, self.output_col) + + new_batch = [] + for data, res in zip(batch, results): + t_flag = isinstance(data, tuple) + if t_flag: + data = list(data) + + if self.output_col is None: + if isinstance(self.input_col, (list, tuple)): + data[self.input_col[0]] = res + for idx in sorted(self.input_col[1:], reverse=True): + del data[idx] + else: + data[self.input_col] = res + elif self.output_col == -1: + data.append(res) + else: + data[self.output_col] = res + + if t_flag: + data = tuple(data) + + new_batch.append(data) + return new_batch @functional_datapipe("async_map_batches") @@ -598,60 +593,6 @@ def __new__( return dp -class _ThreadPoolMapperIterDataPipe(IterDataPipe): - datapipe: IterDataPipe - fn: Callable - - def __init__( - self, - source_datapipe: IterDataPipe, - fn: Callable, - input_col=None, - output_col=None, - max_workers: Optional[int] = None, - **threadpool_kwargs, - ): - self.source_datapipe = source_datapipe - self.fn = fn # type: ignore[assignment] - self.input_col = input_col - validate_input_col(fn, input_col) - if input_col is None and output_col is not None: - raise ValueError("`output_col` must be None when `input_col` is None.") - if isinstance(output_col, (list, tuple)): - if len(output_col) > 1: - raise ValueError("`output_col` must be a single-element list or tuple") - output_col = output_col[0] - if isinstance(self.input_col, (list, tuple)): - - def wrapper_fn(args): - return fn(*args) - - self.fn = wrapper_fn # type: ignore[assignment] - self.output_col = output_col - self.max_workers = max_workers - self.threadpool_kwargs = threadpool_kwargs - - def __iter__(self): - with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: - for batch in self.source_datapipe: - prepared_batch = self._prepare_batch(batch) - results = executor.map(self.fn, prepared_batch) - yield _merge_batch_with_result(batch, results, self.input_col, self.output_col) - - def _prepare_batch(self, batch): - if self.input_col is None: - return batch - - prepared_batch = [] - for data in batch: - if isinstance(self.input_col, (list, tuple)): - args = tuple(data[col] for col in self.input_col) - prepared_batch.append(args) - else: - prepared_batch.append(data[self.input_col]) - return prepared_batch - - @functional_datapipe("threadpool_map") class ThreadPoolMapperIterDataPipe(IterDataPipe): r""" @@ -662,7 +603,6 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe): Args: source_datapipe: Source IterDataPipe fn: The function to be applied to each element within batch of data - batch_size: The size of batch to be aggregated from ``source_datapipe`` input_col: Index or indices of data which ``fn`` is applied, such as: - ``None`` as default to apply ``fn`` to the data directly. @@ -677,6 +617,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe): - Integer is used for list/tuple. ``-1`` represents to append result at the end. - Key is used for dict. New key is acceptable. + scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 64) max_workers: Maximum number of threads to execute function calls. (Default value: None) **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` @@ -684,6 +625,10 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe): For more information about ``max_workers`` and additional arguments for the ``ThreadPoolExecutor`` please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor + Note: + For optimal use of all threads, we recommend ``scheduled_tasks`` > ``max_workers``. High value of ``scheduled_tasks`` + might lead to long waiting period until the first element is yielded. + Example: .. testsetup:: @@ -696,20 +641,21 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe): requests.get = MagicMock() urls = [] + .. testcode:: + + # fetching html from remote + def fetch_html(url: str, **kwargs): + r = requests.get(url, **kwargs) + r.raise_for_status() + return r.content + dp = IterableWrapper(urls) + dp = dp.threadpool_map(fetch_html, batch_size=32,max_workers=16) + .. testcode:: def mul_ten(x): time.sleep(0.1) return x * 10 - dp = IterableWrapper(range(50)) - dp = dp.threadpool_map(mul_ten, 16) - print(list(dp)) - - .. testoutput:: - - [0, 10, 20, 30, ...] - - .. testcode:: dp = IterableWrapper([(i, i) for i in range(50)]) dp = dp.threadpool_map(mul_ten, 16, input_col=1) @@ -729,30 +675,99 @@ def mul_ten(x): [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] - .. testcode:: - - # fetching html from remote - def fetch_html(url: str, **kwargs): - r = requests.get(url, **kwargs) - r.raise_for_status() - return r.content - dp = IterableWrapper(urls) - dp = dp.threadpool_map(fetch_html, 16) - """ - def __new__( - cls, - source_datapipe, + datapipe: IterDataPipe + fn: Callable + + def __init__( + self, + datapipe: IterDataPipe, fn: Callable, - batch_size: int, input_col=None, output_col=None, + scheduled_tasks: int = 64, max_workers: Optional[int] = None, **threadpool_kwargs, ): - dp = source_datapipe.batch(batch_size) - dp = _ThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) - dp = dp.prefetch(buffer_size=2) # start working on the next batch before previous batch is exhausted - dp = dp.flatmap() - return dp + self.datapipe = datapipe + + _check_unpickable_fn(fn) + self.fn = fn # type: ignore[assignment] + + if scheduled_tasks <= 0: + raise ValueError("'scheduled_tasks' is required to be a positive integer.") + self.scheduled_tasks = scheduled_tasks + if max_workers is not None and max_workers <= 0: + raise ValueError("'max_workers' is required to be a positive integer.") + self.max_workers = max_workers + self.threadpool_kwargs = threadpool_kwargs + + self.input_col = input_col + if input_col is None and output_col is not None: + raise ValueError("`output_col` must be None when `input_col` is None.") + if isinstance(output_col, (list, tuple)): + if len(output_col) > 1: + raise ValueError("`output_col` must be a single-element list or tuple") + output_col = output_col[0] + self.output_col = output_col + validate_input_col(fn, input_col) + + def _apply_fn(self, data): + if self.input_col is None and self.output_col is None: + return self.fn(data) + + if self.input_col is None: + res = self.fn(data) + elif isinstance(self.input_col, (list, tuple)): + args = tuple(data[col] for col in self.input_col) + res = self.fn(*args) + else: + res = self.fn(data[self.input_col]) + + # Copy tuple to list and run in-place modification because tuple is immutable. + if isinstance(data, tuple): + t_flag = True + data = list(data) + else: + t_flag = False + + if self.output_col is None: + if isinstance(self.input_col, (list, tuple)): + data[self.input_col[0]] = res + for idx in sorted(self.input_col[1:], reverse=True): + del data[idx] + else: + data[self.input_col] = res + else: + if self.output_col == -1: + data.append(res) + else: + data[self.output_col] = res + + # Convert list back to tuple + return tuple(data) if t_flag else data + + def __iter__(self): + with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: + futures_deque: deque = deque() + has_next = True + itr = iter(self.datapipe) + for _ in range(self.scheduled_tasks): + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + break + while len(futures_deque) > 0: + if has_next: + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + yield futures_deque.popleft().result() + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") From 4dbc1b58edc7a5f640844874ca517bd47871ff22 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sat, 4 Mar 2023 11:38:20 +0100 Subject: [PATCH 11/20] Fix doctest --- torchdata/datapipes/iter/transform/callable.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index dd598de5e..63c1739c8 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -649,7 +649,7 @@ def fetch_html(url: str, **kwargs): r.raise_for_status() return r.content dp = IterableWrapper(urls) - dp = dp.threadpool_map(fetch_html, batch_size=32,max_workers=16) + dp = dp.threadpool_map(fetch_html,max_workers=16) .. testcode:: @@ -658,7 +658,7 @@ def mul_ten(x): return x * 10 dp = IterableWrapper([(i, i) for i in range(50)]) - dp = dp.threadpool_map(mul_ten, 16, input_col=1) + dp = dp.threadpool_map(mul_ten, input_col=1) print(list(dp)) .. testoutput:: @@ -668,7 +668,7 @@ def mul_ten(x): .. testcode:: dp = IterableWrapper([(i, i) for i in range(50)]) - dp = dp.threadpool_map(mul_ten, 16, input_col=1, output_col=-1) + dp = dp.threadpool_map(mul_ten, input_col=1, output_col=-1) print(list(dp)) .. testoutput:: From e3d5532c81856bf1cbb433e45f4cb3c23608f6c5 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sat, 4 Mar 2023 11:44:29 +0100 Subject: [PATCH 12/20] Fix documentation --- torchdata/datapipes/iter/transform/callable.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 63c1739c8..c9f35ae25 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -596,13 +596,14 @@ def __new__( @functional_datapipe("threadpool_map") class ThreadPoolMapperIterDataPipe(IterDataPipe): r""" - Combines elements from the source DataPipe to batches and applies a function - over each element within the batch concurrently using ``ThreadPoolExecutor``, then flattens the output to a - single, unnested IterDataPipe (functional name: ``threadpool_map``). + Applies a function over each item from the source DataPipe concurrently + using ``ThreadPoolExecutor`` (functional name: ``threadpool_map``). + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. Args: source_datapipe: Source IterDataPipe - fn: The function to be applied to each element within batch of data + fn: Function being applied over each item input_col: Index or indices of data which ``fn`` is applied, such as: - ``None`` as default to apply ``fn`` to the data directly. @@ -627,7 +628,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe): Note: For optimal use of all threads, we recommend ``scheduled_tasks`` > ``max_workers``. High value of ``scheduled_tasks`` - might lead to long waiting period until the first element is yielded. + might lead to long waiting period until the first element is yielded as tasks are executed out of order. Example: From f68328d1ba720ff3ebc8c48f105956646c2b7c35 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sat, 4 Mar 2023 12:48:07 +0100 Subject: [PATCH 13/20] Add serialization test --- test/test_serialization.py | 2 ++ tools/gen_pyi.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index a837a3948..b891ed4f8 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -297,6 +297,7 @@ def test_serializable(self): (iterdp.TarArchiveLoader, None, (), {}), # TODO(594): Add serialization tests for optional DataPipe # (iterdp.TFRecordLoader, None, (), {}), + (iterdp.ThreadPoolMapper, None, (_fake_fn_ls,), {}), (iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}), (iterdp.WebDataset, IterableWrapper([("foo.txt", b"1"), ("bar.txt", b"2")]), (), {}), (iterdp.XzFileLoader, None, (), {}), @@ -366,6 +367,7 @@ def test_serializable_with_dill(self): (iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}), (iterdp.OnDiskCacheHolder, (lambda x: x,), {}), (iterdp.ParagraphAggregator, (lambda x: x,), {}), + (iterdp.ThreadPoolMapper, (lambda x: x,), {}), ] # Skipping value comparison for these DataPipes dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator} diff --git a/tools/gen_pyi.py b/tools/gen_pyi.py index ab188223c..7e63ce126 100644 --- a/tools/gen_pyi.py +++ b/tools/gen_pyi.py @@ -77,7 +77,6 @@ def gen_pyi() -> None: "read_from_xz": "IterDataPipe", "read_from_zip": "IterDataPipe", "round_robin_demux": "List[IterDataPipe]", - "threadpool_map": "IterDataPipe", "to_map_datapipe": "MapDataPipe", "unzip": "List[IterDataPipe]", } From 9592c5d9cd655a96985dc169de67c4376ee52151 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sat, 4 Mar 2023 13:04:25 +0100 Subject: [PATCH 14/20] add to init + type fix --- torchdata/datapipes/iter/__init__.py | 1 + torchdata/datapipes/iter/transform/callable.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 40027e8ee..22f038f45 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -214,6 +214,7 @@ "StreamReader", "TFRecordLoader", "TarArchiveLoader", + "ThreadPoolMapper", "UnBatcher", "UnZipper", "WebDataset", diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index c9f35ae25..53301ddab 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -594,7 +594,7 @@ def __new__( @functional_datapipe("threadpool_map") -class ThreadPoolMapperIterDataPipe(IterDataPipe): +class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): r""" Applies a function over each item from the source DataPipe concurrently using ``ThreadPoolExecutor`` (functional name: ``threadpool_map``). From b09c53bd5c349c39a4daa484512da808d31033cb Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sat, 4 Mar 2023 13:20:22 +0100 Subject: [PATCH 15/20] pls help in fixing domain test issue --- torchdata/datapipes/iter/transform/callable.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 53301ddab..c6f395a63 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -690,7 +690,8 @@ def __init__( scheduled_tasks: int = 64, max_workers: Optional[int] = None, **threadpool_kwargs, - ): + ) -> None: + super().__init__() self.datapipe = datapipe _check_unpickable_fn(fn) @@ -749,7 +750,7 @@ def _apply_fn(self, data): # Convert list back to tuple return tuple(data) if t_flag else data - def __iter__(self): + def __iter__(self) -> Iterator[T_co]: with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: futures_deque: deque = deque() has_next = True From 87c3bf81c05d2a7f9027268ec2a99df1fa6145a0 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sun, 5 Mar 2023 00:19:28 +0100 Subject: [PATCH 16/20] improve time to first element by hiding yield --- .../datapipes/iter/transform/callable.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index c6f395a63..3e2d1b85c 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -751,23 +751,31 @@ def _apply_fn(self, data): return tuple(data) if t_flag else data def __iter__(self) -> Iterator[T_co]: - with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: - futures_deque: deque = deque() - has_next = True - itr = iter(self.datapipe) - for _ in range(self.scheduled_tasks): - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - break + executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) + futures_deque: deque = deque() + has_next = True + itr = iter(self.datapipe) + for _ in range(self.scheduled_tasks): + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + break + + # Yield must be hidden in closure so that the futures are submitted + # before the first iterator value is required. + def result_iterator(executor): while len(futures_deque) > 0: + nonlocal has_next if has_next: try: futures_deque.append(executor.submit(self._apply_fn, next(itr))) except StopIteration: has_next = False yield futures_deque.popleft().result() + executor.shutdown() + + return result_iterator(executor) def __len__(self) -> int: if isinstance(self.datapipe, Sized): From e0e1111324b3216ccfb76886d2fce6fee27d09c6 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Sun, 5 Mar 2023 01:03:59 +0100 Subject: [PATCH 17/20] Small doc fixes --- torchdata/datapipes/iter/transform/callable.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 3e2d1b85c..b7493c2f4 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -619,7 +619,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): - Key is used for dict. New key is acceptable. scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 64) - max_workers: Maximum number of threads to execute function calls. (Default value: None) + max_workers: Maximum number of threads to execute function calls **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` Note: @@ -628,7 +628,8 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): Note: For optimal use of all threads, we recommend ``scheduled_tasks`` > ``max_workers``. High value of ``scheduled_tasks`` - might lead to long waiting period until the first element is yielded as tasks are executed out of order. + might lead to long waiting period until the first element is yielded as ``next`` is called + ``scheduled_tasks`` many times on ``source_datapipe`` before yielding. Example: @@ -683,7 +684,7 @@ def mul_ten(x): def __init__( self, - datapipe: IterDataPipe, + source_datapipe: IterDataPipe, fn: Callable, input_col=None, output_col=None, @@ -692,7 +693,7 @@ def __init__( **threadpool_kwargs, ) -> None: super().__init__() - self.datapipe = datapipe + self.datapipe = source_datapipe _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] From 1126b0eb4d5cfef8dae20e70fdbaab1a3ef60725 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Mon, 6 Mar 2023 09:41:15 +0100 Subject: [PATCH 18/20] Ensure shutdown of executor --- .../datapipes/iter/transform/callable.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index b7493c2f4..efec4a060 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -752,31 +752,37 @@ def _apply_fn(self, data): return tuple(data) if t_flag else data def __iter__(self) -> Iterator[T_co]: - executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) - futures_deque: deque = deque() - has_next = True - itr = iter(self.datapipe) - for _ in range(self.scheduled_tasks): - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - break - - # Yield must be hidden in closure so that the futures are submitted - # before the first iterator value is required. - def result_iterator(executor): - while len(futures_deque) > 0: - nonlocal has_next - if has_next: - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - yield futures_deque.popleft().result() + try: + executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) + futures_deque: deque = deque() + has_next = True + itr = iter(self.datapipe) + for _ in range(self.scheduled_tasks): + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + break + + # Yield must be hidden in closure so that the futures are submitted + # before the first iterator value is required. + def result_iterator(): + try: + while len(futures_deque) > 0: + nonlocal has_next + if has_next: + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + yield futures_deque.popleft().result() + finally: + executor.shutdown() + + return result_iterator() + except Exception: executor.shutdown() - - return result_iterator(executor) + raise def __len__(self) -> int: if isinstance(self.datapipe, Sized): From da1a27b9f231e5bd5ec519d1657f7ef5b4b792e0 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 9 Mar 2023 14:44:02 +0100 Subject: [PATCH 19/20] Do not hide yield --- .../datapipes/iter/transform/callable.py | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index efec4a060..0d44b47b6 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -752,8 +752,7 @@ def _apply_fn(self, data): return tuple(data) if t_flag else data def __iter__(self) -> Iterator[T_co]: - try: - executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) + with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: futures_deque: deque = deque() has_next = True itr = iter(self.datapipe) @@ -764,25 +763,13 @@ def __iter__(self) -> Iterator[T_co]: has_next = False break - # Yield must be hidden in closure so that the futures are submitted - # before the first iterator value is required. - def result_iterator(): - try: - while len(futures_deque) > 0: - nonlocal has_next - if has_next: - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - yield futures_deque.popleft().result() - finally: - executor.shutdown() - - return result_iterator() - except Exception: - executor.shutdown() - raise + while len(futures_deque) > 0: + if has_next: + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + yield futures_deque.popleft().result() def __len__(self) -> int: if isinstance(self.datapipe, Sized): From 7ac9015e8bc56714fb95da78fdea2f3057f8cd02 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Mon, 13 Mar 2023 17:47:11 +0100 Subject: [PATCH 20/20] Add suggestion to documentation --- torchdata/datapipes/iter/transform/callable.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 0d44b47b6..23c58a011 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -618,7 +618,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): - Integer is used for list/tuple. ``-1`` represents to append result at the end. - Key is used for dict. New key is acceptable. - scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 64) + scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 128) max_workers: Maximum number of threads to execute function calls **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` @@ -627,9 +627,16 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor Note: - For optimal use of all threads, we recommend ``scheduled_tasks`` > ``max_workers``. High value of ``scheduled_tasks`` - might lead to long waiting period until the first element is yielded as ``next`` is called - ``scheduled_tasks`` many times on ``source_datapipe`` before yielding. + For optimal use of all threads, ``scheduled_tasks`` > ``max_workers`` is strongly recommended. The higher the + variance of the time needed to finish execution of the given ``fn`` is, the higher the value + of ``scheduled_tasks`` needs to be to avoid threads sitting idle while waiting + for the next result (as results are returned in correct order). + + However, too high value of ``scheduled_tasks`` might lead to long waiting period until the first element is yielded + as ``next`` is called ``scheduled_tasks`` many times on ``source_datapipe`` before yielding. + + We encourage you to try out different values of ``max_workers`` and ``scheduled_tasks`` + in search for optimal values for your use-case. Example: @@ -688,7 +695,7 @@ def __init__( fn: Callable, input_col=None, output_col=None, - scheduled_tasks: int = 64, + scheduled_tasks: int = 128, max_workers: Optional[int] = None, **threadpool_kwargs, ) -> None: