Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Prefetcher into the PrototypeRS and fixing prefetcher bug #826

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ def test_dataloader2_load_state_dict(self) -> None:

restored_data_loader.shutdown()

def test_dataloader2_iterates_correctly(self) -> None:
test_data_pipe = IterableWrapper(range(10)).sharding_filter()
reading_services = [
None,
TestReadingService(),
MultiProcessingReadingService(num_workers=4),
PrototypeMultiProcessingReadingService(num_workers=4, prefetch_worker=0),
]
for reading_service in reading_services:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
self.assertEqual(list(range(10)), list(data_loader))
actual = []
for i in data_loader:
actual.append(i)
self.assertEqual(list(range(10)), actual)
actual = []
for i in data_loader:
actual.append(i)
self.assertEqual(list(range(10)), actual)

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
Expand Down
21 changes: 19 additions & 2 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,20 @@ class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
num_workers: int
processes: List
datapipes: List
combined_datapipes: Optional[_IterateQueueDataPipes]
combined_datapipes: Optional[IterDataPipe]

def __init__(
self,
num_workers: int = 0,
multiprocessing_context=None,
prefetch_worker: int = 10,
prefetch_mainloop: int = 10,
) -> None:
self.num_workers = num_workers
# TODO(613): Should be one of 'fork', 'spawn'
self.multiprocessing_context = multiprocessing_context
self.prefetch_worker = prefetch_worker
self.prefetch_mainloop = prefetch_mainloop
self.processes = []
self.datapipes = []
self.combined_datapipes = None
Expand All @@ -196,6 +200,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
if self.num_workers == 0:
# TODO(616): Warn and recommend usage of InProcessReadingService
return datapipe

if self.prefetch_worker > 0:
datapipe = datapipe.prefetch(self.prefetch_worker)

for worker_id in range(self.num_workers):
# TODO(617): Separate into function, because we also need to apply distributed seed
# and call it inside process
Expand All @@ -216,11 +224,20 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
self.datapipes.append(local_datapipe)

self.combined_datapipes = _IterateQueueDataPipes(self.datapipes)
if self.prefetch_mainloop > 0:
self.combined_datapipes = self.combined_datapipes.prefetch(self.prefetch_mainloop)
return self.combined_datapipes # type: ignore[return-value]

def initialize_iteration(self) -> None:
if self.combined_datapipes is not None:
self.combined_datapipes.reset_epoch()
if self.prefetch_mainloop > 0:
# Stop prefetching first
self.combined_datapipes.reset()
Copy link
Contributor

@ejguan ejguan Oct 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since reset will be called in the begining of the __iter__ as well, I think it's better to add a different API to stop prefetching and called here.

And, reset can be the last resort to stop thread if that API is not invoked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, we should have a graph function to stop prefetcher from the end of the pipeline to the beginning especially when multiple prefetchers are presented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also found them annoyingly duplicating and ideally we need something like GraphStop operation to freeze all parallel work. It would be extremely useful for snapshotting. But it is out of scope for this PR (Anyway I will post separate issue to capture demand)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM! There are a few special DataPipes need to be handled specially with this GraphStop:

  • FullSync
  • Prefetcher

self.combined_datapipes.source_datapipe.reset_epoch()
self.combined_datapipes.source_datapipe.reset()
else:
self.combined_datapipes.reset_epoch()
self.combined_datapipes.reset()

def __del__(self):
self.finalize()
Expand Down
3 changes: 0 additions & 3 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,3 @@ def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.thread.join()

def reset_iterator(self):
self.reset()
Comment on lines -105 to -107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I found the same issue as reset will be called automatically at the beginning of __iter__.