Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Saving and restoring initial seed generator #1123

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import multiprocessing as mp
import unittest
from unittest import TestCase
Expand All @@ -14,7 +13,7 @@
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


def _add_one(x: int) -> int:
Expand Down Expand Up @@ -46,6 +45,17 @@ def _dispatching_dp(n_elements=1000):
return dp


class NonShardableDataPipe(IterDataPipe):
def __init__(self, dp: IterDataPipe):
self.dp = dp

def is_replicable(self):
return False

def __iter__(self):
yield from self.dp


class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of MultiProcessingReadingService, notably
Expand All @@ -64,7 +74,7 @@ def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(10):
_ = next(it)
Expand All @@ -82,7 +92,7 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
_ = list(dl)
dl.shutdown()

Expand Down Expand Up @@ -248,6 +258,54 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle().sharding_filter()
# Note that the second `shuffle` occurs in the main process, which uses a different RNG from
# the `shuffle` done in the worker processes
dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type]
rs = MultiProcessingReadingService(num_workers=2)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
temp = next(it1) # Starts iterating
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()

self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result

dl.shutdown()
restored_dl.shutdown()

# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
# Currently, using sharding_round_robin raises a warning
# def test_round_robin_dispatching_pause_limit(self):
Expand Down
37 changes: 36 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pickle
import warnings

from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union
Expand All @@ -19,6 +19,7 @@
T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
RANDOMNESS_STATE_KEY_NAME = "randomness_state"


class DataLoader2Iterator(Iterator[T_co]):
Expand Down Expand Up @@ -176,6 +177,8 @@ def __init__(
self._seed_generator: SeedGenerator = SeedGenerator()
self._seed: Optional[int] = None
self._reset_seed: bool = True
# Seed generator as of beginning of each epoch
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)

def __iter__(self) -> DataLoader2Iterator[T_co]:
r"""
Expand All @@ -198,6 +201,9 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
else:
self._seed_generator.seed()

# Saving initial seed generator state
self._initial_seed_generator = clone(self._seed_generator)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand Down Expand Up @@ -269,10 +275,17 @@ def state_dict(self) -> Dict[str, Any]:

# Serialize datapipe after applying adapters and before reading service adaption
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
serialized_randomness_state = (
self._seed,
self._reset_seed,
pickle.dumps(self._seed_generator),
pickle.dumps(self._initial_seed_generator),
)

return {
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state,
}

@classmethod
Expand All @@ -294,6 +307,12 @@ def from_state(
reading_service=reading_service,
)
data_loader.reading_service_state = reading_service_state

randomness_state = state[RANDOMNESS_STATE_KEY_NAME]
data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1]
data_loader._seed_generator = pickle.loads(randomness_state[2])
data_loader._initial_seed_generator = pickle.loads(randomness_state[3])

return data_loader

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -320,12 +339,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME]
self._seed, self._reset_seed = randomness_state[0], randomness_state[1]
self._seed_generator = pickle.loads(randomness_state[2])
self._initial_seed_generator = pickle.loads(randomness_state[3])

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = clone(self.datapipe)

def _restore_checkpoint_beginning_of_epoch(self) -> None:
r"""
At the beginning of each iteration (epoch), the initial state of randomness is automatically saved.
That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state
to that initial state.

The common use case is to invoke this method after ``DataLoader2``'s state is restored (through
``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch.
"""
self._seed_generator = self._initial_seed_generator

def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
Expand Down
10 changes: 10 additions & 0 deletions torchdata/dataloader2/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,13 @@ def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator":
self._worker_rng = self._worker_rng.spawn(worker_id)
return self
return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id)))

def __getstate__(self):
state = (
self._shared_rng,
self._worker_rng,
)
return state

def __setstate__(self, state):
self._shared_rng, self._worker_rng = state
2 changes: 2 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def initialize_iteration(
) -> Optional[Callable[[DataPipe], DataPipe]]:
assert self._end_datapipe is not None

# Set random seeds for DataPipe that are in the main process (NOT those in worker processes)
# Worker seeds are set in `process_reset_fn`
set_graph_random_seed(self._end_datapipe, seed_generator)

if self._mp:
Expand Down