From 0bfd696bd6b774cc9577e3163bd2c41ca8ffdb09 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Wed, 15 Feb 2023 18:05:40 -0500 Subject: [PATCH] [PrototypeRS] Adding support for naive snapshotting ghstack-source-id: ad4f5ebbee3ec1e87d3b2b12a206f9d9ec70aa1e Pull Request resolved: https://github.com/pytorch/data/pull/915 --- test/dataloader2/test_proto_multi_rs.py | 41 ++++++++++++++++++++---- torchdata/dataloader2/dataloader2.py | 33 ++++++++++++++++++- torchdata/dataloader2/reading_service.py | 19 +++++++++++ 3 files changed, 86 insertions(+), 7 deletions(-) diff --git a/test/dataloader2/test_proto_multi_rs.py b/test/dataloader2/test_proto_multi_rs.py index 52c5c239e..ea773474c 100644 --- a/test/dataloader2/test_proto_multi_rs.py +++ b/test/dataloader2/test_proto_multi_rs.py @@ -276,12 +276,41 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr # cumulative_res.extend(res) # self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) - # TODO: Implemented in an upcoming PR - # def test_reading_service_snapshot(self) -> None: - # pass - # - # def test_dataloader2_snapshot(self) -> None: - # pass + def test_dataloader2_snapshot(self) -> None: + + rs1 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=0) + # rs2 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=2) + # rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0) + # rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0) + # rs5 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2) + # rs6 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2) + + n_samples_before_snapshot = 3 + + n_samples_yielded = 0 + initial_seed_rng = None + + test_rss = [rs1] + for rs in test_rss: + dl: DataLoader2 = DataLoader2(self.dp1, reading_service=rs) + res = [] + for i, x in enumerate(dl): + res.append(x) + if i in {n_samples_before_snapshot - 1}: + n_samples_yielded, initial_seed_rng = dl._get_naive_datapipe_snapshot() + break + dl.shutdown() + self.assertEqual(n_samples_before_snapshot, len(res)) + self.assertEqual(n_samples_before_snapshot, n_samples_yielded) + + dl_restored: DataLoader2 = DataLoader2(self.dp1, reading_service=rs) + dl_restored._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed_rng) + restored_res = list(dl_restored) + self.assertEqual(res, restored_res[0 : n_samples_before_snapshot - 1]) # Check if elements are the same + self.assertEqual(list(range(self.n_elements)), sorted(restored_res)) + dl_restored.shutdown() + + # TODO: Need to figure out the reset situation within `_simple_graph_snapshot_restoration` and ProtoRS instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService) diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 5c358e154..1a1921603 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -14,7 +14,11 @@ from torchdata.dataloader2.graph._serialization import clone, DataPipe, deserialize_datapipe, serialize_datapipe from torchdata.dataloader2.random import SeedGenerator from torchdata.dataloader2.random.seed_generator import _UINT64_UPPER_BOUND -from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface +from torchdata.dataloader2.reading_service import ( + CheckpointableReadingServiceInterface, + PrototypeMultiProcessingReadingService, + ReadingServiceInterface, +) T_co = TypeVar("T_co", covariant=True) SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe" @@ -214,6 +218,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]: if not self._adapted and self.reading_service is not None: if self.reading_service_state is None: + # Only called once when `self._adapted = False` self.datapipe = self.reading_service.initialize(self.datapipe) else: if not isinstance(self.reading_service, CheckpointableReadingServiceInterface): @@ -381,3 +386,29 @@ def _limit(self, num_batches: Optional[int]) -> None: self.reading_service._limit(num_batches) else: warnings.warn("ReadingService doesn't support `limit`.") + + def _get_naive_datapipe_snapshot(self): + """ + Return a snapshot of the DataPipe + """ + if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService): + raise RuntimeError( + "Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting." + ) + self._pause() + n_samples_yielded, _initial_seed = self.reading_service._get_naive_datapipe_snapshot() + self._resume() + return n_samples_yielded, _initial_seed + + def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed) -> None: + if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService): + raise RuntimeError( + "Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting." + ) + if not self._adapted: + self.datapipe = self.reading_service.initialize(self.datapipe) + self._adapted = True + self.reading_service._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed) + # TODO: I might want to skip `initialize_iteration` after this???? + + # TODO: Integrate this with the existing API? Is anyone using these at the moment? diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index cc8eaded0..e6f941fbe 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -19,6 +19,7 @@ import torch.multiprocessing as mp from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES +from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s from torchdata.dataloader2 import communication @@ -213,6 +214,7 @@ def __init__( self._main_prefetch_datapipe = None self._end_datapipe = None self._mp = num_workers > 0 + self._initial_seed = None def initialize(self, datapipe: DataPipe) -> DataPipe: r""" @@ -312,10 +314,18 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: def initialize_iteration( self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None ) -> Optional[Callable[[DataPipe], DataPipe]]: + + # TODO: Store the initial state of generator here + self._initial_seed = seed_generator + assert self._end_datapipe is not None set_graph_random_seed(self._end_datapipe, seed_generator) + assert self.end_datapipe is not None + + set_graph_random_seed(self.end_datapipe, seed_generator) + if self._mp: if self.main_prefetch_cnt > 0: # Stop prefetching first @@ -412,6 +422,15 @@ def _limit(self, num_batches: Optional[int]) -> None: """ pass + def _get_naive_datapipe_snapshot(self): + return self.end_datapipe._number_of_samples_yielded, self._initial_seed + + def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed): + initial_seed_generator = torch.Generator() + initial_seed_generator.manual_seed(initial_seed) + _simple_graph_snapshot_restoration(self.end_datapipe, n_samples_yielded, initial_seed_generator) + # TODO: I might want to skip `initialize_iteration` after this???? + class DistributedReadingService(ReadingServiceInterface): r"""