Skip to content

Commit

Permalink
[PrototypeRS] Adding support for naive snapshotting
Browse files Browse the repository at this point in the history
ghstack-source-id: ad4f5ebbee3ec1e87d3b2b12a206f9d9ec70aa1e
Pull Request resolved: #915
  • Loading branch information
NivekT committed Feb 15, 2023
1 parent 7e8fba3 commit 0bfd696
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
41 changes: 35 additions & 6 deletions test/dataloader2/test_proto_multi_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,41 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# cumulative_res.extend(res)
# self.assertEqual(list(range(n_elements)), sorted(cumulative_res))

# TODO: Implemented in an upcoming PR
# def test_reading_service_snapshot(self) -> None:
# pass
#
# def test_dataloader2_snapshot(self) -> None:
# pass
def test_dataloader2_snapshot(self) -> None:

rs1 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs5 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs6 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)

n_samples_before_snapshot = 3

n_samples_yielded = 0
initial_seed_rng = None

test_rss = [rs1]
for rs in test_rss:
dl: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
res = []
for i, x in enumerate(dl):
res.append(x)
if i in {n_samples_before_snapshot - 1}:
n_samples_yielded, initial_seed_rng = dl._get_naive_datapipe_snapshot()
break
dl.shutdown()
self.assertEqual(n_samples_before_snapshot, len(res))
self.assertEqual(n_samples_before_snapshot, n_samples_yielded)

dl_restored: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
dl_restored._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed_rng)
restored_res = list(dl_restored)
self.assertEqual(res, restored_res[0 : n_samples_before_snapshot - 1]) # Check if elements are the same
self.assertEqual(list(range(self.n_elements)), sorted(restored_res))
dl_restored.shutdown()

# TODO: Need to figure out the reset situation within `_simple_graph_snapshot_restoration` and ProtoRS


instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService)
Expand Down
33 changes: 32 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from torchdata.dataloader2.graph._serialization import clone, DataPipe, deserialize_datapipe, serialize_datapipe
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.random.seed_generator import _UINT64_UPPER_BOUND
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)

T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
Expand Down Expand Up @@ -214,6 +218,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
# Only called once when `self._adapted = False`
self.datapipe = self.reading_service.initialize(self.datapipe)
else:
if not isinstance(self.reading_service, CheckpointableReadingServiceInterface):
Expand Down Expand Up @@ -381,3 +386,29 @@ def _limit(self, num_batches: Optional[int]) -> None:
self.reading_service._limit(num_batches)
else:
warnings.warn("ReadingService doesn't support `limit`.")

def _get_naive_datapipe_snapshot(self):
"""
Return a snapshot of the DataPipe
"""
if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService):
raise RuntimeError(
"Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
self._pause()
n_samples_yielded, _initial_seed = self.reading_service._get_naive_datapipe_snapshot()
self._resume()
return n_samples_yielded, _initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed) -> None:
if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService):
raise RuntimeError(
"Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
if not self._adapted:
self.datapipe = self.reading_service.initialize(self.datapipe)
self._adapted = True
self.reading_service._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed)
# TODO: I might want to skip `initialize_iteration` after this????

# TODO: Integrate this with the existing API? Is anyone using these at the moment?
19 changes: 19 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.multiprocessing as mp

from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration

from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s
from torchdata.dataloader2 import communication
Expand Down Expand Up @@ -213,6 +214,7 @@ def __init__(
self._main_prefetch_datapipe = None
self._end_datapipe = None
self._mp = num_workers > 0
self._initial_seed = None

def initialize(self, datapipe: DataPipe) -> DataPipe:
r"""
Expand Down Expand Up @@ -312,10 +314,18 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
def initialize_iteration(
self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None
) -> Optional[Callable[[DataPipe], DataPipe]]:

# TODO: Store the initial state of generator here
self._initial_seed = seed_generator

assert self._end_datapipe is not None

set_graph_random_seed(self._end_datapipe, seed_generator)

assert self.end_datapipe is not None

set_graph_random_seed(self.end_datapipe, seed_generator)

if self._mp:
if self.main_prefetch_cnt > 0:
# Stop prefetching first
Expand Down Expand Up @@ -412,6 +422,15 @@ def _limit(self, num_batches: Optional[int]) -> None:
"""
pass

def _get_naive_datapipe_snapshot(self):
return self.end_datapipe._number_of_samples_yielded, self._initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed):
initial_seed_generator = torch.Generator()
initial_seed_generator.manual_seed(initial_seed)
_simple_graph_snapshot_restoration(self.end_datapipe, n_samples_yielded, initial_seed_generator)
# TODO: I might want to skip `initialize_iteration` after this????


class DistributedReadingService(ReadingServiceInterface):
r"""
Expand Down

0 comments on commit 0bfd696

Please sign in to comment.