From eb4cdc448b2580a28c6ae5982e759154eae2b847 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Thu, 29 Sep 2022 13:45:09 -0700 Subject: [PATCH] DataLoader2 initial support for randomness control (#801) Summary: Pull Request resolved: https://github.com/pytorch/data/pull/801 Add the initial support for DataLoader2 to control randomness over the pipeline: - Implement `SeedGenerator` - Change API of `ReadingService` to take seed generator from DataLoader2 Differential Revision: D38947827 fbshipit-source-id: 17db1e13fe8685f6b2817f72c0e199edfaf3a3a1 --- test/test_adapter.py | 4 +-- test/test_graph.py | 7 +++-- torchdata/dataloader2/__init__.py | 10 ++++--- torchdata/dataloader2/dataloader2.py | 15 ++++++++-- torchdata/dataloader2/random.py | 38 ++++++++++++++++++++++++ torchdata/dataloader2/reading_service.py | 26 +++++++++------- 6 files changed, 78 insertions(+), 22 deletions(-) create mode 100644 torchdata/dataloader2/random.py diff --git a/test/test_adapter.py b/test/test_adapter.py index 475fd5520..dfa2f1261 100644 --- a/test/test_adapter.py +++ b/test/test_adapter.py @@ -6,9 +6,9 @@ from unittest import TestCase -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface +from torchdata.dataloader2 import DataLoader2 from torchdata.dataloader2.adapter import Shuffle -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe +from torchdata.datapipes.iter import IterableWrapper class AdapterTest(TestCase): diff --git a/test/test_graph.py b/test/test_graph.py index 303eae681..54d2eb3c4 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -12,7 +12,7 @@ from _utils._common_utils_for_test import IS_WINDOWS from torch.utils.data import IterDataPipe -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface +from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface, SeedGenerator from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse from torchdata.datapipes.iter import IterableWrapper, Mapper @@ -47,7 +47,8 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe: return list(graph.values())[0][0] - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: + seed_generator.seed(123) for dp in self.adaptors: dp.started = True @@ -181,10 +182,12 @@ def test_reading_service(self) -> None: rs = TempReadingService() dl = DataLoader2(dp, reading_service=rs) + dl._seed_generator.seed(0) self.assertTrue(len(rs.adaptors) == 0) it = iter(dl) + self.assertEqual(dl._seed_generator._rng.initial_seed(), 123) for new_dp in rs.adaptors: self.assertTrue(new_dp.started) diff --git a/torchdata/dataloader2/__init__.py b/torchdata/dataloader2/__init__.py index 949c0bd31..05676a558 100644 --- a/torchdata/dataloader2/__init__.py +++ b/torchdata/dataloader2/__init__.py @@ -5,15 +5,16 @@ # LICENSE file in the root directory of this source tree. -from .dataloader2 import DataLoader2, DataLoader2Iterator -from .error import PauseIteration -from .reading_service import ( +from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator +from torchdata.dataloader2.error import PauseIteration +from torchdata.dataloader2.random import SeedGenerator +from torchdata.dataloader2.reading_service import ( DistributedReadingService, MultiProcessingReadingService, PrototypeMultiProcessingReadingService, ReadingServiceInterface, ) -from .shuffle_spec import ShuffleSpec +from torchdata.dataloader2.shuffle_spec import ShuffleSpec __all__ = [ "DataLoader2", @@ -23,6 +24,7 @@ "PauseIteration", "PrototypeMultiProcessingReadingService", "ReadingServiceInterface", + "SeedGenerator", "ShuffleSpec", ] diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 555839e76..e21913f1b 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -9,12 +9,14 @@ from dataclasses import dataclass from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union +import torch from torch.utils.data.graph import DataPipe from torchdata.dataloader2.adapter import Adapter +from torchdata.dataloader2.error import PauseIteration +from torchdata.dataloader2.random import SeedGenerator +from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface -from .error import PauseIteration -from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface T_co = TypeVar("T_co", covariant=True) SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe" @@ -107,11 +109,18 @@ def __init__( self.datapipe = adapter_fn(self.datapipe) self._datapipe_before_reading_service_adapt: DataPipe = self.datapipe + self._seed_generator = SeedGenerator() + def __iter__(self) -> Iterator[T_co]: if self._terminated: raise Exception("Cannot iterate over the DataLoader as it has already been shut down") if self._reset_iter: + + # TODO(ejguan): Provide an exclusive API to seed DataLoader + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._seed_generator.seed(seed) + if not self._adapted and self.reading_service is not None: if self.reading_service_state is None: self.datapipe = self.reading_service.initialize(self.datapipe) @@ -122,7 +131,7 @@ def __iter__(self) -> Iterator[T_co]: self._adapted = True if self.reading_service is not None: - self.reading_service.initialize_iteration() + self.reading_service.initialize_iteration(self._seed_generator) self._datapipe_iter = iter(self.datapipe) self._reset_iter = False diff --git a/torchdata/dataloader2/random.py b/torchdata/dataloader2/random.py new file mode 100644 index 000000000..40d6cd45b --- /dev/null +++ b/torchdata/dataloader2/random.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + + +__all__ = ["SeedGenerator"] + + +class SeedGenerator: + def __init__(self, seed=None) -> None: + self._rng = torch.Generator() + if seed is not None: + self._rng.manual_seed(seed) + + def seed(self, seed: int) -> None: + self._rng.manual_seed(seed) + + def generate(self) -> int: + return int(torch.empty((), dtype=torch.int64).random_(generator=self._rng).item()) + + def spawn(self, process_id: int) -> "SeedGenerator": + return SeedGenerator() + + def __getstate__(self): + state = {"_rng_state": self._rng.get_state} + return state + + def __setstate__(self, state): + self._rng = torch.Generator() + rng_state = state["_rng_state"] + self._rng.set_state(rng_state) diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 951393f34..6692501d5 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -21,6 +21,7 @@ from torchdata._constants import default_timeout_in_s from torchdata.dataloader2 import communication +from torchdata.dataloader2.random import SeedGenerator from torchdata.datapipes.iter import FullSync, IterableWrapper @@ -56,7 +57,7 @@ def finalize(self) -> None: """ pass - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: """ ReadingService spin up service. Called at the beginning of every time getting DataLoader iterator. @@ -170,7 +171,8 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value] - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: + # TODO(615): Set seeds for MPRS: 1 shared seed and 1 different seed per worker for dp in self.datapipes: dp.reset_iterator() @@ -283,24 +285,26 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: self._datapipe = datapipe return datapipe - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: r""" Shares the same seed from rank 0 to other ranks across the distributed processes and apply the random seed to the graph of ``DataPipe``. """ - # TODO: Seed Generator should be moved to DataLoader2 after the API - # change of initialize_iteration is landed. - seed = self._share_seed() - _seed_generator = torch.Generator() - _seed_generator.manual_seed(seed) + shared_seed = self._share_seed(seed_generator.generate()) + seed_generator.seed(shared_seed) + + rng = torch.Generator() + rng.manual_seed(seed_generator.generate()) assert self._datapipe is not None + # TODO(ejguan): Set the same seed to random ops before sharding + # but different seeds to random ops after sharding self._datapipe = torch.utils.data.graph_settings.apply_random_seed( self._datapipe, - _seed_generator, + rng, ) - def _share_seed(self): - shared_seed = torch.empty((), dtype=torch.int64).random_() + def _share_seed(self, seed: int): + shared_seed = torch.tensor(seed, dtype=torch.int64) dist.broadcast(shared_seed, src=0, group=self._pg) return shared_seed.item()