diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 15b582187..dfdc98e15 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -308,10 +308,12 @@ def from_state( ) data_loader.reading_service_state = reading_service_state - randomness_state = state[RANDOMNESS_STATE_KEY_NAME] - data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1] - data_loader._seed_generator = pickle.loads(randomness_state[2]) - data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) + # This check is needed for backward compatibility of `state_dict` for users loading from older version + if RANDOMNESS_STATE_KEY_NAME in state: + randomness_state = state[RANDOMNESS_STATE_KEY_NAME] + data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1] + data_loader._seed_generator = pickle.loads(randomness_state[2]) + data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) return data_loader @@ -339,10 +341,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.datapipe = deserialized_datapipe self.reading_service_state = reading_service_state - randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME] - self._seed, self._reset_seed = randomness_state[0], randomness_state[1] - self._seed_generator = pickle.loads(randomness_state[2]) - self._initial_seed_generator = pickle.loads(randomness_state[3]) + # This check is needed for backward compatibility of `state_dict` for users loading from older version + if RANDOMNESS_STATE_KEY_NAME in state_dict: + randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME] + self._seed, self._reset_seed = randomness_state[0], randomness_state[1] + self._seed_generator = pickle.loads(randomness_state[2]) + self._initial_seed_generator = pickle.loads(randomness_state[3]) # re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt if self.datapipe_adapter_fns is not None: