Skip to content

Commit

Permalink
[DataLoader2] Adding guard to randomness state for backward compatibi…
Browse files Browse the repository at this point in the history
…lity

ghstack-source-id: 6ff617871e11362ca378216fa9e912921349c651
Pull Request resolved: #1122
  • Loading branch information
NivekT committed Apr 6, 2023
1 parent 837ede1 commit f3290c3
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,12 @@ def from_state(
)
data_loader.reading_service_state = reading_service_state

randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])
# This check is needed for backward compatibility of `state_dict` for users loading from older version
if RANDOMNESS_STATE_KEY_NAME in state:
randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])

return data_loader

Expand Down Expand Up @@ -339,10 +341,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])
# This check is needed for backward compatibility of `state_dict` for users loading from older version
if RANDOMNESS_STATE_KEY_NAME in state_dict:
randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
Expand Down

0 comments on commit f3290c3

Please sign in to comment.