From 9707a13a61bb47d6185acef3bfb1e481edd965ef Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Tue, 14 Feb 2023 06:42:42 -0800 Subject: [PATCH] Officially graduate ProtypeMPRS to MPRS (#1009) Summary: Pull Request resolved: https://github.com/pytorch/data/pull/1009 Fixes: https://github.com/pytorch/data/issues/932 - Convert all references from `ProtypeMPRS` to `MPRS` - Remove usage of `MPRS` Reviewed By: wenleix, NivekT Differential Revision: D43245136 fbshipit-source-id: 2195a7aae64b8ded5483f5b37383a4fa2d94f13e --- benchmarks/cloud/aws_s3.py | 4 +- docs/source/dataloader2.rst | 1 - docs/source/dlv2_tutorial.rst | 6 +- docs/source/reading_service.rst | 4 +- test/dataloader2/test_dataloader2.py | 70 ++++------------- test/dataloader2/test_random.py | 4 +- test/test_distributed.py | 2 +- test/test_graph.py | 11 +-- torchdata/dataloader2/reading_service.py | 87 +++------------------ torchdata/datapipes/iter/util/prefetcher.py | 2 +- 10 files changed, 39 insertions(+), 152 deletions(-) diff --git a/benchmarks/cloud/aws_s3.py b/benchmarks/cloud/aws_s3.py index bb03f9004..32676c78c 100644 --- a/benchmarks/cloud/aws_s3.py +++ b/benchmarks/cloud/aws_s3.py @@ -14,7 +14,7 @@ import pandas as pd import psutil -from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService +from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService from torchdata.datapipes.iter import IterableWrapper @@ -64,7 +64,7 @@ def check_and_output_speed(prefix: str, create_dp_fn: Callable, n_prefetch: int, dp = create_dp_fn() rs_type = "DataLoader2 w/ tar archives" - new_rs = PrototypeMultiProcessingReadingService( + new_rs = MultiProcessingReadingService( num_workers=n_workers, worker_prefetch_cnt=n_prefetch, main_prefetch_cnt=n_prefetch ) dl: DataLoader2 = DataLoader2(dp, reading_service=new_rs) diff --git a/docs/source/dataloader2.rst b/docs/source/dataloader2.rst index 9c58aede3..0a6eaa785 100644 --- a/docs/source/dataloader2.rst +++ b/docs/source/dataloader2.rst @@ -32,7 +32,6 @@ ReadingService DistributedReadingService MultiProcessingReadingService - PrototypeMultiProcessingReadingService SequentialReadingService Each ``ReadingServices`` would take the ``DataPipe`` graph and rewrite it to achieve a few features like dynamic sharding, sharing random seeds and snapshoting for multi-/distributed processes. For more detail about those features, please refer to `the documentation `_. diff --git a/docs/source/dlv2_tutorial.rst b/docs/source/dlv2_tutorial.rst index 99a83cebf..efb64feb7 100644 --- a/docs/source/dlv2_tutorial.rst +++ b/docs/source/dlv2_tutorial.rst @@ -24,11 +24,11 @@ Here is an example of a ``DataPipe`` graph: Multiprocessing ---------------- -``PrototypeMultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes. +``MultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes. .. code:: python - rs = PrototypeMultiProcessingReadingService(num_workers=4) + rs = MultiProcessingReadingService(num_workers=4) dl = DataLoader2(datapipe, reading_service=rs) for epoch in range(10): dl.seed(epoch) @@ -58,7 +58,7 @@ Multiprocessing + Distributed .. code:: python - mp_rs = PrototypeMultiProcessingReadingService(num_workers=4) + mp_rs = MultiProcessingReadingService(num_workers=4) dist_rs = DistributedReadingService() rs = SequentialReadingService(dist_rs, mp_rs) diff --git a/docs/source/reading_service.rst b/docs/source/reading_service.rst index 1d72b6acc..915d846f1 100644 --- a/docs/source/reading_service.rst +++ b/docs/source/reading_service.rst @@ -11,7 +11,7 @@ Features Dynamic Sharding ^^^^^^^^^^^^^^^^ -Dynamic sharding is achieved by ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline. +Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline. - ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``. @@ -121,7 +121,7 @@ Determinism In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access to it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations. -In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations. +In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations. Graph Mode ^^^^^^^^^^^ diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 509f989bb..0475f6d71 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -27,7 +27,6 @@ DataLoader2, DistributedReadingService, MultiProcessingReadingService, - PrototypeMultiProcessingReadingService, ReadingServiceInterface, SequentialReadingService, ) @@ -120,16 +119,6 @@ def test_dataloader2_reading_service(self) -> None: self.assertEqual(batch, expected_batch) expected_batch += 1 - def test_dataloader2_multi_process_reading_service(self) -> None: - test_data_pipe = IterableWrapper(range(3)) - reading_service = MultiProcessingReadingService() - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) - - expected_batch = 0 - for batch in iter(data_loader): - self.assertEqual(batch, expected_batch) - expected_batch += 1 - def test_dataloader2_load_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) reading_service = TestReadingService() @@ -165,7 +154,7 @@ def test_dataloader2_iterates_correctly(self) -> None: None, TestReadingService(), MultiProcessingReadingService(num_workers=4), - PrototypeMultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0), + MultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0), ] for reading_service in reading_services: data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) @@ -232,18 +221,10 @@ def _get_no_reading_service(): def _get_mp_reading_service(): return MultiProcessingReadingService(num_workers=2) - @staticmethod - def _get_proto_reading_service(): - return PrototypeMultiProcessingReadingService(num_workers=2) - @staticmethod def _get_mp_reading_service_zero_workers(): return MultiProcessingReadingService(num_workers=0) - @staticmethod - def _get_proto_reading_service_zero_workers(): - return PrototypeMultiProcessingReadingService(num_workers=0) - def _collect_data(self, datapipe, reading_service_gen): dl: DataLoader2 = DataLoader2(datapipe, reading_service=reading_service_gen()) result = [] @@ -265,9 +246,7 @@ def test_dataloader2_batch_collate(self) -> None: reading_service_generators = ( self._get_mp_reading_service, - self._get_proto_reading_service, self._get_mp_reading_service_zero_workers, - self._get_proto_reading_service_zero_workers, ) for reading_service_gen in reading_service_generators: actual = self._collect_data(dp, reading_service_gen=reading_service_gen) @@ -279,27 +258,6 @@ def test_dataloader2_shuffle(self) -> None: pass -class DataLoader2IntegrationTest(TestCase): - @staticmethod - def _get_mp_reading_service(): - return MultiProcessingReadingService(num_workers=2) - - def test_lazy_load(self): - source_dp = IterableWrapper([(i, i) for i in range(10)]) - map_dp = source_dp.to_map_datapipe() - - reading_service_generators = (self._get_mp_reading_service,) - for reading_service_gen in reading_service_generators: - dl: DataLoader2 = DataLoader2(datapipe=map_dp, reading_service=reading_service_gen()) - # Lazy loading - dp = dl.datapipe - self.assertTrue(dp._map is None) - it = iter(dl) - self.assertEqual(list(it), list(range(10))) - # Lazy loading in multiprocessing - self.assertTrue(map_dp._map is None) - - @unittest.skipIf( TEST_WITH_TSAN, "Fails with TSAN with the following error: starting new threads after multi-threaded " @@ -382,7 +340,7 @@ def is_replicable(self): return False -class PrototypeMultiProcessingReadingServiceTest(TestCase): +class MultiProcessingReadingServiceTest(TestCase): @staticmethod def _worker_init_fn(datapipe, worker_info): datapipe = datapipe.sharding_filter() @@ -403,7 +361,7 @@ def _worker_reset_fn(datapipe, worker_info, worker_seed_generator: SeedGenerator def test_worker_fns(self, ctx): dp: IterDataPipe = IterableWrapper(range(100)).batch(2).shuffle() - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=2, multiprocessing_context=ctx, worker_init_fn=self._worker_init_fn, @@ -448,7 +406,7 @@ def _assert_deterministic_dl_res(dl, exp): sf_dp = single_br_dp.sharding_filter() replace_dp(graph, single_br_dp, sf_dp) dl = DataLoader2( - end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) + end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) ) # Determinism and dynamic sharding # _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)]) @@ -462,7 +420,7 @@ def _assert_deterministic_dl_res(dl, exp): sf_dp = map_dp.sharding_filter() replace_dp(graph, map_dp, sf_dp) dl = DataLoader2( - end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) + end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) ) # Determinism for non-replicable pipeline _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)]) @@ -476,7 +434,7 @@ def _assert_deterministic_dl_res(dl, exp): round_robin_dispatcher = ShardingRoundRobinDispatcher(map_dp, SHARDING_PRIORITIES.MULTIPROCESSING) replace_dp(graph, map_dp, round_robin_dispatcher) dl = DataLoader2( - end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) + end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) ) # Determinism for non-replicable pipeline _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)]) @@ -518,7 +476,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2): replace_dp(graph, branch1_dp, sf1_dp) replace_dp(graph, branch2_dp, sf2_dp) dl = DataLoader2( - end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) + end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) ) # Determinism and dynamic sharding _assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10))) @@ -533,7 +491,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2): sf_dp = branch2_dp.sharding_filter() replace_dp(graph, branch2_dp, sf_dp) dl = DataLoader2( - end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) + end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) ) # Determinism for non-replicable pipeline _assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10))) @@ -547,7 +505,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2): non_replicable_dp2 = ShardingRoundRobinDispatcher(branch2_dp, SHARDING_PRIORITIES.MULTIPROCESSING) replace_dp(graph, branch2_dp, non_replicable_dp2) dl = DataLoader2( - end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) + end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) ) # Determinism for non-replicable pipeline _assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10))) @@ -558,7 +516,7 @@ def test_multi_worker_determinism(self, ctx): dp = dp.shuffle().sharding_filter() dp = dp.batch(2) - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=2, multiprocessing_context=ctx, ) @@ -589,7 +547,7 @@ def test_dispatching_worker_determinism(self, ctx): dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) dp = dp.batch(2) - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=2, multiprocessing_context=ctx, ) @@ -625,7 +583,7 @@ def test_non_replicable_datapipe(self, ctx) -> None: dp = dp.batch(2) non_rep_dp = NonReplicableDataPipe(dp) - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=2, multiprocessing_context=ctx, ) @@ -775,7 +733,7 @@ def _make_dispatching_dp(data_length): @staticmethod def _make_rs(num_workers, ctx): - mp_rs = PrototypeMultiProcessingReadingService( + mp_rs = MultiProcessingReadingService( num_workers=num_workers, multiprocessing_context=ctx, ) @@ -850,7 +808,7 @@ def test_sequential_reading_service_dispatching_dp(self, ctx): self.assertNotEqual(result[1][rank][1], result[3][rank][1]) -instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest) +instantiate_parametrized_tests(MultiProcessingReadingServiceTest) instantiate_parametrized_tests(SequentialReadingServiceTest) diff --git a/test/dataloader2/test_random.py b/test/dataloader2/test_random.py index 3560cceb1..59f8f894e 100644 --- a/test/dataloader2/test_random.py +++ b/test/dataloader2/test_random.py @@ -15,7 +15,7 @@ import torch from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize -from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService +from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService from torchdata.dataloader2.graph.settings import set_graph_random_seed from torchdata.dataloader2.random import SeedGenerator from torchdata.datapipes.iter import IterableWrapper @@ -40,7 +40,7 @@ def test_proto_rs_determinism(self, num_workers): data_source = IterableWrapper(exp) dp = data_source.shuffle().sharding_filter().map(_random_fn) - rs = PrototypeMultiProcessingReadingService(num_workers=num_workers) + rs = MultiProcessingReadingService(num_workers=num_workers) dl = DataLoader2(dp, reading_service=rs) # No seed diff --git a/test/test_distributed.py b/test/test_distributed.py index 5fb397eeb..9d70d1a86 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -23,7 +23,7 @@ from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize from torch.utils.data import DataLoader -from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService +from torchdata.dataloader2 import DataLoader2, DistributedReadingService from torchdata.datapipes.iter import IterableWrapper from torchdata.datapipes.iter.util.distributed import PrefetchTimeoutError diff --git a/test/test_graph.py b/test/test_graph.py index 79d1064cb..c0c346554 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -16,7 +16,7 @@ from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface +from torchdata.dataloader2 import DataLoader2, ReadingServiceInterface from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps from torchdata.dataloader2.graph.utils import _find_replicable_branches from torchdata.dataloader2.random import SeedGenerator @@ -254,15 +254,6 @@ def test_reading_service(self) -> None: self.assertEqual(res, list(dl)) - @unittest.skipIf(IS_WINDOWS, "Fork is required for lambda") - def test_multiprocessing_reading_service(self) -> None: - _, (*_, dp) = self._get_datapipes() # pyre-ignore - rs = MultiProcessingReadingService(2, persistent_workers=True, multiprocessing_context="fork") - dl = DataLoader2(dp, reading_service=rs) - d1 = list(dl) - d2 = list(dl) - self.assertEqual(d1, d2) - def insert_round_robin_sharding(graph, datapipe): dispatch_dp = ShardingRoundRobinDispatcher(datapipe, SHARDING_PRIORITIES.MULTIPROCESSING) diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 3c76bf61d..b16501724 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -18,7 +18,6 @@ import torch.distributed as dist import torch.multiprocessing as mp -from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s @@ -29,7 +28,7 @@ from torchdata.dataloader2.random import dist_share_seed, SeedGenerator from torchdata.dataloader2.utils import process_init_fn, process_reset_fn, WorkerInfo from torchdata.dataloader2.utils.dispatch import _DummyIterDataPipe, find_lca_round_robin_sharding_dp -from torchdata.datapipes.iter import FullSync, IterableWrapper +from torchdata.datapipes.iter import FullSync class ReadingServiceInterface(ABC): @@ -140,6 +139,15 @@ def _collate_no_op(batch): class PrototypeMultiProcessingReadingService(ReadingServiceInterface): + def __new__(cls, *args, **kwargs): + warnings.warn( + "`PrototypeMultiProcessingReadingService` is deprecated and will be removed in TorchData 0.8. " + "Please use `MultiProcessingReadingService`." + ) + return MultiProcessingReadingService(*args, **kwargs) + + +class MultiProcessingReadingService(ReadingServiceInterface): r""" Spawns multiple worker processes to load data from the ``DataPipe`` graph. If any non-replicable ``DataPipe`` (``sharding_round_robin_dispatch``) is presented in the graph, @@ -163,14 +171,6 @@ class PrototypeMultiProcessingReadingService(ReadingServiceInterface): worker_reset_fn: (Callable, optional): Function to be called at the beginning of each epoch in each worker process with ``DataPipe``, ``WorkerInfo`` and ``SeedGenerator`` as the expected arguments. - - Note: - - This ``ReadingService`` is still in prototype mode and will replace - :class:`MultiProcessingReadingService`. - - It currently does both distributed and multiprocessing sharding over the pipeline. - The distributed-related code is going to be removed when ``SequentialReadingService`` - is provided to combine the :class:`DistributedReadingService` and this ``ReadingService``. - """ num_workers: int multiprocessing_context: Optional[str] @@ -216,7 +216,7 @@ def __init__( def initialize(self, datapipe: DataPipe) -> DataPipe: r""" - ``PrototypeMultiProcessingReadingService`` finds information about sharding, + ``MultiProcessingReadingService`` finds information about sharding, separates graph by multiple pieces and reconnects it using queues. creates subprocesses. """ @@ -260,7 +260,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: replicable_dps = _find_replicable_branches(graph) assert ( len(replicable_dps) == 1 - ), "PrototypeMultiProcessingReadingService only supports single replicable branch currently" + ), "MultiProcessingReadingService only supports single replicable branch currently" replicable_dp = replicable_dps[0] if self.worker_prefetch_cnt > 0: @@ -337,7 +337,7 @@ def initialize_iteration( def finalize(self) -> None: r""" - ``PrototypeMultiProcessingReadingService`` invalidate states & properly exits all subprocesses. + ``MultiProcessingReadingService`` invalidate states & properly exits all subprocesses. """ # TODO(618): Check if anyone stuck with messages def clean_me(process, req_queue, res_queue): @@ -406,67 +406,6 @@ def _resume(self): self._main_prefetch_datapipe.resume() # type: ignore[union-attr] -class MultiProcessingReadingService(ReadingServiceInterface): - r""" - ``MultiProcessingReadingService`` that utilizes ``torch.utils.data.DataLoader`` to - launch subprocesses for ``DataPipe`` graph. Please refer to documents of ``DataLoader`` - in https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader for all arguments. - - Note: - This ``ReadingService`` be replaced by :class:`PrototypeMultiProcessingReadingService`. - """ - num_workers: int - pin_memory: bool - timeout: float - worker_init_fn: Optional[Callable[[int], None]] - prefetch_factor: Optional[int] - persistent_workers: bool - - def __init__( - self, - num_workers: int = 0, - pin_memory: bool = False, - timeout: float = 0, - worker_init_fn: Optional[Callable[[int], None]] = None, - multiprocessing_context=None, - prefetch_factor: Optional[int] = None, - persistent_workers: bool = False, - ) -> None: - self.num_workers = num_workers - self.pin_memory = pin_memory - self.timeout = timeout - self.worker_init_fn = worker_init_fn - self.multiprocessing_context = multiprocessing_context - self.prefetch_factor = prefetch_factor - self.persistent_workers = persistent_workers - if self.num_workers == 0: - self.prefetch_factor = None - self.persistent_workers = False - self.dl_: Optional[DataLoader] = None - - # Wrap the DataLoader with IterableWrapper to respect type annotation - def initialize(self, datapipe: DataPipe) -> DataPipe: - self.dl_ = DataLoader( - datapipe, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - timeout=self.timeout, - worker_init_fn=self.worker_init_fn, - multiprocessing_context=self.multiprocessing_context, - prefetch_factor=self.prefetch_factor, - persistent_workers=self.persistent_workers, - # TODO(621): `collate_fn` is necessary until we stop using DLv1 https://github.com/pytorch/data/issues/530 - collate_fn=_collate_no_op, - batch_size=1, # This reading service assume batching is done via DataPipe - ) - return IterableWrapper(self.dl_) # type: ignore[return-value] - - def finalize(self) -> None: - if self.persistent_workers and self.dl_ is not None and self.dl_._iterator is not None: - self.dl_._iterator._shutdown_workers() # type: ignore[attr-defined] - self.dl_._iterator = None - - class DistributedReadingService(ReadingServiceInterface): r""" ``DistributedReadingSerivce`` handles distributed sharding on the graph of ``DataPipe`` and diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 1dc18c2bb..5cdb5374f 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -34,7 +34,7 @@ class PrefetcherIterDataPipe(IterDataPipe): and stores the result in the buffer, ready to be consumed by the subsequent DataPipe. It has no effect aside from getting the sample ready ahead of time. - This is used by ``PrototypeMultiProcessingReadingService`` when the arguments + This is used by ``MultiProcessingReadingService`` when the arguments ``worker_prefetch_cnt`` (for prefetching at each worker process) or ``main_prefetch_cnt`` (for prefetching at the main loop) are greater than 0.