From 851e26a0297e22c73f133b40e462cc42f868c885 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Mon, 10 Apr 2023 11:42:42 -0700 Subject: [PATCH] Saving and restoring initial seed generator (#1124) Summary: Pull Request resolved: https://github.com/pytorch/data/pull/1124 Reland of #998 with added guard while loading randomness state in `DataLoader2` for backward compatibility Changes to `DataLoader2`: - Modifying `state_dict` to store `randomness_state`, which includes: - `_seed: int` - `_reset_seed: bool` - flag indicating whether `_seed` needs to be set - `_seed_generator` - the latest version at the time when `state_dict` is called - `_initial_seed_generator` - the versopm that is saved at the beginning of very epoch - Modifying `from_state` and `load_state_dict` to restore `randomness_state` - Adding a method `_restore_checkpoint_beginning_of_epoch` - This sets `self._seed_generator = self._initial_seed_generator`, allowing users to re-create an epoch from the beginning. --- ### Considerations Storing the randomness states provide more flexibility for users to restore as they see fit. The decision to do that should not be controversial. I decided to make add a new method for checkpointing at the beginning of the epoch, ensure that users are not confused about what randomness is restored by default. The basic idea is that we want to allow users to restore `dl2._seed_generator` to the previously saved version. From that point on, they can create a new `__iter__` and continue from the beginning of the epoch. - Note that since `_seed` and `_reset_seed` are also saved, if the users were planning to use a different seed or if there was a need to re-seed, those remain valid after restoring the checkpoint. - Finally, if users change their mind at any point (after restoring) and want to manual set `seed`. That `seed` will override any other behavior and the `seed` will be used. Test Plan: Imported from OSS f425956975 Reviewed By: bearzx Differential Revision: D44748514 Pulled By: NivekT fbshipit-source-id: 8713592902b1e0680e46e4db4280c84c708dbf55 --- test/dataloader2/test_mprs.py | 66 +++++++++++++++++-- torchdata/dataloader2/dataloader2.py | 41 +++++++++++- .../dataloader2/random/seed_generator.py | 10 +++ torchdata/dataloader2/reading_service.py | 2 + 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/test/dataloader2/test_mprs.py b/test/dataloader2/test_mprs.py index b4fb1d7f8..05a7ba27b 100644 --- a/test/dataloader2/test_mprs.py +++ b/test/dataloader2/test_mprs.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import multiprocessing as mp import unittest from unittest import TestCase @@ -14,7 +13,7 @@ from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper +from torchdata.datapipes.iter import IterableWrapper, IterDataPipe def _add_one(x: int) -> int: @@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000): return dp +class NonShardableDataPipe(IterDataPipe): + def __init__(self, dp: IterDataPipe): + self.dp = dp + + def is_replicable(self): + return False + + def __iter__(self): + yield from self.dp + + class TestMultiProcessingReadingService(TestCase): r""" This tests specific functionalities of MultiProcessingReadingService, notably @@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: worker_prefetch_cnt=worker_prefetch, multiprocessing_context=ctx, ) - dl = DataLoader2(dp, reading_service=rs) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) it = iter(dl) for _ in range(10): _ = next(it) @@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: worker_prefetch_cnt=worker_prefetch, multiprocessing_context=ctx, ) - dl = DataLoader2(dp, reading_service=rs) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) _ = list(dl) dl.shutdown() @@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr res.append(x) self.assertEqual(9, len(res)) + def test_initial_epoch_checkpointing(self): + dp = IterableWrapper(range(20)).shuffle().sharding_filter() + # Note that the second `shuffle` occurs in the main process, which uses a different RNG from + # the `shuffle` done in the worker processes + dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type] + rs = MultiProcessingReadingService(num_workers=2) + + # Functional Test: Saving state before iterator is created + dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs) + dl.seed(1) + initial_state = dl.state_dict() + it1 = iter(dl) + + restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] + restored_dl._restore_checkpoint_beginning_of_epoch() + self.assertEqual(list(it1), list(restored_dl)) + + dl.shutdown() + restored_dl.shutdown() + + # Functional Test: Saving state after iterator is created + dl = DataLoader2(datapipe=dp, reading_service=rs) + dl.seed(1) + it1 = iter(dl) + initial_state = dl.state_dict() + + restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] + restored_dl._restore_checkpoint_beginning_of_epoch() + self.assertEqual(list(it1), list(restored_dl)) + + dl.shutdown() + restored_dl.shutdown() + + # Functional Test: Saving state after iterator is created and began iterating + dl = DataLoader2(datapipe=dp, reading_service=rs) + dl.seed(1) + it1 = iter(dl) + temp = next(it1) # Starts iterating + initial_state = dl.state_dict() + + restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] + restored_dl._restore_checkpoint_beginning_of_epoch() + + self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result + + dl.shutdown() + restored_dl.shutdown() + # TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding # Currently, using sharding_round_robin raises a warning # def test_round_robin_dispatching_pause_limit(self): diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 893e51180..dfdc98e15 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import pickle import warnings from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union @@ -19,6 +19,7 @@ T_co = TypeVar("T_co", covariant=True) SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe" READING_SERVICE_STATE_KEY_NAME = "reading_service_state" +RANDOMNESS_STATE_KEY_NAME = "randomness_state" class DataLoader2Iterator(Iterator[T_co]): @@ -176,6 +177,8 @@ def __init__( self._seed_generator: SeedGenerator = SeedGenerator() self._seed: Optional[int] = None self._reset_seed: bool = True + # Seed generator as of beginning of each epoch + self._initial_seed_generator: SeedGenerator = clone(self._seed_generator) def __iter__(self) -> DataLoader2Iterator[T_co]: r""" @@ -198,6 +201,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]: else: self._seed_generator.seed() + # Saving initial seed generator state + self._initial_seed_generator = clone(self._seed_generator) + if not self._adapted and self.reading_service is not None: if self.reading_service_state is None: self.datapipe = self.reading_service.initialize(self.datapipe) @@ -269,10 +275,17 @@ def state_dict(self) -> Dict[str, Any]: # Serialize datapipe after applying adapters and before reading service adaption serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt) + serialized_randomness_state = ( + self._seed, + self._reset_seed, + pickle.dumps(self._seed_generator), + pickle.dumps(self._initial_seed_generator), + ) return { SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe, READING_SERVICE_STATE_KEY_NAME: reading_service_state, + RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state, } @classmethod @@ -294,6 +307,14 @@ def from_state( reading_service=reading_service, ) data_loader.reading_service_state = reading_service_state + + # This check is needed for backward compatibility of `state_dict` for users loading from older version + if RANDOMNESS_STATE_KEY_NAME in state: + randomness_state = state[RANDOMNESS_STATE_KEY_NAME] + data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1] + data_loader._seed_generator = pickle.loads(randomness_state[2]) + data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) + return data_loader def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -320,12 +341,30 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.datapipe = deserialized_datapipe self.reading_service_state = reading_service_state + # This check is needed for backward compatibility of `state_dict` for users loading from older version + if RANDOMNESS_STATE_KEY_NAME in state_dict: + randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME] + self._seed, self._reset_seed = randomness_state[0], randomness_state[1] + self._seed_generator = pickle.loads(randomness_state[2]) + self._initial_seed_generator = pickle.loads(randomness_state[3]) + # re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt if self.datapipe_adapter_fns is not None: for adapter_fn in self.datapipe_adapter_fns: self.datapipe = adapter_fn(self.datapipe) self._datapipe_before_reading_service_adapt = clone(self.datapipe) + def _restore_checkpoint_beginning_of_epoch(self) -> None: + r""" + At the beginning of each iteration (epoch), the initial state of randomness is automatically saved. + That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state + to that initial state. + + The common use case is to invoke this method after ``DataLoader2``'s state is restored (through + ``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch. + """ + self._seed_generator = self._initial_seed_generator + def _pause(self): if hasattr(self.reading_service, "_pause"): self._is_paused = True diff --git a/torchdata/dataloader2/random/seed_generator.py b/torchdata/dataloader2/random/seed_generator.py index 2f67ee2be..fa4fdfabc 100644 --- a/torchdata/dataloader2/random/seed_generator.py +++ b/torchdata/dataloader2/random/seed_generator.py @@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator": self._worker_rng = self._worker_rng.spawn(worker_id) return self return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id))) + + def __getstate__(self): + state = ( + self._shared_rng, + self._worker_rng, + ) + return state + + def __setstate__(self, state): + self._shared_rng, self._worker_rng = state diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 4109c05fe..d8a43e349 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -312,6 +312,8 @@ def initialize_iteration( ) -> Optional[Callable[[DataPipe], DataPipe]]: assert self._end_datapipe is not None + # Set random seeds for DataPipe that are in the main process (NOT those in worker processes) + # Worker seeds are set in `process_reset_fn` set_graph_random_seed(self._end_datapipe, seed_generator) if self._mp: