Skip to content

Commit

Permalink
Adding Prefetcher into the PrototypeRS and fixing prefetcher bug (pyt…
Browse files Browse the repository at this point in the history
…orch#826)

Summary: Pull Request resolved: pytorch#826

Test Plan: Imported from OSS

Reviewed By: NivekT, msaroufim

Differential Revision: D40308358

Pulled By: VitalyFedyunin

fbshipit-source-id: f4b706c29ceffc52c376e9cddfc56a75819ff1ee
  • Loading branch information
VitalyFedyunin authored and ejguan committed Oct 23, 2022
1 parent fabce15 commit affb0fa
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
22 changes: 22 additions & 0 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ def test_dataloader2_load_state_dict(self) -> None:

restored_data_loader.shutdown()

def test_dataloader2_iterates_correctly(self) -> None:
test_data_pipe = IterableWrapper(range(10)).sharding_filter()
reading_services = [
None,
TestReadingService(),
MultiProcessingReadingService(num_workers=4),
PrototypeMultiProcessingReadingService(num_workers=4, prefetch_worker=0),
]
for reading_service in reading_services:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
actual = []
for i in data_loader:
actual.append(i)
self.assertEqual(list(range(10)), actual)
actual = []
for i in data_loader:
actual.append(i)
self.assertEqual(list(range(10)), actual)

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
Expand Down
21 changes: 19 additions & 2 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,20 @@ class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
num_workers: int
processes: List
datapipes: List
combined_datapipes: Optional[_IterateQueueDataPipes]
combined_datapipes: Optional[IterDataPipe]

def __init__(
self,
num_workers: int = 0,
multiprocessing_context=None,
prefetch_worker: int = 10,
prefetch_mainloop: int = 10,
) -> None:
self.num_workers = num_workers
# TODO(613): Should be one of 'fork', 'spawn'
self.multiprocessing_context = multiprocessing_context
self.prefetch_worker = prefetch_worker
self.prefetch_mainloop = prefetch_mainloop
self.processes = []
self.datapipes = []
self.combined_datapipes = None
Expand All @@ -196,6 +200,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
if self.num_workers == 0:
# TODO(616): Warn and recommend usage of InProcessReadingService
return datapipe

if self.prefetch_worker > 0:
datapipe = datapipe.prefetch(self.prefetch_worker)

for worker_id in range(self.num_workers):
# TODO(617): Separate into function, because we also need to apply distributed seed
# and call it inside process
Expand All @@ -216,11 +224,20 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
self.datapipes.append(local_datapipe)

self.combined_datapipes = _IterateQueueDataPipes(self.datapipes)
if self.prefetch_mainloop > 0:
self.combined_datapipes = self.combined_datapipes.prefetch(self.prefetch_mainloop)
return self.combined_datapipes # type: ignore[return-value]

def initialize_iteration(self) -> None:
if self.combined_datapipes is not None:
self.combined_datapipes.reset_epoch()
if self.prefetch_mainloop > 0:
# Stop prefetching first
self.combined_datapipes.reset()
self.combined_datapipes.source_datapipe.reset_epoch()
self.combined_datapipes.source_datapipe.reset()
else:
self.combined_datapipes.reset_epoch()
self.combined_datapipes.reset()

def __del__(self):
self.finalize()
Expand Down
3 changes: 0 additions & 3 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,3 @@ def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.thread.join()

def reset_iterator(self):
self.reset()

0 comments on commit affb0fa

Please sign in to comment.