Skip to content

Commit

Permalink
[DataPipe] adding serialization test for all core IterDataPipes (#71456)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#71456

Related to pytorch/data#172

cc VitalyFedyunin ejguan NivekT

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D33668748

Pulled By: NivekT

fbshipit-source-id: ea2085d5ed47533ca49258cc52471373c6ae1847
(cherry picked from commit d5f6fde)
  • Loading branch information
NivekT authored and cyyever committed Feb 9, 2022
1 parent 31ce477 commit 99894ae
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
31 changes: 23 additions & 8 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_listdirfiles_iterable_datapipe(self):
self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files))
self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))

def test_loadfilesfromdisk_iterable_datapipe(self):
def test_readfilesfromdisk_iterable_datapipe(self):
# test import datapipe class directly
from torch.utils.data.datapipes.iter import (
FileLister,
Expand Down Expand Up @@ -437,24 +437,39 @@ class TestFunctionalIterDataPipe(TestCase):
def test_serializable(self):
input_dp = dp.iter.IterableWrapper(range(10))
picklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Batcher, (3, True,), {}),
(dp.iter.Collator, (_fake_fn,), {}),
(dp.iter.Concater, (dp.iter.IterableWrapper(range(5)),), {}),
(dp.iter.Demultiplexer, (2, _fake_filter_fn), {}),
(dp.iter.FileLister, (), {}),
(dp.iter.FileOpener, (), {}),
(dp.iter.Filter, (_fake_filter_fn,), {}),
(dp.iter.Filter, (partial(_fake_filter_fn_constant, 5),), {}),
(dp.iter.Forker, (2,), {}),
(dp.iter.Grouper, (_fake_filter_fn,), {"group_size": 2}),
(dp.iter.IterableWrapper, (), {}),
(dp.iter.Mapper, (_fake_fn, ), {}),
(dp.iter.Mapper, (partial(_fake_add, 1), ), {}),
(dp.iter.Collator, (_fake_fn, ), {}),
(dp.iter.Filter, (_fake_filter_fn, ), {}),
(dp.iter.Filter, (partial(_fake_filter_fn_constant, 5), ), {}),
(dp.iter.Demultiplexer, (2, _fake_filter_fn), {}),
(dp.iter.Multiplexer, (input_dp,), {}),
(dp.iter.Sampler, (), {}),
(dp.iter.Shuffler, (), {}),
(dp.iter.StreamReader, (), {}),
(dp.iter.UnBatcher, (), {}),
(dp.iter.Zipper, (input_dp,), {}),
]
for dpipe, dp_args, dp_kwargs in picklable_datapipes:
print(dpipe)
_ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg]

def test_serializable_with_dill(self):
"""Only for DataPipes that take in a function or buffer as argument"""
input_dp = dp.iter.IterableWrapper(range(10))
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda x: x,), {}),
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
(dp.iter.Filter, (lambda x: x >= 5,), {}),
# (dp.iter.Grouper, (lambda x: x >= 5,), {}), # TODO: Need custom __getstate__ for Grouper
(dp.iter.Mapper, (lambda x: x, ), {}),
(dp.iter.Collator, (lambda x: x, ), {}),
(dp.iter.Filter, (lambda x: x >= 5, ), {}),
(dp.iter.Demultiplexer, (2, lambda x: x % 2, ), {})
]
if HAS_DILL:
for dpipe, dp_args, dp_kwargs in unpicklable_datapipes:
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/datapipes/iter/combinatorics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SamplerIterDataPipe(IterDataPipe[T_co]):
Args:
datapipe: IterDataPipe to sample from
sampler: Sampler class to genereate sample elements from input DataPipe.
sampler: Sampler class to generate sample elements from input DataPipe.
Default is :class:`SequentialSampler` for IterDataPipe
"""
datapipe: IterDataPipe
Expand Down

0 comments on commit 99894ae

Please sign in to comment.