From d36a3ed6f2d655bcde408eba8d15c2404775ab80 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Thu, 6 Oct 2022 09:37:16 -0700 Subject: [PATCH] Add SeedGenerator to DataLoader2 and enable re-seeding for ReadingService (#801) Summary: Pull Request resolved: https://github.com/pytorch/data/pull/801 Add the initial support for DataLoader2 to control randomness over the pipeline: - Implement `SeedGenerator` - Change API of `ReadingService` to take seed generator from DataLoader2 Reviewed By: NivekT Differential Revision: D38947827 fbshipit-source-id: 21761db17cab2f1c9ef89058b6a53f53abe0590f --- test/test_adapter.py | 4 +- test/test_graph.py | 7 +- test/test_random.py | 103 +++++++++++++++ torchdata/dataloader2/__init__.py | 10 +- torchdata/dataloader2/dataloader2.py | 20 ++- torchdata/dataloader2/random/__init__.py | 10 ++ torchdata/dataloader2/random/_philox.py | 122 ++++++++++++++++++ .../dataloader2/random/seed_generator.py | 38 ++++++ torchdata/dataloader2/reading_service.py | 42 ++++-- 9 files changed, 329 insertions(+), 27 deletions(-) create mode 100644 test/test_random.py create mode 100644 torchdata/dataloader2/random/__init__.py create mode 100644 torchdata/dataloader2/random/_philox.py create mode 100644 torchdata/dataloader2/random/seed_generator.py diff --git a/test/test_adapter.py b/test/test_adapter.py index 475fd5520..dfa2f1261 100644 --- a/test/test_adapter.py +++ b/test/test_adapter.py @@ -6,9 +6,9 @@ from unittest import TestCase -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface +from torchdata.dataloader2 import DataLoader2 from torchdata.dataloader2.adapter import Shuffle -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe +from torchdata.datapipes.iter import IterableWrapper class AdapterTest(TestCase): diff --git a/test/test_graph.py b/test/test_graph.py index 41aa53d30..8655c957e 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -12,7 +12,7 @@ from _utils._common_utils_for_test import IS_WINDOWS from torch.utils.data import IterDataPipe -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface +from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface, SeedGenerator from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps from torchdata.datapipes.iter import IterableWrapper, Mapper @@ -47,7 +47,8 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe: return list(graph.values())[0][0] - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: + seed_generator.seed(123) for dp in self.adaptors: dp.started = True @@ -181,10 +182,12 @@ def test_reading_service(self) -> None: rs = TempReadingService() dl = DataLoader2(dp, reading_service=rs) + dl._seed_generator.seed(0) self.assertTrue(len(rs.adaptors) == 0) it = iter(dl) + self.assertEqual(dl._seed_generator._rng._seed[0], 123) for new_dp in rs.adaptors: self.assertTrue(new_dp.started) diff --git a/test/test_random.py b/test/test_random.py new file mode 100644 index 000000000..82410bccf --- /dev/null +++ b/test/test_random.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from torchdata.dataloader2.random import SeedGenerator +from torchdata.dataloader2.random._philox import PhiloxEngine + + +class TestRandom(unittest.TestCase): + def test_philox_engine_generate(self): + prng = PhiloxEngine() + + # Raise Error without a provided seed + with self.assertRaises(AssertionError): + prng.generate() + + # Same seed + prng.seed(123) + s0 = [prng.generate() for _ in range(10)] + + prng = PhiloxEngine(seed=123) + s1 = [prng.generate() for _ in range(10)] + + self.assertEqual(s0, s1) + + # Reset + prng.seed(123) + s2 = [prng.generate() for _ in range(10)] + + self.assertEqual(s1, s2) + + # Different seeds + prng = PhiloxEngine(seed=321) + s3 = [prng.generate() for _ in range(10)] + + self.assertNotEqual(s0, s3) + + def test_philox_engine_spawn(self): + prng = PhiloxEngine() + + # Raise Error without a provided seed + with self.assertRaises(AssertionError): + prng.spawn(0) + + # Same seed + prng.seed(123) + s0 = [prng.spawn(i) for i in range(10)] + + prng = PhiloxEngine(seed=123) + s1 = [prng.spawn(i) for i in range(10)] + + self.assertEqual(s0, s1) + + # Reset + prng.seed(123) + s2 = [prng.spawn(i) for i in range(10)] + + self.assertEqual(s1, s2) + + # Different seeds + prng = PhiloxEngine(seed=321) + s3 = [prng.spawn(i) for i in range(10)] + + self.assertNotEqual(s0, s3) + + def test_seed_generator(self): + # Generate seeds + sg = SeedGenerator(123) + + # Same seed + s0 = [sg.generate() for _ in range(10)] + + # Reset + sg.seed(123) + s1 = [sg.generate() for _ in range(10)] + + self.assertEqual(s0, s1) + + # Different Seeds + sg.seed(321) + s2 = [sg.generate() for _ in range(10)] + + self.assertNotEqual(s0, s2) + + # Spawn new Seed Generators + sg1 = sg.spawn(1) + sg2 = sg.spawn(2) + + for _ in range(10): + self.assertNotEqual(sg1.generate(), sg2.generate()) + + sg1_1 = sg.spawn(1) + sg1_2 = sg.spawn(1) + for _ in range(10): + self.assertEqual(sg1_1.generate(), sg1_2.generate()) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchdata/dataloader2/__init__.py b/torchdata/dataloader2/__init__.py index 949c0bd31..05676a558 100644 --- a/torchdata/dataloader2/__init__.py +++ b/torchdata/dataloader2/__init__.py @@ -5,15 +5,16 @@ # LICENSE file in the root directory of this source tree. -from .dataloader2 import DataLoader2, DataLoader2Iterator -from .error import PauseIteration -from .reading_service import ( +from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator +from torchdata.dataloader2.error import PauseIteration +from torchdata.dataloader2.random import SeedGenerator +from torchdata.dataloader2.reading_service import ( DistributedReadingService, MultiProcessingReadingService, PrototypeMultiProcessingReadingService, ReadingServiceInterface, ) -from .shuffle_spec import ShuffleSpec +from torchdata.dataloader2.shuffle_spec import ShuffleSpec __all__ = [ "DataLoader2", @@ -23,6 +24,7 @@ "PauseIteration", "PrototypeMultiProcessingReadingService", "ReadingServiceInterface", + "SeedGenerator", "ShuffleSpec", ] diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 7740adb20..fbd710f02 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -9,12 +9,13 @@ from dataclasses import dataclass from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union -from torchdata.dataloader2.adapter import Adapter - -from torchdata.dataloader2.graph import DataPipe +import torch +from torch.utils.data.graph import DataPipe -from .error import PauseIteration -from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface +from torchdata.dataloader2.adapter import Adapter +from torchdata.dataloader2.error import PauseIteration +from torchdata.dataloader2.random import SeedGenerator +from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface T_co = TypeVar("T_co", covariant=True) SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe" @@ -107,11 +108,18 @@ def __init__( self.datapipe = adapter_fn(self.datapipe) self._datapipe_before_reading_service_adapt: DataPipe = self.datapipe + self._seed_generator = SeedGenerator() + def __iter__(self) -> Iterator[T_co]: if self._terminated: raise Exception("Cannot iterate over the DataLoader as it has already been shut down") if self._reset_iter: + + # TODO(ejguan): Provide an exclusive API to seed DataLoader + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._seed_generator.seed(seed) + if not self._adapted and self.reading_service is not None: if self.reading_service_state is None: self.datapipe = self.reading_service.initialize(self.datapipe) @@ -122,7 +130,7 @@ def __iter__(self) -> Iterator[T_co]: self._adapted = True if self.reading_service is not None: - self.reading_service.initialize_iteration() + self.reading_service.initialize_iteration(self._seed_generator) self._datapipe_iter = iter(self.datapipe) self._reset_iter = False diff --git a/torchdata/dataloader2/random/__init__.py b/torchdata/dataloader2/random/__init__.py new file mode 100644 index 000000000..0b3e5aa0b --- /dev/null +++ b/torchdata/dataloader2/random/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchdata.dataloader2.random.seed_generator import SeedGenerator + + +__all__ = ["SeedGenerator"] diff --git a/torchdata/dataloader2/random/_philox.py b/torchdata/dataloader2/random/_philox.py new file mode 100644 index 000000000..b014eee7c --- /dev/null +++ b/torchdata/dataloader2/random/_philox.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple + +# Note [Philox Engine implementation] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf for details regarding the engine. +# Using Philox4×32-10 for the sake of performance, randomness and crush-resistance. +# The following code could be optimized into C++ bindings + +# Philox Constants +kPhilox10A = 0x9E3779B9 +kPhilox10B = 0xBB67AE85 +kPhiloxSA = 0xD2511F53 +kPhiloxSB = 0xCD9E8D57 + +MASK_32b = 0xFFFFFFFF +MASK_64b = 0xFFFFFFFFFFFFFFFF +HALF_UINT64 = 0x8000000000000000 + + +def mulhilo32(a: int, b: int) -> Tuple[int, int]: + product = a * b + return product & MASK_32b, (product >> 32) & MASK_32b + + +def single_round(key: List[int], ctr: List[int]) -> List[int]: + lo0, hi0 = mulhilo32(kPhiloxSA, ctr[0]) + lo1, hi1 = mulhilo32(kPhiloxSB, ctr[2]) + res = [0] * 4 + res[0] = hi1 ^ ctr[1] ^ key[0] + res[1] = lo1 + res[2] = hi0 ^ ctr[3] ^ key[1] + res[3] = lo0 + return res + + +def philox_10_round(key: Tuple[int, int], ctr: List[int]) -> List[int]: + _key = list(key) + _ctr = list(ctr) + for _ in range(9): + _ctr = single_round(_key, _ctr) + _key[0] = (_key[0] + kPhilox10A) & MASK_32b + _key[1] = (_key[1] + kPhilox10B) & MASK_32b + return single_round(_key, _ctr) + + +class PhiloxEngine: + r""" + Philox is a counter-based RNG with a certain properties: + - High performance + - Statistiacl random + - Crush-resistance Bijection + + Generate new seeds or spawn parallel seeds for worker processes. + """ + + def __init__(self, seed: Optional[int] = None) -> None: + self._seed: Tuple[int, int] = (-1, -1) + self._ctr: List[int] = [0] * 4 + self._generated_seeds: Optional[List[int]] = None + self._spawn_seed: Tuple[int, int] = (-1, -1) + if seed is not None: + self.seed(seed) + + def _incr_ctr(self) -> None: + for i in range(3): + self._ctr[i] += 1 + if self._ctr[i] <= MASK_32b: + return + self._ctr[i] = 0 + self._ctr[3] += 1 + # if overflow (2^128) has occurred during addition, back to the initial counter + if self._ctr[3] > MASK_32b: + self._ctr[3] = 0 + self._incr_ctr() + + def seed(self, seed: int) -> "PhiloxEngine": + seed = seed & MASK_64b + # Convert seed from int64 to uint64 + if seed < 0: + seed = seed + HALF_UINT64 + lo = seed & MASK_32b + hi = (seed >> 32) & MASK_32b + self._seed = (lo, hi) + # Reset counter and cached seed + self._ctr = [0] * 4 + self._generated_seeds = None + # Generate the spawn seed + self._spawn_seed = tuple(philox_10_round(self._seed, self._ctr)[:2]) # type: ignore[assignment] + self._incr_ctr() + return self + + def generate(self) -> int: + assert self._seed != (-1, -1) + if self._generated_seeds is None: + self._generated_seeds = philox_10_round(self._seed, self._ctr) + self._incr_ctr() + res = self._generated_seeds[:2] + else: + res = self._generated_seeds[2:] + self._generated_seeds = None + return (res[1] << 32) + res[0] + + def spawn(self, index: int) -> int: + assert self._seed != (-1, -1) + assert index >= 0 + + offset = index % 2 + val = index if offset == 0 else index - 1 + + ctr = [] + for _ in range(4): + ctr.append(val & MASK_32b) + val = val >> 32 + + res = philox_10_round(self._spawn_seed, ctr)[offset * 2 : offset * 2 + 2] + return (res[1] << 32) + res[0] diff --git a/torchdata/dataloader2/random/seed_generator.py b/torchdata/dataloader2/random/seed_generator.py new file mode 100644 index 000000000..bdaa62485 --- /dev/null +++ b/torchdata/dataloader2/random/seed_generator.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._philox import PhiloxEngine + + +class SeedGenerator: + r""" + SeedGenerator is used to generate seeds in a deterministic and randomized manner + based on a user-provided initial seed. Internally, it utilizes a counter-based PRNG + called Philox to generate random seeds. + + Args: + seed: The base seed to generate random seeds + """ + + def __init__(self, seed=None) -> None: + self._rng: PhiloxEngine = PhiloxEngine() + if seed is not None: + self._rng.seed(seed) + + def seed(self, seed: int) -> None: + self._rng.seed(seed) + + def generate(self) -> int: + r""" + Generate one uint64 random seed + """ + return self._rng.generate() + + def spawn(self, process_id: int) -> "SeedGenerator": + r""" + Spawn a sub-SeedGenerator based on the provided process_id + """ + return SeedGenerator(self._rng.spawn(process_id)) diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 10860fccb..36c7e6027 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -21,6 +21,7 @@ from torchdata._constants import default_timeout_in_s from torchdata.dataloader2 import communication from torchdata.dataloader2.graph import DataPipe +from torchdata.dataloader2.random import SeedGenerator from torchdata.datapipes.iter import FullSync, IterableWrapper, IterDataPipe @@ -56,13 +57,19 @@ def finalize(self) -> None: """ pass - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: """ ReadingService spin up service. Called at the beginning of every time getting DataLoader iterator. + Args: + seed_generator: SeedGenerator object created and managed by DataLoader2. As the single + source of randomness, it will governs the determinism for all of random operations + with the graph of DataPipes. + Example: - MultiProcessingReadingService starts prefetching items from the graph. + MultiProcessingReadingService starts setting worker seeds per process and prefetching + items from the graph. """ pass @@ -189,7 +196,8 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: return _IterateQueueDataPipes(self.datapipes) # type: ignore[return-value] - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: + # TODO(615): Set seeds for MPRS: 1 shared seed and 1 different seed per worker pass def __del__(self): @@ -260,6 +268,9 @@ def finalize(self) -> None: self.dl_._iterator = None +_HALF_UINT64 = 0x8000000000000000 + + class DistributedReadingService(ReadingServiceInterface): r""" ``DistributedReadingSerivce`` handles distributed sharding on the graph of ``DataPipe`` and @@ -301,26 +312,31 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: self._datapipe = datapipe return datapipe - def initialize_iteration(self) -> None: + def initialize_iteration(self, seed_generator: SeedGenerator) -> None: r""" Shares the same seed from rank 0 to other ranks across the distributed processes and apply the random seed to the graph of ``DataPipe``. """ - # TODO: Seed Generator should be moved to DataLoader2 after the API - # change of initialize_iteration is landed. - seed = self._share_seed() - _seed_generator = torch.Generator() - _seed_generator.manual_seed(seed) + shared_seed = self._share_seed(seed_generator.generate()) + seed_generator.seed(shared_seed) + + dp_graph_rng = torch.Generator() + dp_graph_rng.manual_seed(seed_generator.generate()) assert self._datapipe is not None + # TODO(ejguan): Set the same seed to random ops before sharding + # but different seeds to random ops after sharding self._datapipe = torch.utils.data.graph_settings.apply_random_seed( self._datapipe, - _seed_generator, + dp_graph_rng, ) - def _share_seed(self): - shared_seed = torch.empty((), dtype=torch.int64).random_() + def _share_seed(self, seed: int): + # Convert uint64 to int64 to prevent overflow for integer Tensor + seed -= _HALF_UINT64 + shared_seed = torch.tensor(seed, dtype=torch.int64) dist.broadcast(shared_seed, src=0, group=self._pg) - return shared_seed.item() + # Revert int64 back to uint64 + return shared_seed.item() + _HALF_UINT64 def finalize(self) -> None: r"""