diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index e8f624732..ecbb510aa 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -21,6 +21,7 @@ import torchdata from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls +from torch.testing._internal.common_utils import suppress_warnings from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration from torchdata.datapipes.iter import ( @@ -81,12 +82,12 @@ def _convert_to_tensor(data): async def _async_mul_ten(x): - await asyncio.sleep(1) + await asyncio.sleep(0.1) return x * 10 async def _async_x_mul_y(x, y): - await asyncio.sleep(1) + await asyncio.sleep(0.1) return x * y @@ -290,7 +291,7 @@ def odd_even_bug(i: int) -> int: self.assertEqual(len(source_dp), len(result_dp)) def test_prefetcher_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(50000)) + source_dp = IterableWrapper(range(5000)) prefetched_dp = source_dp.prefetch(10) # check if early termination resets child thread properly for _, _ in zip(range(100), prefetched_dp): @@ -1624,6 +1625,7 @@ def test_threadpool_map_batches(self): batch_size = 16 target_length = 30 input_dp = IterableWrapper(range(target_length)) + input_dp_parallel = IterableWrapper(range(target_length)) def fn(item, dtype=torch.float, *, sum=False): data = torch.tensor(item, dtype=dtype) @@ -1640,8 +1642,8 @@ def fn(item, dtype=torch.float, *, sum=False): for x, y in zip(map_dp, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) - # __len__ Test: inherits length from source DataPipe # this doesn't work atm + # __len__ Test: inherits length from source DataPipe # self.assertEqual(target_length, len(map_dp)) input_dp_nl = IDP_NoLen(range(target_length)) @@ -1649,16 +1651,250 @@ def fn(item, dtype=torch.float, *, sum=False): for x, y in zip(map_dp_nl, range(target_length)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) + # this doesn't work atm # __len__ Test: inherits length from source DataPipe - raises error when invalid # with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): # len(map_dp_nl) + # Test: two independent ThreadPoolExecutors running at the same time + map_dp_parallel = input_dp_parallel.thread_map_batches(fn, batch_size) + for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): + self.assertEqual(x, torch.tensor(z, dtype=torch.float)) + self.assertEqual(y, torch.tensor(z, dtype=torch.float)) + # Reset Test: DataPipe resets properly n_elements_before_reset = 5 res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) self.assertEqual(list(range(target_length)), res_after_reset) + @suppress_warnings # Suppress warning for lambda fn + def test_threadpool_map_batches_tuple_list_with_col_iterdatapipe(self): + batch_size = 3 + + def fn_11(d): + return -d + + def fn_1n(d): + return -d, d + + def fn_n1(d0, d1): + return d0 + d1 + + def fn_nn(d0, d1): + return -d0, -d1, d0 + d1 + + def fn_n1_def(d0, d1=1): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_sep_pos(d0, *args, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + p_fn_cmplx = partial(fn_cmplx, d2=2) + + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): + for constr in (list, tuple): + datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.map(fn, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + if constr is list: + ref_dp = ref_dp.map(list) + self.assertEqual(list(res_dp), list(ref_dp)) + # Reset + self.assertEqual(list(res_dp), list(ref_dp)) + + _helper(lambda data: data, fn_n1_def, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) + _helper(lambda data: data, p_fn_n1, 0, 1) + _helper(lambda data: data, p_fn_cmplx, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2) + _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) + + # Replacing with one input column and default output column + _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) + _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) + # The index of input column is out of range + _helper(None, fn_1n, 3, error=IndexError) + # Unmatched input columns with fn arguments + _helper(None, fn_n1, 1, error=ValueError) + _helper(None, fn_n1, [0, 1, 2], error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError) + _helper(None, fn_cmplx, 0, 1, ValueError) + _helper(None, fn_n1_pos, 1, error=ValueError) + _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) + _helper(None, p_fn_n1, [0, 1], error=ValueError) + _helper(None, fn_1n, [1, 2], error=ValueError) + # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) + _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) + # Fn has keyword-only arguments + _helper(None, fn_n1_kwargs, 1, error=ValueError) + _helper(None, fn_cmplx, [0, 1], 2, ValueError) + + # Replacing with multiple input columns and default output column (the left-most input column) + _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) + _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1]) + + # output_col can only be specified when input_col is not None + _helper(None, fn_n1, None, 1, error=ValueError) + # output_col can only be single-element list or tuple + _helper(None, fn_n1, None, [0, 1], error=ValueError) + # Single-element list as output_col + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) + # Replacing with one input column and single specified output column + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) + _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) + # The index of output column is out of range + _helper(None, fn_1n, 1, 3, error=IndexError) + _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) + _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) + + # Appending the output at the end + _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) + _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) + _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) + _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1) + + # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected + _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) + _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) + + @suppress_warnings # Suppress warning for lambda fn + def test_threadpool_map_batches_dict_with_col_iterdatapipe(self): + batch_size = 3 + + def fn_11(d): + return -d + + def fn_1n(d): + return -d, d + + def fn_n1(d0, d1): + return d0 + d1 + + def fn_nn(d0, d1): + return -d0, -d1, d0 + d1 + + def fn_n1_def(d0, d1=1): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_kwonly(*, d0, d1): + return d0 + d1 + + def fn_has_nondefault_kwonly(d0, *, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_cmplx = partial(fn_cmplx, d2=2) + + # Prevent modification in-place to support resetting + def _dict_update(data, newdata, remove_idx=None): + _data = dict(data) + _data.update(newdata) + if remove_idx: + for idx in remove_idx: + del _data[idx] + return _data + + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): + datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) + if ref_fn is None: + with self.assertRaises(error): + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + list(res_dp) + else: + res_dp = datapipe.thread_map_batches(fn, batch_size, input_col, output_col) + ref_dp = datapipe.map(ref_fn) + self.assertEqual(list(res_dp), list(ref_dp)) + # Reset + self.assertEqual(list(res_dp), list(ref_dp)) + + _helper(lambda data: data, fn_n1_def, "x", "y") + _helper(lambda data: data, p_fn_n1, "x", "y") + _helper(lambda data: data, p_fn_cmplx, "x", "y") + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z") + + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z") + + _helper(None, fn_n1_pos, "x", error=ValueError) + _helper(None, fn_n1_kwargs, "x", error=ValueError) + # non-default kw-only args + _helper(None, fn_kwonly, ["x", "y"], error=ValueError) + _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) + _helper(None, fn_cmplx, ["x", "y"], error=ValueError) + + # Replacing with one input column and default output column + _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") + _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") + # The key of input column is not in dict + _helper(None, fn_1n, "a", error=KeyError) + # Unmatched input columns with fn arguments + _helper(None, fn_n1, "y", error=ValueError) + _helper(None, fn_1n, ["x", "y"], error=ValueError) + _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) + _helper(None, p_fn_n1, ["x", "y"], error=ValueError) + _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) + # Replacing with multiple input columns and default output column (the left-most input column) + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) + _helper( + lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), + fn_nn, + ["z", "y"], + ) + + # output_col can only be specified when input_col is not None + _helper(None, fn_n1, None, "x", error=ValueError) + # output_col can only be single-element list or tuple + _helper(None, fn_n1, None, ["x", "y"], error=ValueError) + # Single-element list as output_col + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) + # Replacing with one input column and single specified output column + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") + _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z") + _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y") + _helper( + lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), + fn_nn, + ["y", "z"], + "x", + ) + + # Adding new key to dict for the output + _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") + _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a") + _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a") + _helper( + lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), + fn_nn, + ["y", "z"], + "a", + ) + if __name__ == "__main__": unittest.main() diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index c48548146..357604654 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -27,6 +27,35 @@ def _no_op_fn(*args): return args +def _merge_batch_with_result(orig_batch, results, input_col, output_col): + if input_col is None: + return results + + new_batch = [] + for data, res in zip(orig_batch, results): + t_flag = isinstance(data, tuple) + if t_flag: + data = list(data) + + if output_col is None: + if isinstance(input_col, (list, tuple)): + data[input_col[0]] = res + for idx in sorted(input_col[1:], reverse=True): + del data[idx] + else: + data[input_col] = res + elif output_col == -1: + data.append(res) + else: + data[output_col] = res + + if t_flag: + data = tuple(data) + + new_batch.append(data) + return new_batch + + @functional_datapipe("map_batches") class BatchMapperIterDataPipe(IterDataPipe[T_co]): r""" @@ -495,30 +524,7 @@ async def controlled_async_fn(async_fn, *data): else: coroutines.append(controlled_async_fn(self.async_fn, data[self.input_col])) results = await asyncio.gather(*coroutines) - - new_batch = [] - for data, res in zip(batch, results): - t_flag = isinstance(data, tuple) - if t_flag: - data = list(data) - - if self.output_col is None: - if isinstance(self.input_col, (list, tuple)): - data[self.input_col[0]] = res - for idx in sorted(self.input_col[1:], reverse=True): - del data[idx] - else: - data[self.input_col] = res - elif self.output_col == -1: - data.append(res) - else: - data[self.output_col] = res - - if t_flag: - data = tuple(data) - - new_batch.append(data) - return new_batch + return _merge_batch_with_result(batch, results, self.input_col, self.output_col) @functional_datapipe("async_map_batches") @@ -602,14 +608,21 @@ def __init__( **threadpool_kwargs, ): self.source_datapipe = source_datapipe - self.fn = fn # type: ignore[assignment] + self.fn = fn + self.input_col = input_col + validate_input_col(fn, input_col) if input_col is None and output_col is not None: raise ValueError("`output_col` must be None when `input_col` is None.") - self.input_col = input_col if isinstance(output_col, (list, tuple)): if len(output_col) > 1: raise ValueError("`output_col` must be a single-element list or tuple") output_col = output_col[0] + if isinstance(self.input_col, (list, tuple)): + + def wrapper_fn(args): + return fn(*args) + + self.fn = wrapper_fn self.output_col = output_col self.max_workers = max_workers self.threadpool_kwargs = threadpool_kwargs @@ -619,8 +632,7 @@ def __iter__(self): for batch in self.source_datapipe: prepared_batch = self.preparebatch(batch) results = executor.map(self.fn, prepared_batch) - return_batch = self.merge_batch_with_result(batch, results) - yield return_batch + yield _merge_batch_with_result(batch, results, self.input_col, self.output_col) def preparebatch(self, batch): if self.input_col is None: @@ -629,40 +641,12 @@ def preparebatch(self, batch): prepared_batch = [] for data in batch: if isinstance(self.input_col, (list, tuple)): - args = tuple(batch[col] for col in self.input_col) + args = tuple(data[col] for col in self.input_col) prepared_batch.append(args) else: prepared_batch.append(data[self.input_col]) return prepared_batch - def merge_batch_with_result(self, orig_batch, results): - if self.input_col is None: - return results - - new_batch = [] - for data, res in zip(orig_batch, results): - t_flag = isinstance(data, tuple) - if t_flag: - data = list(data) - - if self.output_col is None: - if isinstance(self.input_col, (list, tuple)): - data[self.input_col[0]] = res - for idx in sorted(self.input_col[1:], reverse=True): - del data[idx] - else: - data[self.input_col] = res - elif self.output_col == -1: - data.append(res) - else: - data[self.output_col] = res - - if t_flag: - data = tuple(data) - - new_batch.append(data) - return new_batch - @functional_datapipe("thread_map_batches") class BatchThreadPoolMapperIterDataPipe(IterDataPipe): @@ -760,5 +744,4 @@ def __new__( dp = source_datapipe.batch(batch_size) dp = _BatchThreadPoolMapperIterDataPipe(dp, fn, input_col, output_col, max_workers, **threadpool_kwargs) dp = dp.flatmap() - return dp