Skip to content

Commit

Permalink
Add serialization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Aug 3, 2022
1 parent 84bc118 commit c81c881
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions torchdata/datapipes/iter/util/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,8 @@ def __init__(self, datapipe: IterDataPipe, timeout=default_timeout_in_s):
self.datapipe = datapipe
self.timeout = timeout

if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError(
"Torch Distributed is required to be initialized"
)
self._process_group = dist.new_group(backend="gloo")
self._world_size = dist.get_world_size()
self._process_group = None
self._world_size = 1

self._lock = threading.RLock()
self._cv = threading.Condition(lock=self._lock)
Expand Down Expand Up @@ -154,6 +150,13 @@ def _callback_fn(self, exp: Expected) -> None:
self._cv.notify()

def __iter__(self) -> Iterator[T_co]:
if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError(
"Torch Distributed is required to be initialized"
)
self._process_group = dist.new_group(backend="gloo")
self._world_size = dist.get_world_size()

assert self._executor is None
self._executor = _PrefetchExecutor(
iter(self.datapipe),
Expand Down Expand Up @@ -187,3 +190,25 @@ def reset(self):
self._error = None
self._sync_counter = torch.tensor([0], dtype=torch.int32)
self._done_callback = False

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
if self._executor is not None:
self._executor.shutdown()
state = (
self.datapipe,
self.timeout,
)
return state

def __setstate__(self, state):
self.datapipe, self.timeout = state
self._process_group = None
self._world_size = 1
self._lock = threading.RLock()
self._cv = threading.Condition(lock=self._lock)
self._executor = None
self._error = None
self._sync_counter = torch.tensor([0], dtype=torch.int32)
self._done_callback = False

0 comments on commit c81c881

Please sign in to comment.