Skip to content

Commit

Permalink
[DataPipe] Ensures Prefetcher shuts down properly
Browse files Browse the repository at this point in the history
ghstack-source-id: dd778b5c86cc13a8b084c86385cc10bd163659d2
Pull Request resolved: #1166
  • Loading branch information
NivekT committed May 19, 2023
1 parent ba31745 commit f9b6e36
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
7 changes: 7 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def _test_fullsync(rank, world_size, backend, q):
it2 = iter(dp3) # Reset
next(it2)

dp4 = dp.prefetch(2)
it = iter(dp4)
next(it)
dp4.pause()
it2 = iter(dp4) # Reset
next(it2)

_finalize_distributed_queue(rank, q)

@world_size_parametrize
Expand Down
33 changes: 24 additions & 9 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, source_datapipe, buffer_size: int = 10):
raise ValueError("'buffer_size' is required to be a positive integer.")
self.buffer_size = buffer_size
self.thread: Optional[threading.Thread] = None
self.prefetch_data: Optional[_PrefetchData] = None

@staticmethod
def thread_worker(prefetch_data: _PrefetchData):
Expand All @@ -83,6 +84,8 @@ def thread_worker(prefetch_data: _PrefetchData):
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

def __iter__(self):
if self.thread is not None:
self.shutdown()
try:
prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size)
self.prefetch_data = prefetch_data
Expand All @@ -104,9 +107,7 @@ def __iter__(self):
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
thread.join()
self.shutdown()

def __getstate__(self):
"""
Expand Down Expand Up @@ -145,13 +146,26 @@ def pause(self):

@final
def resume(self):
if self.thread is not None and (
not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0
if (
self.thread is not None
and self.prefetch_data is not None
and (not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0)
):
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = True
self.prefetch_data.paused = False

@final
def shutdown(self):
if self.prefetch_data:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data = None
if self.thread:
self.thread.join()

def __del__(self):
self.shutdown()

def __len__(self) -> int:
if isinstance(self.source_datapipe, Sized):
return len(self.source_datapipe)
Expand Down Expand Up @@ -185,6 +199,7 @@ def __init__(self, source_datapipe, device=None, pin_memory_fn=pin_memory_fn):
device = torch.cuda.current_device()
self.device = device
self.pin_memory_fn = pin_memory_fn
self.prefetch_data: Optional[_PrefetchData] = None

def is_replicable(self) -> bool:
return False
Expand All @@ -210,6 +225,8 @@ def thread_worker(prefetch_data: _PrefetchData, pin_memory_fn, device): # type:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

def __iter__(self):
if self.thread is not None:
self.shutdown()
try:
prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size)
self.prefetch_data = prefetch_data
Expand All @@ -235,9 +252,7 @@ def __iter__(self):
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
thread.join()
self.shutdown()

def __getstate__(self):
state = super().__getstate__()
Expand Down

0 comments on commit f9b6e36

Please sign in to comment.