Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding mux_longest DataPipe #372

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ These tend to involve multiple DataPipes, combining them or splitting one to man
IterKeyZipper
MapKeyZipper
Multiplexer
MultiplexerLongest
SampleMultiplexer
UnZipper
Zipper
Expand Down
33 changes: 33 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,39 @@ def test_itertomap_mapdatapipe(self):
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Found duplicate key")

def test_mux_longest_iterdatapipe(self):

# Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
input_dp1 = IterableWrapper(range(4))
input_dp2 = IterableWrapper(range(4, 8))
input_dp3 = IterableWrapper(range(8, 12))
output_dp = input_dp1.mux_longest(input_dp2, input_dp3)
expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
self.assertEqual(len(expected_output), len(output_dp))
self.assertEqual(expected_output, list(output_dp))

# Functional Test: Uneven input Data Pipes
input_dp1 = IterableWrapper([1, 2, 3, 4])
input_dp2 = IterableWrapper([10])
input_dp3 = IterableWrapper([100, 200, 300])
output_dp = input_dp1.mux_longest(input_dp2, input_dp3)
expected_output = [1, 10, 100, 2, 200, 3, 300, 4]
self.assertEqual(len(expected_output), len(output_dp))
self.assertEqual(expected_output, list(output_dp))

# Functional Test: Empty Data Pipe
input_dp1 = IterableWrapper([0, 1, 2, 3])
input_dp2 = IterableWrapper([])
output_dp = input_dp1.mux_longest(input_dp2)
self.assertEqual(len(input_dp1), len(output_dp))
self.assertEqual(list(input_dp1), list(output_dp))

# __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__
input_dp1 = IterableWrapper(range(10))
input_dp_no_len = IDP_NoLen(range(10))
output_dp = input_dp1.mux_longest(input_dp_no_len)
with self.assertRaises(TypeError):
len(output_dp)

if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def test_serializable(self):
(SequenceWrapper({"a": 100, "b": 200, "c": 300}), itemgetter(0)),
{},
),
(
iterdp.MultiplexerLongest,
IterableWrapper(range(10)),
(),
{},
),
(iterdp.OnDiskCacheHolder, None, (), {}),
(iterdp.OnlineReader, None, (), {}),
(
Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
CSVParserIterDataPipe as CSVParser,
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.mux_longest import MultiplexerLongestIterDataPipe as MultiplexerLongest
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer
Expand Down Expand Up @@ -164,6 +165,7 @@
"Mapper",
"MaxTokenBucketizer",
"Multiplexer",
"MultiplexerLongest",
"OnDiskCacheHolder",
"OnlineReader",
"ParagraphAggregator",
Expand Down
53 changes: 53 additions & 0 deletions torchdata/datapipes/iter/util/mux_longest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from typing import Sized, Set, Optional


@functional_datapipe('mux_longest')
class MultiplexerLongestIterDataPipe(IterDataPipe):
r"""
Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux_longest``). As in,
one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
and so on. It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted.
Args:
datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> list(dp1.mux_longest(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, 4, 14, 24]
"""
def __init__(self, *datapipes):
self.datapipes = datapipes
self.length: Optional[int] = None

def __iter__(self):
iterators = [iter(x) for x in self.datapipes]
finished: Set[int] = set()
while len(finished) < len(iterators):
for i in range(len(iterators)):
if i not in finished:
try:
value = next(iterators[i])
yield value
except StopIteration:
finished.add(i)

def __len__(self):
if self.length is not None:
if self.length == -1:
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
return self.length
if all(isinstance(dp, Sized) for dp in self.datapipes):
self.length = sum(len(dp) for dp in self.datapipes)
else:
self.length = -1
return len(self)