diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 5f9d8ebb0..3a4a766dc 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -122,6 +122,28 @@ def test_dataloader2_load_state_dict(self) -> None: restored_data_loader.shutdown() + def test_dataloader2_iterates_correctly(self) -> None: + test_data_pipe = IterableWrapper(range(10)).sharding_filter() + reading_services = [ + None, + TestReadingService(), + MultiProcessingReadingService(num_workers=4), + PrototypeMultiProcessingReadingService(num_workers=4, prefetch_worker=0), + ] + for reading_service in reading_services: + data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) + self.assertEqual(list(range(10)), list(data_loader)) + self.assertEqual(list(range(10)), list(data_loader)) + self.assertEqual(list(range(10)), list(data_loader)) + actual = [] + for i in data_loader: + actual.append(i) + self.assertEqual(list(range(10)), actual) + actual = [] + for i in data_loader: + actual.append(i) + self.assertEqual(list(range(10)), actual) + def test_dataloader2_reset(self) -> None: test_data_pipe = IterableWrapper(range(10)) diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index d60775e4f..548a97513 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -162,16 +162,20 @@ class PrototypeMultiProcessingReadingService(ReadingServiceInterface): num_workers: int processes: List datapipes: List - combined_datapipes: Optional[_IterateQueueDataPipes] + combined_datapipes: Optional[IterDataPipe] def __init__( self, num_workers: int = 0, multiprocessing_context=None, + prefetch_worker: int = 10, + prefetch_mainloop: int = 10, ) -> None: self.num_workers = num_workers # TODO(613): Should be one of 'fork', 'spawn' self.multiprocessing_context = multiprocessing_context + self.prefetch_worker = prefetch_worker + self.prefetch_mainloop = prefetch_mainloop self.processes = [] self.datapipes = [] self.combined_datapipes = None @@ -196,6 +200,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: if self.num_workers == 0: # TODO(616): Warn and recommend usage of InProcessReadingService return datapipe + + if self.prefetch_worker > 0: + datapipe = datapipe.prefetch(self.prefetch_worker) + for worker_id in range(self.num_workers): # TODO(617): Separate into function, because we also need to apply distributed seed # and call it inside process @@ -216,11 +224,20 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: self.datapipes.append(local_datapipe) self.combined_datapipes = _IterateQueueDataPipes(self.datapipes) + if self.prefetch_mainloop > 0: + self.combined_datapipes = self.combined_datapipes.prefetch(self.prefetch_mainloop) return self.combined_datapipes # type: ignore[return-value] def initialize_iteration(self) -> None: if self.combined_datapipes is not None: - self.combined_datapipes.reset_epoch() + if self.prefetch_mainloop > 0: + # Stop prefetching first + self.combined_datapipes.reset() + self.combined_datapipes.source_datapipe.reset_epoch() + self.combined_datapipes.source_datapipe.reset() + else: + self.combined_datapipes.reset_epoch() + self.combined_datapipes.reset() def __del__(self): self.finalize() diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index c0867f969..d3b1003e9 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -102,6 +102,3 @@ def reset(self): if self.thread is not None: self.prefetch_data.run_prefetcher = False self.thread.join() - - def reset_iterator(self): - self.reset()