Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ThreadPoolMapper #1052

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ These DataPipes apply the a given function to each element in the DataPipe.
BatchMapper
FlatMapper
Mapper
ThreadPoolMapper

Other DataPipes
-------------------------
Expand Down
275 changes: 272 additions & 3 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings

from collections import defaultdict
from functools import partial
from typing import Dict

import expecttest
Expand All @@ -20,6 +21,7 @@
import torchdata

from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls
from torch.testing._internal.common_utils import suppress_warnings

from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
from torchdata.datapipes.iter import (
Expand Down Expand Up @@ -80,12 +82,12 @@ def _convert_to_tensor(data):


async def _async_mul_ten(x):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
return x * 10


async def _async_x_mul_y(x, y):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
return x * y


Expand Down Expand Up @@ -289,7 +291,7 @@ def odd_even_bug(i: int) -> int:
self.assertEqual(len(source_dp), len(result_dp))

def test_prefetcher_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(50000))
source_dp = IterableWrapper(range(5000))
prefetched_dp = source_dp.prefetch(10)
# check if early termination resets child thread properly
for _, _ in zip(range(100), prefetched_dp):
Expand Down Expand Up @@ -1619,6 +1621,273 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_
self.assertEqual(v1, exp)
self.assertEqual(v2, exp)

def test_threadpool_map(self):
target_length = 30
input_dp = IterableWrapper(range(target_length))
input_dp_parallel = IterableWrapper(range(target_length))

def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
return data if not sum else data.sum()

# Functional Test: apply to each element correctly
map_dp = input_dp.threadpool_map(fn)
self.assertEqual(target_length, len(map_dp))
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# Functional Test: works with partial function
map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True))
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

# __len__ Test: inherits length from source DataPipe
self.assertEqual(target_length, len(map_dp))

input_dp_nl = IDP_NoLen(range(target_length))
map_dp_nl = input_dp_nl.threadpool_map(lambda x: x)
for x, y in zip(map_dp_nl, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# __len__ Test: inherits length from source DataPipe - raises error when invalid
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(map_dp_nl)

# Test: two independent ThreadPoolExecutors running at the same time
map_dp_parallel = input_dp_parallel.threadpool_map(fn)
for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)):
self.assertEqual(x, torch.tensor(z, dtype=torch.float))
self.assertEqual(y, torch.tensor(z, dtype=torch.float))

# Reset Test: DataPipe resets properly
n_elements_before_reset = 5
res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset)
self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
self.assertEqual(list(range(target_length)), res_after_reset)

@suppress_warnings # Suppress warning for lambda fn
def test_threadpool_map_tuple_list_with_col_iterdatapipe(self):
def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_sep_pos(d0, *args, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)
p_fn_cmplx = partial(fn_cmplx, d2=2)

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
if constr is list:
ref_dp = ref_dp.map(list)
SvenDS9 marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(list(res_dp), list(ref_dp), "First test failed")
# Reset
self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed")

_helper(lambda data: data, fn_n1_def, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
_helper(lambda data: data, p_fn_n1, 0, 1)
_helper(lambda data: data, p_fn_cmplx, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
_helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])

# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
_helper(None, fn_cmplx, 0, 1, ValueError)
_helper(None, fn_n1_pos, 1, error=ValueError)
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
_helper(None, p_fn_n1, [0, 1], error=ValueError)
_helper(None, fn_1n, [1, 2], error=ValueError)
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
# Fn has keyword-only arguments
_helper(None, fn_n1_kwargs, 1, error=ValueError)
_helper(None, fn_cmplx, [0, 1], 2, ValueError)

# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])

# output_col can only be specified when input_col is not None
_helper(None, fn_n1, None, 1, error=ValueError)
# output_col can only be single-element list or tuple
_helper(None, fn_n1, None, [0, 1], error=ValueError)
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
_helper(None, fn_1n, 1, 3, error=IndexError)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)

# Appending the output at the end
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)

# Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
_helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
_helper(lambda data: (data[0], data[1], int(data[2])), int, 2)

@suppress_warnings # Suppress warning for lambda fn
def test_threadpool_map_dict_with_col_iterdatapipe(self):
def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_kwonly(*, d0, d1):
return d0 + d1

def fn_has_nondefault_kwonly(d0, *, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_cmplx = partial(fn_cmplx, d2=2)

# Prevent modification in-place to support resetting
def _dict_update(data, newdata, remove_idx=None):
_data = dict(data)
_data.update(newdata)
if remove_idx:
for idx in remove_idx:
del _data[idx]
return _data

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}])
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp), "First test failed")
# Reset
self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed")

_helper(lambda data: data, fn_n1_def, "x", "y")
_helper(lambda data: data, p_fn_n1, "x", "y")
_helper(lambda data: data, p_fn_cmplx, "x", "y")
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z")

_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z")

_helper(None, fn_n1_pos, "x", error=ValueError)
_helper(None, fn_n1_kwargs, "x", error=ValueError)
# non-default kw-only args
_helper(None, fn_kwonly, ["x", "y"], error=ValueError)
_helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
_helper(None, fn_cmplx, ["x", "y"], error=ValueError)

# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, "y", error=ValueError)
_helper(None, fn_1n, ["x", "y"], error=ValueError)
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(
lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]),
fn_nn,
["z", "y"],
)

# output_col can only be specified when input_col is not None
_helper(None, fn_n1, None, "x", error=ValueError)
# output_col can only be single-element list or tuple
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
_helper(
lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}),
fn_nn,
["y", "z"],
"x",
)

# Adding new key to dict for the output
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
_helper(
lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}),
fn_nn,
["y", "z"],
"a",
)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_serializable(self):
(iterdp.TarArchiveLoader, None, (), {}),
# TODO(594): Add serialization tests for optional DataPipe
# (iterdp.TFRecordLoader, None, (), {}),
(iterdp.ThreadPoolMapper, None, (_fake_fn_ls,), {}),
(iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}),
(iterdp.WebDataset, IterableWrapper([("foo.txt", b"1"), ("bar.txt", b"2")]), (), {}),
(iterdp.XzFileLoader, None, (), {}),
Expand Down Expand Up @@ -366,6 +367,7 @@ def test_serializable_with_dill(self):
(iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}),
(iterdp.OnDiskCacheHolder, (lambda x: x,), {}),
(iterdp.ParagraphAggregator, (lambda x: x,), {}),
(iterdp.ThreadPoolMapper, (lambda x: x,), {}),
]
# Skipping value comparison for these DataPipes
dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator}
Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
FlatMapperIterDataPipe as FlatMapper,
FlattenIterDataPipe as Flattener,
SliceIterDataPipe as Slicer,
ThreadPoolMapperIterDataPipe as ThreadPoolMapper,
)
from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader
from torchdata.datapipes.iter.util.cacheholder import (
Expand Down Expand Up @@ -213,6 +214,7 @@
"StreamReader",
"TFRecordLoader",
"TarArchiveLoader",
"ThreadPoolMapper",
"UnBatcher",
"UnZipper",
"WebDataset",
Expand Down
Loading