diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 8f8f4db0f..0aeb3203e 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -7,7 +7,8 @@ import threading import time -from typing import Optional +from collections import deque +from typing import Deque, Optional from torchdata.dataloader2 import communication @@ -19,11 +20,11 @@ class _PrefetchData: - def __init__(self, source_datapipe, buffer_size): + def __init__(self, source_datapipe, buffer_size: int): self.run_prefetcher = True # TODO: Potential optimization is changing buffer from list to dequeue - self.prefetch_buffer = [] - self.buffer_size = buffer_size + self.prefetch_buffer: Deque = deque() + self.buffer_size: int = buffer_size self.source_datapipe = source_datapipe @@ -92,8 +93,7 @@ def __iter__(self): self.thread.start() while prefetch_data.run_prefetcher: if len(prefetch_data.prefetch_buffer) > 0: - yield prefetch_data.prefetch_buffer[0] - prefetch_data.prefetch_buffer = prefetch_data.prefetch_buffer[1:] + yield prefetch_data.prefetch_buffer.popleft() else: # TODO: Calculate sleep interval based on previous availability speed time.sleep(CONSUMER_SLEEP_INTERVAL) @@ -113,7 +113,7 @@ def __getstate__(self): after entire state of the graph is saved). """ # TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration - return dict(source_datapipe=self.source_datapipe, buffer_size=self.buffer_size) + return {"source_datapipe": self.source_datapipe, "buffer_size": self.buffer_size} def __setstate__(self, state): self.source_datapipe = state["source_datapipe"]