Skip to content

Commit

Permalink
[PrototypeRS] Adding support for naive snapshotting
Browse files Browse the repository at this point in the history
ghstack-source-id: 713a9b208cf7596f0f81c305396e4c27ef51e950
Pull Request resolved: #915
  • Loading branch information
NivekT committed Jan 18, 2023
1 parent cc3e90f commit 8c85df6
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 7 deletions.
41 changes: 35 additions & 6 deletions test/dataloader2/test_proto_multi_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,41 @@ def test_reading_service_limit(self) -> None:
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
)

# TODO: Implemented in an upcoming PR
# def test_reading_service_snapshot(self) -> None:
# pass
#
# def test_dataloader2_snapshot(self) -> None:
# pass
def test_dataloader2_snapshot(self) -> None:

rs1 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs2 = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
# rs5 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
# rs6 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)

n_samples_before_snapshot = 3

n_samples_yielded = 0
initial_seed_rng = None

test_rss = [rs1]
for rs in test_rss:
dl: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
res = []
for i, x in enumerate(dl):
res.append(x)
if i in {n_samples_before_snapshot - 1}:
n_samples_yielded, initial_seed_rng = dl._get_naive_datapipe_snapshot()
break
dl.shutdown()
self.assertEqual(n_samples_before_snapshot, len(res))
self.assertEqual(n_samples_before_snapshot, n_samples_yielded)

dl_restored: DataLoader2 = DataLoader2(self.dp1, reading_service=rs)
dl_restored._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed_rng)
restored_res = list(dl_restored)
self.assertEqual(res, restored_res[0 : n_samples_before_snapshot - 1]) # Check if elements are the same
self.assertEqual(list(range(self.n_elements)), sorted(restored_res))
dl_restored.shutdown()

# TODO: Need to figure out the reset situation within `_simple_graph_snapshot_restoration` and ProtoRS


if __name__ == "__main__":
Expand Down
34 changes: 33 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
serialize_datapipe,
wrap_datapipe_for_serialization,
)
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)

T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
Expand Down Expand Up @@ -203,6 +207,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
if self._reset_iter:
if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
# Only called once when `self._adapted = False`
self.datapipe = self.reading_service.initialize(self.datapipe)
else:
if not isinstance(self.reading_service, CheckpointableReadingServiceInterface):
Expand Down Expand Up @@ -328,3 +333,30 @@ def resume(self):
else:
warnings.warn("ReadingService doesn't support resume.")
self._is_paused = False

def _get_naive_datapipe_snapshot(self):
"""
Return a snapshot of the DataPipe
"""
if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService):
raise RuntimeError(
"Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
self.pause()
n_samples_yielded, _initial_seed = self.reading_service._get_naive_datapipe_snapshot()
self.resume()
return n_samples_yielded, _initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed) -> None:
if not isinstance(self.reading_service, PrototypeMultiProcessingReadingService):
raise RuntimeError(
"Only `PrototypeMultiProcessingReadingService` " "currently supports naive DataPipe snapshotting."
)
if not self._adapted:
self.datapipe = self.reading_service.initialize(self.datapipe)
self._adapted = True
self.reading_service._restore_naive_datapipe_snapshot(n_samples_yielded, initial_seed)
# TODO: I might want to skip `initialize_iteration` after this????

# TODO: Integrate this with the existing API? Is anyone using these at the moment?
>>>>>>> 99bf9cf5 ([PrototypeRS] Adding support for naive snapshotting)
12 changes: 12 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES
from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration

from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s
from torchdata.dataloader2 import communication
Expand Down Expand Up @@ -189,6 +190,7 @@ def __init__(
self._pg = None
self._world_size = 1
self._rank = 0
self._initial_seed = None

def initialize(self, datapipe: DataPipe) -> DataPipe:
r"""
Expand Down Expand Up @@ -277,6 +279,7 @@ def initialize_iteration(self) -> None:
shared_seed_int: int = shared_seed.item() # type: ignore[assignment]
_seed_generator = torch.Generator()
_seed_generator.manual_seed(shared_seed_int)
self._initial_seed = shared_seed_int
torch.utils.data.graph_settings.apply_random_seed(
self.end_datapipe, # type: ignore[arg-type]
_seed_generator,
Expand Down Expand Up @@ -394,6 +397,15 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self.end_datapipe.resume() # type: ignore[union-attr]

def _get_naive_datapipe_snapshot(self):
return self.end_datapipe._number_of_samples_yielded, self._initial_seed

def _restore_naive_datapipe_snapshot(self, n_samples_yielded, initial_seed):
initial_seed_generator = torch.Generator()
initial_seed_generator.manual_seed(initial_seed)
_simple_graph_snapshot_restoration(self.end_datapipe, n_samples_yielded, initial_seed_generator)
# TODO: I might want to skip `initialize_iteration` after this????


class MultiProcessingReadingService(ReadingServiceInterface):
r"""
Expand Down

0 comments on commit 8c85df6

Please sign in to comment.