diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 39e8e3ac1f9..4153c1951ae 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -207,7 +207,7 @@ def test_listdirfiles_iterable_datapipe(self): self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files)) self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files)) - def test_loadfilesfromdisk_iterable_datapipe(self): + def test_readfilesfromdisk_iterable_datapipe(self): # test import datapipe class directly from torch.utils.data.datapipes.iter import ( FileLister, @@ -437,24 +437,39 @@ class TestFunctionalIterDataPipe(TestCase): def test_serializable(self): input_dp = dp.iter.IterableWrapper(range(10)) picklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ + (dp.iter.Batcher, (3, True,), {}), + (dp.iter.Collator, (_fake_fn,), {}), + (dp.iter.Concater, (dp.iter.IterableWrapper(range(5)),), {}), + (dp.iter.Demultiplexer, (2, _fake_filter_fn), {}), + (dp.iter.FileLister, (), {}), + (dp.iter.FileOpener, (), {}), + (dp.iter.Filter, (_fake_filter_fn,), {}), + (dp.iter.Filter, (partial(_fake_filter_fn_constant, 5),), {}), + (dp.iter.Forker, (2,), {}), + (dp.iter.Grouper, (_fake_filter_fn,), {"group_size": 2}), + (dp.iter.IterableWrapper, (), {}), (dp.iter.Mapper, (_fake_fn, ), {}), (dp.iter.Mapper, (partial(_fake_add, 1), ), {}), - (dp.iter.Collator, (_fake_fn, ), {}), - (dp.iter.Filter, (_fake_filter_fn, ), {}), - (dp.iter.Filter, (partial(_fake_filter_fn_constant, 5), ), {}), - (dp.iter.Demultiplexer, (2, _fake_filter_fn), {}), + (dp.iter.Multiplexer, (input_dp,), {}), + (dp.iter.Sampler, (), {}), + (dp.iter.Shuffler, (), {}), + (dp.iter.StreamReader, (), {}), + (dp.iter.UnBatcher, (), {}), + (dp.iter.Zipper, (input_dp,), {}), ] for dpipe, dp_args, dp_kwargs in picklable_datapipes: print(dpipe) _ = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] def test_serializable_with_dill(self): + """Only for DataPipes that take in a function or buffer as argument""" input_dp = dp.iter.IterableWrapper(range(10)) unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [ + (dp.iter.Collator, (lambda x: x,), {}), + (dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}), + (dp.iter.Filter, (lambda x: x >= 5,), {}), + # (dp.iter.Grouper, (lambda x: x >= 5,), {}), # TODO: Need custom __getstate__ for Grouper (dp.iter.Mapper, (lambda x: x, ), {}), - (dp.iter.Collator, (lambda x: x, ), {}), - (dp.iter.Filter, (lambda x: x >= 5, ), {}), - (dp.iter.Demultiplexer, (2, lambda x: x % 2, ), {}) ] if HAS_DILL: for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index fe2f24bca28..a39a5303268 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -13,7 +13,7 @@ class SamplerIterDataPipe(IterDataPipe[T_co]): Args: datapipe: IterDataPipe to sample from - sampler: Sampler class to genereate sample elements from input DataPipe. + sampler: Sampler class to generate sample elements from input DataPipe. Default is :class:`SequentialSampler` for IterDataPipe """ datapipe: IterDataPipe