diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index c591b924c..64f539f26 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -261,6 +261,11 @@ def odd_even_bug(i: int) -> int: with self.assertRaisesRegex(KeyError, "is not a valid key in the given MapDataPipe"): next(it) + # Functional test: ensure that keep_key option works + result_dp = source_dp.zip_with_map(map_dp, odd_even, keep_key=True) + expected_res_keep_key = [(key, (i, odd_even_string(i))) for i, key in zip(range(10), [0, 1] * 5)] + self.assertEqual(expected_res_keep_key, list(result_dp)) + # Reset Test: n_elements_before_reset = 4 result_dp = source_dp.zip_with_map(map_dp, odd_even) diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index 27a7ee6fa..bd9d70769 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -174,20 +174,33 @@ class MapKeyZipperIterDataPipe(IterDataPipe[T_co]): from ``map_datapipe`` map_datapipe: MapDataPipe that takes a key from ``key_fn``, and returns an item key_fn: Function that maps each item from ``source_iterdatapipe`` to a key that exists in ``map_datapipe`` + keep_key: Option to yield the matching key along with the items in a tuple, + resulting in ``(key, merge_fn(item1, item2))``. merge_fn: Function that combines the item from ``source_iterdatapipe`` and the matching item from ``map_datapipe``, by default a tuple is created Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> from torchdata.datapipes.map import SequenceWrapper - >>> from operator import itemgetter - >>> def merge_fn(tuple_from_iter, value_from_map): - >>> return tuple_from_iter[0], tuple_from_iter[1] + value_from_map - >>> dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)]) - >>> mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) - >>> res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn) - >>> list(res_dp) + + .. testsetup:: + + from operator import itemgetter + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper + from torchdata.datapipes.map import SequenceWrapper + + def merge_fn(tuple_from_iter, value_from_map): + return tuple_from_iter[0], tuple_from_iter[1] + value_from_map + dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)]) + mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) + res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn) + print(list(res_dp)) + + .. testoutput:: + [('a', 101), ('b', 202), ('c', 303)] + """ def __init__( @@ -196,6 +209,7 @@ def __init__( map_datapipe: MapDataPipe, key_fn: Callable, merge_fn: Optional[Callable] = None, + keep_key: bool = False, ): if not isinstance(map_datapipe, MapDataPipe): raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.") @@ -206,6 +220,7 @@ def __init__( if merge_fn is not None: _check_unpickable_fn(merge_fn) self.merge_fn: Optional[Callable] = merge_fn + self.keep_key = keep_key def __iter__(self) -> Iterator: for item in self.source_iterdatapipe: @@ -214,7 +229,11 @@ def __iter__(self) -> Iterator: map_item = self.map_datapipe[key] except (KeyError, IndexError): raise KeyError(f"key_fn maps {item} to {key}, which is not a valid key in the given MapDataPipe.") - yield self.merge_fn(item, map_item) if self.merge_fn else (item, map_item) + res = self.merge_fn(item, map_item) if self.merge_fn else (item, map_item) + if self.keep_key: + yield key, res + else: + yield res def __len__(self) -> int: return len(self.source_iterdatapipe)