Skip to content

Commit

Permalink
[DataLoader2] Saving and restoring initial seed generator
Browse files Browse the repository at this point in the history
ghstack-source-id: 9aa77a0f2de7b83ba406552228880f720e20a3b2
Pull Request resolved: #998
  • Loading branch information
NivekT committed Mar 17, 2023
1 parent 7d97475 commit 70a7cc8
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
59 changes: 59 additions & 0 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,65 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
rs = MultiProcessingReadingService(num_workers=2)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
next(it1) # Starts iterating
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._initial_seed_generator = None

print(dl._seed_generator)
print(restored_dl._seed_generator)

print(dl._seed_generator == restored_dl._seed_generator)

restored_dl2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl2._restore_checkpoint_beginning_of_epoch()

res1 = list(it1)
res2 = list(restored_dl)
res3 = list(restored_dl2)

print(res1)
print(res2)
print(res3)
# self.assertEqual(list(it1), list(restored_dl)[1:]) # Note skipping over 1st element from actual result

dl.shutdown()
restored_dl.shutdown()
restored_dl2.shutdown()

# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
# Currently, using sharding_round_robin raises a warning
# def test_round_robin_dispatching_pause_limit(self):
Expand Down
30 changes: 28 additions & 2 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pickle
import warnings

from dataclasses import dataclass
Expand All @@ -20,6 +19,7 @@
T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
RANDOMNESS_STATE_KEY_NAME = "randomness_state"


@dataclass
Expand Down Expand Up @@ -185,6 +185,8 @@ def __init__(
self._seed_generator: SeedGenerator = SeedGenerator()
self._seed: Optional[int] = None
self._reset_seed: bool = True
# Seed generator as of beginning of each epoch
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)

def __iter__(self) -> DataLoader2Iterator[T_co]:
r"""
Expand All @@ -207,6 +209,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
else:
self._seed_generator.seed()

# Saving initial seed generator state
self._initial_seed_generator = clone(self._seed_generator)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand Down Expand Up @@ -278,10 +283,17 @@ def state_dict(self) -> Dict[str, Any]:

# Serialize datapipe after applying adapters and before reading service adaption
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
serialized_randomness_state = (
self._seed,
self._reset_seed,
pickle.dumps(self._seed_generator),
pickle.dumps(self._initial_seed_generator),
)

return {
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state,
}

@classmethod
Expand All @@ -303,6 +315,12 @@ def from_state(
reading_service=reading_service,
)
data_loader.reading_service_state = reading_service_state

randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])

return data_loader

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -329,12 +347,20 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = clone(self.datapipe)

def _restore_checkpoint_beginning_of_epoch(self) -> None:
self._seed_generator = self._initial_seed_generator

def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
Expand Down

0 comments on commit 70a7cc8

Please sign in to comment.