Skip to content

Commit

Permalink
Add ThreadPoolMapper (#1052)
Browse files Browse the repository at this point in the history
Summary:
Fixes #1045

### Changes

- Add ThreadPoolMapper datapipe
- Add tests

Pull Request resolved: #1052

Reviewed By: NivekT

Differential Revision: D44033894

Pulled By: ejguan

fbshipit-source-id: 608b70a857e4610fc6616a53711c706207ce696a
  • Loading branch information
SvenDS9 authored and facebook-github-bot committed Mar 14, 2023
1 parent f1283eb commit fea20d4
Show file tree
Hide file tree
Showing 5 changed files with 474 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ These DataPipes apply the a given function to each element in the DataPipe.
BatchMapper
FlatMapper
Mapper
ThreadPoolMapper

Other DataPipes
-------------------------
Expand Down
275 changes: 272 additions & 3 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings

from collections import defaultdict
from functools import partial
from typing import Dict

import expecttest
Expand All @@ -20,6 +21,7 @@
import torchdata

from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls
from torch.testing._internal.common_utils import suppress_warnings

from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
from torchdata.datapipes.iter import (
Expand Down Expand Up @@ -80,12 +82,12 @@ def _convert_to_tensor(data):


async def _async_mul_ten(x):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
return x * 10


async def _async_x_mul_y(x, y):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
return x * y


Expand Down Expand Up @@ -289,7 +291,7 @@ def odd_even_bug(i: int) -> int:
self.assertEqual(len(source_dp), len(result_dp))

def test_prefetcher_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(50000))
source_dp = IterableWrapper(range(5000))
prefetched_dp = source_dp.prefetch(10)
# check if early termination resets child thread properly
for _, _ in zip(range(100), prefetched_dp):
Expand Down Expand Up @@ -1619,6 +1621,273 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_
self.assertEqual(v1, exp)
self.assertEqual(v2, exp)

def test_threadpool_map(self):
target_length = 30
input_dp = IterableWrapper(range(target_length))
input_dp_parallel = IterableWrapper(range(target_length))

def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
return data if not sum else data.sum()

# Functional Test: apply to each element correctly
map_dp = input_dp.threadpool_map(fn)
self.assertEqual(target_length, len(map_dp))
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# Functional Test: works with partial function
map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True))
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

# __len__ Test: inherits length from source DataPipe
self.assertEqual(target_length, len(map_dp))

input_dp_nl = IDP_NoLen(range(target_length))
map_dp_nl = input_dp_nl.threadpool_map(lambda x: x)
for x, y in zip(map_dp_nl, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# __len__ Test: inherits length from source DataPipe - raises error when invalid
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(map_dp_nl)

# Test: two independent ThreadPoolExecutors running at the same time
map_dp_parallel = input_dp_parallel.threadpool_map(fn)
for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)):
self.assertEqual(x, torch.tensor(z, dtype=torch.float))
self.assertEqual(y, torch.tensor(z, dtype=torch.float))

# Reset Test: DataPipe resets properly
n_elements_before_reset = 5
res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset)
self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
self.assertEqual(list(range(target_length)), res_after_reset)

@suppress_warnings # Suppress warning for lambda fn
def test_threadpool_map_tuple_list_with_col_iterdatapipe(self):
def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_sep_pos(d0, *args, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)
p_fn_cmplx = partial(fn_cmplx, d2=2)

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
if constr is list:
ref_dp = ref_dp.map(list)
self.assertEqual(list(res_dp), list(ref_dp), "First test failed")
# Reset
self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed")

_helper(lambda data: data, fn_n1_def, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
_helper(lambda data: data, p_fn_n1, 0, 1)
_helper(lambda data: data, p_fn_cmplx, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
_helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])

# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
_helper(None, fn_cmplx, 0, 1, ValueError)
_helper(None, fn_n1_pos, 1, error=ValueError)
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
_helper(None, p_fn_n1, [0, 1], error=ValueError)
_helper(None, fn_1n, [1, 2], error=ValueError)
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
# Fn has keyword-only arguments
_helper(None, fn_n1_kwargs, 1, error=ValueError)
_helper(None, fn_cmplx, [0, 1], 2, ValueError)

# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])

# output_col can only be specified when input_col is not None
_helper(None, fn_n1, None, 1, error=ValueError)
# output_col can only be single-element list or tuple
_helper(None, fn_n1, None, [0, 1], error=ValueError)
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
_helper(None, fn_1n, 1, 3, error=IndexError)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)

# Appending the output at the end
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)

# Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
_helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
_helper(lambda data: (data[0], data[1], int(data[2])), int, 2)

@suppress_warnings # Suppress warning for lambda fn
def test_threadpool_map_dict_with_col_iterdatapipe(self):
def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_kwonly(*, d0, d1):
return d0 + d1

def fn_has_nondefault_kwonly(d0, *, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_cmplx = partial(fn_cmplx, d2=2)

# Prevent modification in-place to support resetting
def _dict_update(data, newdata, remove_idx=None):
_data = dict(data)
_data.update(newdata)
if remove_idx:
for idx in remove_idx:
del _data[idx]
return _data

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}])
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp), "First test failed")
# Reset
self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed")

_helper(lambda data: data, fn_n1_def, "x", "y")
_helper(lambda data: data, p_fn_n1, "x", "y")
_helper(lambda data: data, p_fn_cmplx, "x", "y")
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z")

_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z")

_helper(None, fn_n1_pos, "x", error=ValueError)
_helper(None, fn_n1_kwargs, "x", error=ValueError)
# non-default kw-only args
_helper(None, fn_kwonly, ["x", "y"], error=ValueError)
_helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
_helper(None, fn_cmplx, ["x", "y"], error=ValueError)

# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, "y", error=ValueError)
_helper(None, fn_1n, ["x", "y"], error=ValueError)
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(
lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]),
fn_nn,
["z", "y"],
)

# output_col can only be specified when input_col is not None
_helper(None, fn_n1, None, "x", error=ValueError)
# output_col can only be single-element list or tuple
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
_helper(
lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}),
fn_nn,
["y", "z"],
"x",
)

# Adding new key to dict for the output
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
_helper(
lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}),
fn_nn,
["y", "z"],
"a",
)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_serializable(self):
(iterdp.TarArchiveLoader, None, (), {}),
# TODO(594): Add serialization tests for optional DataPipe
# (iterdp.TFRecordLoader, None, (), {}),
(iterdp.ThreadPoolMapper, None, (_fake_fn_ls,), {}),
(iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}),
(iterdp.WebDataset, IterableWrapper([("foo.txt", b"1"), ("bar.txt", b"2")]), (), {}),
(iterdp.XzFileLoader, None, (), {}),
Expand Down Expand Up @@ -366,6 +367,7 @@ def test_serializable_with_dill(self):
(iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}),
(iterdp.OnDiskCacheHolder, (lambda x: x,), {}),
(iterdp.ParagraphAggregator, (lambda x: x,), {}),
(iterdp.ThreadPoolMapper, (lambda x: x,), {}),
]
# Skipping value comparison for these DataPipes
dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator}
Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
FlatMapperIterDataPipe as FlatMapper,
FlattenIterDataPipe as Flattener,
SliceIterDataPipe as Slicer,
ThreadPoolMapperIterDataPipe as ThreadPoolMapper,
)
from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader
from torchdata.datapipes.iter.util.cacheholder import (
Expand Down Expand Up @@ -213,6 +214,7 @@
"StreamReader",
"TFRecordLoader",
"TarArchiveLoader",
"ThreadPoolMapper",
"UnBatcher",
"UnZipper",
"WebDataset",
Expand Down
Loading

0 comments on commit fea20d4

Please sign in to comment.