Skip to content

Commit

Permalink
[PrototypeRS] Adding support for naive snapshotting
Browse files Browse the repository at this point in the history
ghstack-source-id: b4fc8fdbdcd40f52280603a29df82d20b8061de5
Pull Request resolved: #915
  • Loading branch information
NivekT committed Mar 17, 2023
1 parent 562e539 commit ea406b3
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 12 deletions.
41 changes: 35 additions & 6 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,41 @@ def test_initial_epoch_checkpointing(self):
# cumulative_res.extend(res)
# self.assertEqual(list(range(n_elements)), sorted(cumulative_res))

# TODO: Implemented in an upcoming PR
# def test_reading_service_snapshot(self) -> None:
# pass
#
# def test_dataloader2_snapshot(self) -> None:
# pass
def test_dataloader2_snapshot(self) -> None:

rs1 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs5 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs6 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)

n_samples_before_snapshot = 3

n_samples_yielded = 0
initial_seed_rng = None

test_rss = [rs1]
for rs in test_rss:
dl: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
res = []
for i, x in enumerate(dl):
res.append(x)
if i in {n_samples_before_snapshot - 1}:
n_samples_yielded, initial_seed_rng = dl._get_naive_datapipe_snapshot()
break
dl.shutdown()
self.assertEqual(n_samples_before_snapshot, len(res))
self.assertEqual(n_samples_before_snapshot, n_samples_yielded)

dl_restored: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
dl_restored._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed_rng)
restored_res = list(dl_restored)
self.assertEqual(res, restored_res[0 : n_samples_before_snapshot - 1]) # Check if elements are the same
self.assertEqual(list(range(self.n_elements)), sorted(restored_res))
dl_restored.shutdown()

# TODO: Need to figure out the reset situation within `_simple_graph_snapshot_restoration` and ProtoRS


instantiate_parametrized_tests(TestMultiProcessingReadingService)
Expand Down
44 changes: 40 additions & 4 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torchdata.dataloader2.graph._serialization import clone, DataPipe, deserialize_datapipe, serialize_datapipe
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.random.seed_generator import _UINT64_UPPER_BOUND
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
MultiProcessingReadingService,
ReadingServiceInterface,
)

T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
Expand Down Expand Up @@ -214,6 +218,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
# Only called once when `self._adapted = False`
self.datapipe = self.reading_service.initialize(self.datapipe)
else:
if not isinstance(self.reading_service, CheckpointableReadingServiceInterface):
Expand Down Expand Up @@ -379,6 +384,37 @@ def _resume(self):
else:
self.reading_service._resume()
self._is_paused = False
# TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used
elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"):
warnings.warn("ReadingService doesn't support resume.")
else:
warnings.warn("ReadingService doesn't support `resume`.")

def _limit(self, num_batches: Optional[int]) -> None:
if hasattr(self.reading_service, "_limit"):
self.reading_service._limit(num_batches)
else:
warnings.warn("ReadingService doesn't support `limit`.")

def _get_naive_datapipe_snapshot(self):
"""
Return a snapshot of the DataPipe
"""
if not isinstance(self.reading_service, MultiProcessingReadingService):
raise RuntimeError(
"Only `MultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
self._pause()
n_samples_yielded, _initial_seed = self.reading_service._get_naive_datapipe_snapshot()
self._resume()
return n_samples_yielded, _initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed) -> None:
if not isinstance(self.reading_service, MultiProcessingReadingService):
raise RuntimeError(
"Only `MultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
if not self._adapted:
self.datapipe = self.reading_service.initialize(self.datapipe)
self._adapted = True
self.reading_service._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed)
# TODO: I might want to skip `initialize_iteration` after this????

# TODO: Integrate this with the existing API? Is anyone using these at the moment?
18 changes: 16 additions & 2 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.multiprocessing as mp

from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration

from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s
from torchdata.dataloader2 import communication
Expand Down Expand Up @@ -227,8 +228,6 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
self._end_datapipe = datapipe
return datapipe

graph = traverse_dps(datapipe)

ctx = mp.get_context(self.multiprocessing_context)

# Launch dispatching process for the lowest common ancestor of non-replicable DataPipes
Expand Down Expand Up @@ -406,6 +405,21 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self._main_prefetch_datapipe.resume() # type: ignore[union-attr]

def _limit(self, num_batches: Optional[int]) -> None:
"""
For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle
the limit operation, such that nothing needs to be done here.
"""
pass

def _get_naive_datapipe_snapshot(self):
return 0 if self._end_datapipe is None else self._end_datapipe._number_of_samples_yielded

def _restore_naive_datapipe_snapshot(self, initial_seed_generator: SeedGenerator, n_samples_yielded):
assert self._end_datapipe is not None # `self.initialize()` needs to be called prior
_simple_graph_snapshot_restoration(self._end_datapipe, n_samples_yielded, initial_seed_generator)
# TODO: I might want to skip `initialize_iteration` after this????


class DistributedReadingService(ReadingServiceInterface):
r"""
Expand Down

0 comments on commit ea406b3

Please sign in to comment.