diff --git a/docs/source/torchdata.datapipes.iter.rst b/docs/source/torchdata.datapipes.iter.rst index b9817a8b9..fdbb1009c 100644 --- a/docs/source/torchdata.datapipes.iter.rst +++ b/docs/source/torchdata.datapipes.iter.rst @@ -97,6 +97,7 @@ These tend to involve multiple DataPipes, combining them or splitting one to man IterKeyZipper MapKeyZipper Multiplexer + MultiplexerLongest SampleMultiplexer UnZipper Zipper diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 52584aba6..8bd9c6dc5 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -868,6 +868,39 @@ def test_itertomap_mapdatapipe(self): self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Found duplicate key") + def test_mux_longest_iterdatapipe(self): + + # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted + input_dp1 = IterableWrapper(range(4)) + input_dp2 = IterableWrapper(range(4, 8)) + input_dp3 = IterableWrapper(range(8, 12)) + output_dp = input_dp1.mux_longest(input_dp2, input_dp3) + expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] + self.assertEqual(len(expected_output), len(output_dp)) + self.assertEqual(expected_output, list(output_dp)) + + # Functional Test: Uneven input Data Pipes + input_dp1 = IterableWrapper([1, 2, 3, 4]) + input_dp2 = IterableWrapper([10]) + input_dp3 = IterableWrapper([100, 200, 300]) + output_dp = input_dp1.mux_longest(input_dp2, input_dp3) + expected_output = [1, 10, 100, 2, 200, 3, 300, 4] + self.assertEqual(len(expected_output), len(output_dp)) + self.assertEqual(expected_output, list(output_dp)) + + # Functional Test: Empty Data Pipe + input_dp1 = IterableWrapper([0, 1, 2, 3]) + input_dp2 = IterableWrapper([]) + output_dp = input_dp1.mux_longest(input_dp2) + self.assertEqual(len(input_dp1), len(output_dp)) + self.assertEqual(list(input_dp1), list(output_dp)) + + # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__ + input_dp1 = IterableWrapper(range(10)) + input_dp_no_len = IDP_NoLen(range(10)) + output_dp = input_dp1.mux_longest(input_dp_no_len) + with self.assertRaises(TypeError): + len(output_dp) if __name__ == "__main__": unittest.main() diff --git a/test/test_serialization.py b/test/test_serialization.py index f66e319d3..2a7f1c0d4 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -234,6 +234,12 @@ def test_serializable(self): (SequenceWrapper({"a": 100, "b": 200, "c": 300}), itemgetter(0)), {}, ), + ( + iterdp.MultiplexerLongest, + IterableWrapper(range(10)), + (), + {}, + ), (iterdp.OnDiskCacheHolder, None, (), {}), (iterdp.OnlineReader, None, (), {}), ( diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index c76f1ef84..6aa3f43a7 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -94,6 +94,7 @@ CSVParserIterDataPipe as CSVParser, LineReaderIterDataPipe as LineReader, ) +from torchdata.datapipes.iter.util.mux_longest import MultiplexerLongestIterDataPipe as MultiplexerLongest from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer @@ -164,6 +165,7 @@ "Mapper", "MaxTokenBucketizer", "Multiplexer", + "MultiplexerLongest", "OnDiskCacheHolder", "OnlineReader", "ParagraphAggregator", diff --git a/torchdata/datapipes/iter/util/mux_longest.py b/torchdata/datapipes/iter/util/mux_longest.py new file mode 100644 index 000000000..29ae5d368 --- /dev/null +++ b/torchdata/datapipes/iter/util/mux_longest.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from typing import Sized, Set, Optional + + +@functional_datapipe('mux_longest') +class MultiplexerLongestIterDataPipe(IterDataPipe): + r""" + Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux_longest``). As in, + one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, + and so on. It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted. + + Args: + datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted + + Example: + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> list(dp1.mux_longest(dp2, dp3)) + [0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, 4, 14, 24] + """ + def __init__(self, *datapipes): + self.datapipes = datapipes + self.length: Optional[int] = None + + def __iter__(self): + iterators = [iter(x) for x in self.datapipes] + finished: Set[int] = set() + while len(finished) < len(iterators): + for i in range(len(iterators)): + if i not in finished: + try: + value = next(iterators[i]) + yield value + except StopIteration: + finished.add(i) + + def __len__(self): + if self.length is not None: + if self.length == -1: + raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) + return self.length + if all(isinstance(dp, Sized) for dp in self.datapipes): + self.length = sum(len(dp) for dp in self.datapipes) + else: + self.length = -1 + return len(self)