Skip to content

Commit

Permalink
Implementing thread based PrefetcherIterDataPipe (#770)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #770

Test Plan: Imported from OSS

Reviewed By: NivekT

Differential Revision: D39816751

Pulled By: VitalyFedyunin

fbshipit-source-id: 9cd38309b8c008518cb820bf3748b7654c86f214
  • Loading branch information
VitalyFedyunin authored and facebook-github-bot committed Oct 3, 2022
1 parent 57357b7 commit 9ad8efb
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
10 changes: 10 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,16 @@ def odd_even_bug(i: int) -> int:
result_dp = source_dp.zip_with_map(map_dp, odd_even)
self.assertEqual(len(source_dp), len(result_dp))

def test_prefetcher_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(50000))
prefetched_dp = source_dp.prefetch(10)
# check if early termination resets child thread properly
for _, _ in zip(range(100), prefetched_dp):
pass
expected = list(source_dp)
actual = list(prefetched_dp)
self.assertEqual(expected, actual)

def test_repeater_iterdatapipe(self) -> None:
import itertools

Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.prefetch import FullSyncIterDataPipe as FullSync
from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher
from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
Expand Down Expand Up @@ -190,6 +191,7 @@
"OnlineReader",
"ParagraphAggregator",
"ParquetDataFrameLoader",
"Prefetcher",
"RandomSplitter",
"RarArchiveLoader",
"Repeater",
Expand Down
105 changes: 105 additions & 0 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
import time

from typing import Optional

from torchdata.dataloader2 import communication

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fullfilment checks
CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availablitity in buffer


class _PrefetchData:
def __init__(self, source_datapipe, buffer_size):
self.run_prefetcher = True
# TODO: Potential optimization is changing buffer from list to dequeue
self.prefetch_buffer = []
self.buffer_size = buffer_size
self.source_datapipe = source_datapipe


@functional_datapipe("prefetch")
class PrefetcherIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe, buffer_size: int = 10):
self.source_datapipe = source_datapipe
if buffer_size <= 0:
raise ValueError("'buffer_size' is required to be a positive integer.")
self.buffer_size = buffer_size
self.thread: Optional[threading.Thread] = None

@staticmethod
def thread_worker(prefetch_data):
itr = iter(prefetch_data.source_datapipe)
stop_iteration = False
while prefetch_data.run_prefetcher:
if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size and not stop_iteration:
try:
item = next(itr)
prefetch_data.prefetch_buffer.append(item)
except StopIteration:
stop_iteration = True
except communication.iter.InvalidStateResetRequired:
stop_iteration = True
except communication.iter.TerminateRequired:
prefetch_data.run_prefetcher = False
elif stop_iteration and len(prefetch_data.prefetch_buffer) == 0:
prefetch_data.run_prefetcher = False
else: # Buffer is full, waiting for main thread to consume items
# TODO: Calculate sleep interval based on previous consumption speed
time.sleep(PRODUCER_SLEEP_INTERVAL)

def __iter__(self):
if self.buffer_size < 1:
yield from self.source_datapipe
else:
try:
prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size)
self.prefetch_data = prefetch_data
self.thread = threading.Thread(
target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True
)
self.thread.start()
while prefetch_data.run_prefetcher:
if len(prefetch_data.prefetch_buffer) > 0:
yield prefetch_data.prefetch_buffer[0]
prefetch_data.prefetch_buffer = prefetch_data.prefetch_buffer[1:]
else:
# TODO: Calculate sleep interval based on previous availability speed
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
if self.thread is not None:
self.thread.join()
self.thread = None

def __getstate__(self):
"""
Getting state in threading enviroment requires next operations:
1) Stopping of the producer thread.
2) Saving buffer.
3) Adding lazy restart of producer thread when __next__ is called again
(this will guarantee that you only change state of the source_datapipe
after entire state of the graph is saved).
"""
# TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration
return dict(source_datapipe=self.source_datapipe)

def __setstate__(self, state):
self.source_datapipe = state["source_datapipe"]

def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.thread.join()

def reset_iterator(self):
self.reset()

0 comments on commit 9ad8efb

Please sign in to comment.