Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataLoader2] Saving and restoring initial seed generator #998

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
05c0bf1
[DataLoader2] Saving and restoring initial seed generator
NivekT Feb 8, 2023
45b6998
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 8, 2023
9d6f38d
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 9, 2023
389e567
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 9, 2023
955e412
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 10, 2023
90278bf
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 10, 2023
1a8ebdd
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 13, 2023
fa1f93f
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
4653af3
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
15f774e
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 15, 2023
5de743f
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 16, 2023
18a5c26
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 28, 2023
9f87c01
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Feb 28, 2023
3cdd2ea
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
ef850ed
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
b715066
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 17, 2023
5e56222
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 24, 2023
0433509
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 24, 2023
0d4854c
Update on "[DataLoader2] Saving and restoring initial seed generator"
NivekT Mar 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
clone,
DataPipe,
deserialize_datapipe,
deserialize_seed_generator,
serialize_datapipe,
serialize_seed_generator,
wrap_datapipe_for_serialization,
)
from torchdata.dataloader2.random import SeedGenerator
Expand All @@ -26,6 +28,7 @@
T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"
INITIAL_SEED_GEN_STATE_KEY_NAME = "initial_seed_gen_state"


@dataclass
Expand Down Expand Up @@ -187,6 +190,9 @@ def __init__(
self._seed_generator: SeedGenerator = SeedGenerator()
self._seed: Optional[int] = None
self._reset_seed: bool = True
# Seed generator as of beginning of each epoch
self._initial_seed_generator: SeedGenerator = clone(self._seed_generator)
self._skip_iteration_seeding: bool = False

def __iter__(self) -> DataLoader2Iterator[T_co]:
r"""
Expand All @@ -206,9 +212,14 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
if self._reset_seed:
self._seed_generator.seed(self._seed)
self._reset_seed = False
NivekT marked this conversation as resolved.
Show resolved Hide resolved
elif self._skip_iteration_seeding:
self._skip_iteration_seeding = False
else:
self._seed_generator.seed()

# Saving initial seed generator state
self._initial_seed_generator = clone(self._seed_generator)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand Down Expand Up @@ -240,6 +251,7 @@ def seed(self, seed: int) -> None:
raise ValueError(f"Expected an uint64 seed, but got {seed}.")
self._seed = seed
self._reset_seed = True
self._skip_iteration_seeding = False

def __del__(self) -> None:
self.shutdown()
Expand Down Expand Up @@ -280,17 +292,20 @@ def state_dict(self) -> Dict[str, Any]:

NivekT marked this conversation as resolved.
Show resolved Hide resolved
# Serialize datapipe after applying adapters and before reading service adaption
serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt)
serialized_initial_seed_generator = serialize_seed_generator(self._initial_seed_generator)

return {
SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe,
READING_SERVICE_STATE_KEY_NAME: reading_service_state,
INITIAL_SEED_GEN_STATE_KEY_NAME: serialized_initial_seed_generator,
}

@classmethod
def from_state(
cls,
state: Dict[str, Any],
reading_service: CheckpointableReadingServiceInterface,
restore_initial_seed_generator: bool = False,
) -> "DataLoader2[T_co]":
"""
Create new ``DataLoader2`` with ``DataPipe`` graph and ``ReadingService`` restored
Expand All @@ -305,9 +320,17 @@ def from_state(
reading_service=reading_service,
)
data_loader.reading_service_state = reading_service_state

if restore_initial_seed_generator:
serialized_generator = state[INITIAL_SEED_GEN_STATE_KEY_NAME]
deserialized_seed_generator = deserialize_seed_generator(serialized_generator)
assert deserialized_seed_generator is not None
data_loader._seed_generator = deserialized_seed_generator
data_loader._skip_iteration_seeding = True

return data_loader

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any], restore_initial_seed_generator: bool = False) -> None:
"""
For the existing ``DataLoader2``, load serialized state to restore ``DataPipe`` graph
and reset the internal state of ``ReadingService``.
Expand All @@ -331,6 +354,13 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.datapipe = deserialized_datapipe
self.reading_service_state = reading_service_state

if restore_initial_seed_generator:
serialized_generator = state_dict[INITIAL_SEED_GEN_STATE_KEY_NAME]
deserialized_seed_generator = deserialize_seed_generator(serialized_generator)
assert deserialized_seed_generator is not None
self._seed_generator = deserialized_seed_generator
self._skip_iteration_seeding = True

# re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
Expand Down
16 changes: 16 additions & 0 deletions torchdata/dataloader2/graph/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from torchdata.dataloader2.graph import DataPipe
from torchdata.dataloader2.random.seed_generator import SeedGenerator
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe

Expand All @@ -22,10 +23,25 @@
"clone",
"deserialize_datapipe",
"serialize_datapipe",
"serialize_seed_generator",
"wrap_datapipe_for_serialization",
]


def serialize_seed_generator(seed_generator: SeedGenerator) -> bytes:
try:
return pickle.dumps(seed_generator)
except pickle.PickleError as e:
raise RuntimeError(f"Seed generator should be pickle-able by default for checkpoint: {e}")


def deserialize_seed_generator(serialized_generator: bytes) -> SeedGenerator:
try:
return pickle.loads(serialized_generator)
except pickle.PickleError as e:
raise RuntimeError(f"Seed generator should be pickle-able by default for checkpoint: {e}")


def serialize_datapipe(datapipe: DataPipe) -> bytes:
try:
return pickle.dumps(datapipe)
Expand Down