Skip to content

Commit

Permalink
Add Tests and extract merge-with-result-function
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenDS9 committed Feb 28, 2023
1 parent 73fb8bd commit c35ad5d
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 62 deletions.
244 changes: 240 additions & 4 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchdata

from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls
from torch.testing._internal.common_utils import suppress_warnings

from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
from torchdata.datapipes.iter import (
Expand Down Expand Up @@ -81,12 +82,12 @@ def _convert_to_tensor(data):


async def _async_mul_ten(x):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
return x * 10


async def _async_x_mul_y(x, y):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
return x * y


Expand Down Expand Up @@ -290,7 +291,7 @@ def odd_even_bug(i: int) -> int:
self.assertEqual(len(source_dp), len(result_dp))

def test_prefetcher_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(50000))
source_dp = IterableWrapper(range(5000))
prefetched_dp = source_dp.prefetch(10)
# check if early termination resets child thread properly
for _, _ in zip(range(100), prefetched_dp):
Expand Down Expand Up @@ -1624,6 +1625,7 @@ def test_threadpool_map_batches(self):
batch_size = 16
target_length = 30
input_dp = IterableWrapper(range(target_length))
input_dp_parallel = IterableWrapper(range(target_length))

def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
Expand All @@ -1640,25 +1642,259 @@ def fn(item, dtype=torch.float, *, sum=False):
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

# __len__ Test: inherits length from source DataPipe
# this doesn't work atm
# __len__ Test: inherits length from source DataPipe
# self.assertEqual(target_length, len(map_dp))

input_dp_nl = IDP_NoLen(range(target_length))
map_dp_nl = input_dp_nl.thread_map_batches((lambda x: x), batch_size)
for x, y in zip(map_dp_nl, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# this doesn't work atm
# __len__ Test: inherits length from source DataPipe - raises error when invalid
# with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
# len(map_dp_nl)

# Test: two independent ThreadPoolExecutors running at the same time
map_dp_parallel = input_dp_parallel.thread_map_batches(fn, batch_size)
for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)):
self.assertEqual(x, torch.tensor(z, dtype=torch.float))
self.assertEqual(y, torch.tensor(z, dtype=torch.float))

# Reset Test: DataPipe resets properly
n_elements_before_reset = 5
res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset)
self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
self.assertEqual(list(range(target_length)), res_after_reset)

@suppress_warnings # Suppress warning for lambda fn
def test_threadpool_map_batches_tuple_list_with_col_iterdatapipe(self):
batch_size = 3

def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_sep_pos(d0, *args, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)
p_fn_cmplx = partial(fn_cmplx, d2=2)

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
if constr is list:
ref_dp = ref_dp.map(list)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

_helper(lambda data: data, fn_n1_def, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
_helper(lambda data: data, p_fn_n1, 0, 1)
_helper(lambda data: data, p_fn_cmplx, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
_helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])

# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
_helper(None, fn_cmplx, 0, 1, ValueError)
_helper(None, fn_n1_pos, 1, error=ValueError)
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
_helper(None, p_fn_n1, [0, 1], error=ValueError)
_helper(None, fn_1n, [1, 2], error=ValueError)
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
# Fn has keyword-only arguments
_helper(None, fn_n1_kwargs, 1, error=ValueError)
_helper(None, fn_cmplx, [0, 1], 2, ValueError)

# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])

# output_col can only be specified when input_col is not None
_helper(None, fn_n1, None, 1, error=ValueError)
# output_col can only be single-element list or tuple
_helper(None, fn_n1, None, [0, 1], error=ValueError)
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
_helper(None, fn_1n, 1, 3, error=IndexError)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)

# Appending the output at the end
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)

# Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
_helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
_helper(lambda data: (data[0], data[1], int(data[2])), int, 2)

@suppress_warnings # Suppress warning for lambda fn
def test_threadpool_map_batches_dict_with_col_iterdatapipe(self):
batch_size = 3

def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_kwonly(*, d0, d1):
return d0 + d1

def fn_has_nondefault_kwonly(d0, *, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_cmplx = partial(fn_cmplx, d2=2)

# Prevent modification in-place to support resetting
def _dict_update(data, newdata, remove_idx=None):
_data = dict(data)
_data.update(newdata)
if remove_idx:
for idx in remove_idx:
del _data[idx]
return _data

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}])
if ref_fn is None:
with self.assertRaises(error):
res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col)
list(res_dp)
else:
res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

_helper(lambda data: data, fn_n1_def, "x", "y")
_helper(lambda data: data, p_fn_n1, "x", "y")
_helper(lambda data: data, p_fn_cmplx, "x", "y")
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z")

_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z")

_helper(None, fn_n1_pos, "x", error=ValueError)
_helper(None, fn_n1_kwargs, "x", error=ValueError)
# non-default kw-only args
_helper(None, fn_kwonly, ["x", "y"], error=ValueError)
_helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
_helper(None, fn_cmplx, ["x", "y"], error=ValueError)

# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, "y", error=ValueError)
_helper(None, fn_1n, ["x", "y"], error=ValueError)
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(
lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]),
fn_nn,
["z", "y"],
)

# output_col can only be specified when input_col is not None
_helper(None, fn_n1, None, "x", error=ValueError)
# output_col can only be single-element list or tuple
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
_helper(
lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}),
fn_nn,
["y", "z"],
"x",
)

# Adding new key to dict for the output
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
_helper(
lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}),
fn_nn,
["y", "z"],
"a",
)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit c35ad5d

Please sign in to comment.