Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataPipe] Ensures Prefetcher shuts down properly #1166

Closed
wants to merge 9 commits into from
7 changes: 7 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def _test_fullsync(rank, world_size, backend, q):
it2 = iter(dp3) # Reset
next(it2)

dp4 = dp.prefetch(2)
it = iter(dp4)
next(it)
dp4.pause()
it2 = iter(dp4) # Reset
next(it2)

_finalize_distributed_queue(rank, q)

@world_size_parametrize
Expand Down
37 changes: 22 additions & 15 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, source_datapipe, buffer_size: int = 10):
raise ValueError("'buffer_size' is required to be a positive integer.")
self.buffer_size = buffer_size
self.thread: Optional[threading.Thread] = None
self.prefetch_data: Optional[_PrefetchData] = None

@staticmethod
def thread_worker(prefetch_data: _PrefetchData):
Expand Down Expand Up @@ -104,9 +105,7 @@ def __iter__(self):
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
thread.join()
self.shutdown()

def __getstate__(self):
"""
Expand All @@ -127,12 +126,7 @@ def __setstate__(self, state):

@final
def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data.paused = False
self.thread.join()
self.thread = None
self.shutdown()

def pause(self):
if self.thread is not None:
Expand All @@ -145,13 +139,28 @@ def pause(self):

@final
def resume(self):
if self.thread is not None and (
not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0
if (
self.thread is not None
and self.prefetch_data is not None
and (not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0)
):
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = True
self.prefetch_data.paused = False

@final
def shutdown(self):
if hasattr(self, "prefetch_data") and self.prefetch_data is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data.paused = False
self.prefetch_data = None
if hasattr(self, "thread") and self.thread is not None:
self.thread.join()
self.thread = None

def __del__(self):
self.shutdown()

def __len__(self) -> int:
if isinstance(self.source_datapipe, Sized):
return len(self.source_datapipe)
Expand Down Expand Up @@ -235,9 +244,7 @@ def __iter__(self):
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
thread.join()
self.shutdown()

def __getstate__(self):
state = super().__getstate__()
Expand Down