Skip to content

Commit

Permalink
keep_key in MapKeyZipper (#1042)
Browse files Browse the repository at this point in the history
Summary:
As mentioned in #256 it would be useful to have this in all dp that use a `key_fn`.

### Changes

- Add keep_key option to `MapKeyZipper`
- Test for the new option
- Improve example in documentation by using sphinx doctest

Pull Request resolved: #1042

Reviewed By: ejguan

Differential Revision: D43550124

Pulled By: NivekT

fbshipit-source-id: 453418560dd586ebb9ce42dd27eb1bc86eb734d0
  • Loading branch information
SvenDS9 authored and NivekT committed Feb 28, 2023
1 parent 106c532 commit b90b467
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
5 changes: 5 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def odd_even_bug(i: int) -> int:
with self.assertRaisesRegex(KeyError, "is not a valid key in the given MapDataPipe"):
next(it)

# Functional test: ensure that keep_key option works
result_dp = source_dp.zip_with_map(map_dp, odd_even, keep_key=True)
expected_res_keep_key = [(key, (i, odd_even_string(i))) for i, key in zip(range(10), [0, 1] * 5)]
self.assertEqual(expected_res_keep_key, list(result_dp))

# Reset Test:
n_elements_before_reset = 4
result_dp = source_dp.zip_with_map(map_dp, odd_even)
Expand Down
39 changes: 29 additions & 10 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,33 @@ class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
from ``map_datapipe``
map_datapipe: MapDataPipe that takes a key from ``key_fn``, and returns an item
key_fn: Function that maps each item from ``source_iterdatapipe`` to a key that exists in ``map_datapipe``
keep_key: Option to yield the matching key along with the items in a tuple,
resulting in ``(key, merge_fn(item1, item2))``.
merge_fn: Function that combines the item from ``source_iterdatapipe`` and the matching item
from ``map_datapipe``, by default a tuple is created
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from torchdata.datapipes.map import SequenceWrapper
>>> from operator import itemgetter
>>> def merge_fn(tuple_from_iter, value_from_map):
>>> return tuple_from_iter[0], tuple_from_iter[1] + value_from_map
>>> dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)])
>>> mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
>>> res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn)
>>> list(res_dp)
.. testsetup::
from operator import itemgetter
.. testcode::
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.map import SequenceWrapper
def merge_fn(tuple_from_iter, value_from_map):
return tuple_from_iter[0], tuple_from_iter[1] + value_from_map
dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)])
mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn)
print(list(res_dp))
.. testoutput::
[('a', 101), ('b', 202), ('c', 303)]
"""

def __init__(
Expand All @@ -196,6 +209,7 @@ def __init__(
map_datapipe: MapDataPipe,
key_fn: Callable,
merge_fn: Optional[Callable] = None,
keep_key: bool = False,
):
if not isinstance(map_datapipe, MapDataPipe):
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
Expand All @@ -206,6 +220,7 @@ def __init__(
if merge_fn is not None:
_check_unpickable_fn(merge_fn)
self.merge_fn: Optional[Callable] = merge_fn
self.keep_key = keep_key

def __iter__(self) -> Iterator:
for item in self.source_iterdatapipe:
Expand All @@ -214,7 +229,11 @@ def __iter__(self) -> Iterator:
map_item = self.map_datapipe[key]
except (KeyError, IndexError):
raise KeyError(f"key_fn maps {item} to {key}, which is not a valid key in the given MapDataPipe.")
yield self.merge_fn(item, map_item) if self.merge_fn else (item, map_item)
res = self.merge_fn(item, map_item) if self.merge_fn else (item, map_item)
if self.keep_key:
yield key, res
else:
yield res

def __len__(self) -> int:
return len(self.source_iterdatapipe)
Expand Down

0 comments on commit b90b467

Please sign in to comment.