Skip to content

Commit

Permalink
Update mux_longest data pipe (pytorch#372)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#372

OSS issue discussion: pytorch#346
This diff updates `mux_longest` data pipe.

`mux_longest`: Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux_longest``). As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, and so on. It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted. This is  same as current `MultiplexerIterDataPipe` in pytorch (https://github.com/pytorch/pytorch/blob/4fb7fa081e4fb5df3bf7bc85dcb9a3a9a3ac7133/torch/utils/data/datapipes/iter/combining.py#L375-L390)

`mux_longest` example:

```
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> list(dp1.mux_longest(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, 4, 14, 24]
```

Reviewed By: NivekT, ejguan

Differential Revision: D35805772

fbshipit-source-id: db629550c51a5cd9ac90ee77e9942686f995e079
  • Loading branch information
ninginthecloud authored and facebook-github-bot committed Apr 27, 2022
1 parent b6ade8f commit 7cf40fb
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ These tend to involve multiple DataPipes, combining them or splitting one to man
IterKeyZipper
MapKeyZipper
Multiplexer
MultiplexerLongest
SampleMultiplexer
UnZipper
Zipper
Expand Down
34 changes: 34 additions & 0 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,40 @@ def test_itertomap_mapdatapipe(self):
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Found duplicate key")

def test_mux_longest_iterdatapipe(self):

# Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
input_dp1 = IterableWrapper(range(4))
input_dp2 = IterableWrapper(range(4, 8))
input_dp3 = IterableWrapper(range(8, 12))
output_dp = input_dp1.mux_longest(input_dp2, input_dp3)
expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
self.assertEqual(len(expected_output), len(output_dp))
self.assertEqual(expected_output, list(output_dp))

# Functional Test: Uneven input Data Pipes
input_dp1 = IterableWrapper([1, 2, 3, 4])
input_dp2 = IterableWrapper([10])
input_dp3 = IterableWrapper([100, 200, 300])
output_dp = input_dp1.mux_longest(input_dp2, input_dp3)
expected_output = [1, 10, 100, 2, 200, 3, 300, 4]
self.assertEqual(len(expected_output), len(output_dp))
self.assertEqual(expected_output, list(output_dp))

# Functional Test: Empty Data Pipe
input_dp1 = IterableWrapper([0, 1, 2, 3])
input_dp2 = IterableWrapper([])
output_dp = input_dp1.mux_longest(input_dp2)
self.assertEqual(len(input_dp1), len(output_dp))
self.assertEqual(list(input_dp1), list(output_dp))

# __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__
input_dp1 = IterableWrapper(range(10))
input_dp_no_len = IDP_NoLen(range(10))
output_dp = input_dp1.mux_longest(input_dp_no_len)
with self.assertRaises(TypeError):
len(output_dp)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ def test_serializable(self):
(SequenceWrapper({"a": 100, "b": 200, "c": 300}), itemgetter(0)),
{},
),
(
iterdp.MultiplexerLongest,
IterableWrapper(range(10)),
(),
{},
),
(iterdp.OnDiskCacheHolder, None, (), {}),
(iterdp.OnlineReader, None, (), {}),
(
Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
CSVParserIterDataPipe as CSVParser,
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.mux_longest import MultiplexerLongestIterDataPipe as MultiplexerLongest
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer
Expand Down Expand Up @@ -161,6 +162,7 @@
"Mapper",
"MaxTokenBucketizer",
"Multiplexer",
"MultiplexerLongest",
"OnDiskCacheHolder",
"OnlineReader",
"ParagraphAggregator",
Expand Down
53 changes: 53 additions & 0 deletions torchdata/datapipes/iter/util/mux_longest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from typing import Sized, Set, Optional


@functional_datapipe('mux_longest')
class MultiplexerLongestIterDataPipe(IterDataPipe):
r"""
Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux_longest``). As in,
one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
and so on. It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted.
Args:
datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> list(dp1.mux_longest(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, 4, 14, 24]
"""
def __init__(self, *datapipes):
self.datapipes = datapipes
self.length: Optional[int] = None

def __iter__(self):
iterators = [iter(x) for x in self.datapipes]
finished: Set[int] = set()
while len(finished) < len(iterators):
for i in range(len(iterators)):
if i not in finished:
try:
value = next(iterators[i])
yield value
except StopIteration:
finished.add(i)

def __len__(self):
if self.length is not None:
if self.length == -1:
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
return self.length
if all(isinstance(dp, Sized) for dp in self.datapipes):
self.length = sum(len(dp) for dp in self.datapipes)
else:
self.length = -1
return len(self)

0 comments on commit 7cf40fb

Please sign in to comment.