Skip to content

Commit

Permalink
Update Prefetcher buffer to use deque (pytorch#842)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#842

This was one of the TODOs in Prefetcher. It makes the implementation cleaner, marginally faster (but not enough to make a significant difference).

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D40517618

Pulled By: NivekT

fbshipit-source-id: a199c06b060a83e087a3cac17ec57c87a7127c40
  • Loading branch information
NivekT authored and ejguan committed Oct 23, 2022
1 parent 0023047 commit b9e0732
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import threading
import time

from typing import Optional
from collections import deque
from typing import Deque, Optional

from torchdata.dataloader2 import communication

Expand All @@ -19,11 +20,11 @@


class _PrefetchData:
def __init__(self, source_datapipe, buffer_size):
def __init__(self, source_datapipe, buffer_size: int):
self.run_prefetcher = True
# TODO: Potential optimization is changing buffer from list to dequeue
self.prefetch_buffer = []
self.buffer_size = buffer_size
self.prefetch_buffer: Deque = deque()
self.buffer_size: int = buffer_size
self.source_datapipe = source_datapipe


Expand Down Expand Up @@ -92,8 +93,7 @@ def __iter__(self):
self.thread.start()
while prefetch_data.run_prefetcher:
if len(prefetch_data.prefetch_buffer) > 0:
yield prefetch_data.prefetch_buffer[0]
prefetch_data.prefetch_buffer = prefetch_data.prefetch_buffer[1:]
yield prefetch_data.prefetch_buffer.popleft()
else:
# TODO: Calculate sleep interval based on previous availability speed
time.sleep(CONSUMER_SLEEP_INTERVAL)
Expand All @@ -113,7 +113,7 @@ def __getstate__(self):
after entire state of the graph is saved).
"""
# TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration
return dict(source_datapipe=self.source_datapipe, buffer_size=self.buffer_size)
return {"source_datapipe": self.source_datapipe, "buffer_size": self.buffer_size}

def __setstate__(self, state):
self.source_datapipe = state["source_datapipe"]
Expand Down

0 comments on commit b9e0732

Please sign in to comment.