Skip to content

Commit

Permalink
DataLoader2 initial support for randomness control (#801)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #801

Add the initial support for DataLoader2 to control randomness over the pipeline:
- Implement `SeedGenerator`
- Change API of `ReadingService` to take seed generator from DataLoader2

Differential Revision: D38947827

fbshipit-source-id: 17db1e13fe8685f6b2817f72c0e199edfaf3a3a1
  • Loading branch information
ejguan authored and facebook-github-bot committed Sep 29, 2022
1 parent a3e1427 commit eb4cdc4
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 22 deletions.
4 changes: 2 additions & 2 deletions test/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from unittest import TestCase

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import Shuffle
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from torchdata.datapipes.iter import IterableWrapper


class AdapterTest(TestCase):
Expand Down
7 changes: 5 additions & 2 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from _utils._common_utils_for_test import IS_WINDOWS
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface, SeedGenerator
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse
from torchdata.datapipes.iter import IterableWrapper, Mapper

Expand Down Expand Up @@ -47,7 +47,8 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:

return list(graph.values())[0][0]

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
seed_generator.seed(123)
for dp in self.adaptors:
dp.started = True

Expand Down Expand Up @@ -181,10 +182,12 @@ def test_reading_service(self) -> None:

rs = TempReadingService()
dl = DataLoader2(dp, reading_service=rs)
dl._seed_generator.seed(0)

self.assertTrue(len(rs.adaptors) == 0)

it = iter(dl)
self.assertEqual(dl._seed_generator._rng.initial_seed(), 123)
for new_dp in rs.adaptors:
self.assertTrue(new_dp.started)

Expand Down
10 changes: 6 additions & 4 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
# LICENSE file in the root directory of this source tree.


from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.reading_service import (
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)
from .shuffle_spec import ShuffleSpec
from torchdata.dataloader2.shuffle_spec import ShuffleSpec

__all__ = [
"DataLoader2",
Expand All @@ -23,6 +24,7 @@
"PauseIteration",
"PrototypeMultiProcessingReadingService",
"ReadingServiceInterface",
"SeedGenerator",
"ShuffleSpec",
]

Expand Down
15 changes: 12 additions & 3 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

import torch
from torch.utils.data.graph import DataPipe

from torchdata.dataloader2.adapter import Adapter
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface

from .error import PauseIteration
from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface

T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
Expand Down Expand Up @@ -107,11 +109,18 @@ def __init__(
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt: DataPipe = self.datapipe

self._seed_generator = SeedGenerator()

def __iter__(self) -> Iterator[T_co]:
if self._terminated:
raise Exception("Cannot iterate over the DataLoader as it has already been shut down")

if self._reset_iter:

# TODO(ejguan): Provide an exclusive API to seed DataLoader
seed = int(torch.empty((), dtype=torch.int64).random_().item())
self._seed_generator.seed(seed)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand All @@ -122,7 +131,7 @@ def __iter__(self) -> Iterator[T_co]:
self._adapted = True

if self.reading_service is not None:
self.reading_service.initialize_iteration()
self.reading_service.initialize_iteration(self._seed_generator)

self._datapipe_iter = iter(self.datapipe)
self._reset_iter = False
Expand Down
38 changes: 38 additions & 0 deletions torchdata/dataloader2/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional

import torch


__all__ = ["SeedGenerator"]


class SeedGenerator:
def __init__(self, seed=None) -> None:
self._rng = torch.Generator()
if seed is not None:
self._rng.manual_seed(seed)

def seed(self, seed: int) -> None:
self._rng.manual_seed(seed)

def generate(self) -> int:
return int(torch.empty((), dtype=torch.int64).random_(generator=self._rng).item())

def spawn(self, process_id: int) -> "SeedGenerator":
return SeedGenerator()

def __getstate__(self):
state = {"_rng_state": self._rng.get_state}
return state

def __setstate__(self, state):
self._rng = torch.Generator()
rng_state = state["_rng_state"]
self._rng.set_state(rng_state)
26 changes: 15 additions & 11 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from torchdata._constants import default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import FullSync, IterableWrapper


Expand Down Expand Up @@ -56,7 +57,7 @@ def finalize(self) -> None:
"""
pass

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
"""
ReadingService spin up service.
Called at the beginning of every time getting DataLoader iterator.
Expand Down Expand Up @@ -170,7 +171,8 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:

return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value]

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
# TODO(615): Set seeds for MPRS: 1 shared seed and 1 different seed per worker
for dp in self.datapipes:
dp.reset_iterator()

Expand Down Expand Up @@ -283,24 +285,26 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
self._datapipe = datapipe
return datapipe

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
r"""
Shares the same seed from rank 0 to other ranks across the distributed processes
and apply the random seed to the graph of ``DataPipe``.
"""
# TODO: Seed Generator should be moved to DataLoader2 after the API
# change of initialize_iteration is landed.
seed = self._share_seed()
_seed_generator = torch.Generator()
_seed_generator.manual_seed(seed)
shared_seed = self._share_seed(seed_generator.generate())
seed_generator.seed(shared_seed)

rng = torch.Generator()
rng.manual_seed(seed_generator.generate())
assert self._datapipe is not None
# TODO(ejguan): Set the same seed to random ops before sharding
# but different seeds to random ops after sharding
self._datapipe = torch.utils.data.graph_settings.apply_random_seed(
self._datapipe,
_seed_generator,
rng,
)

def _share_seed(self):
shared_seed = torch.empty((), dtype=torch.int64).random_()
def _share_seed(self, seed: int):
shared_seed = torch.tensor(seed, dtype=torch.int64)
dist.broadcast(shared_seed, src=0, group=self._pg)
return shared_seed.item()

Expand Down

0 comments on commit eb4cdc4

Please sign in to comment.