forked from pytorch/data
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SeedGenerator to DataLoader2 and enable re-seeding for ReadingSer…
…vice (pytorch#801) Summary: Pull Request resolved: pytorch#801 Add the initial support for DataLoader2 to control randomness over the pipeline: - Implement `SeedGenerator` - Change API of `ReadingService` to take seed generator from DataLoader2 Reviewed By: NivekT Differential Revision: D38947827 fbshipit-source-id: 21761db17cab2f1c9ef89058b6a53f53abe0590f
- Loading branch information
1 parent
b373bd2
commit d36a3ed
Showing
9 changed files
with
329 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
from torchdata.dataloader2.random import SeedGenerator | ||
from torchdata.dataloader2.random._philox import PhiloxEngine | ||
|
||
|
||
class TestRandom(unittest.TestCase): | ||
def test_philox_engine_generate(self): | ||
prng = PhiloxEngine() | ||
|
||
# Raise Error without a provided seed | ||
with self.assertRaises(AssertionError): | ||
prng.generate() | ||
|
||
# Same seed | ||
prng.seed(123) | ||
s0 = [prng.generate() for _ in range(10)] | ||
|
||
prng = PhiloxEngine(seed=123) | ||
s1 = [prng.generate() for _ in range(10)] | ||
|
||
self.assertEqual(s0, s1) | ||
|
||
# Reset | ||
prng.seed(123) | ||
s2 = [prng.generate() for _ in range(10)] | ||
|
||
self.assertEqual(s1, s2) | ||
|
||
# Different seeds | ||
prng = PhiloxEngine(seed=321) | ||
s3 = [prng.generate() for _ in range(10)] | ||
|
||
self.assertNotEqual(s0, s3) | ||
|
||
def test_philox_engine_spawn(self): | ||
prng = PhiloxEngine() | ||
|
||
# Raise Error without a provided seed | ||
with self.assertRaises(AssertionError): | ||
prng.spawn(0) | ||
|
||
# Same seed | ||
prng.seed(123) | ||
s0 = [prng.spawn(i) for i in range(10)] | ||
|
||
prng = PhiloxEngine(seed=123) | ||
s1 = [prng.spawn(i) for i in range(10)] | ||
|
||
self.assertEqual(s0, s1) | ||
|
||
# Reset | ||
prng.seed(123) | ||
s2 = [prng.spawn(i) for i in range(10)] | ||
|
||
self.assertEqual(s1, s2) | ||
|
||
# Different seeds | ||
prng = PhiloxEngine(seed=321) | ||
s3 = [prng.spawn(i) for i in range(10)] | ||
|
||
self.assertNotEqual(s0, s3) | ||
|
||
def test_seed_generator(self): | ||
# Generate seeds | ||
sg = SeedGenerator(123) | ||
|
||
# Same seed | ||
s0 = [sg.generate() for _ in range(10)] | ||
|
||
# Reset | ||
sg.seed(123) | ||
s1 = [sg.generate() for _ in range(10)] | ||
|
||
self.assertEqual(s0, s1) | ||
|
||
# Different Seeds | ||
sg.seed(321) | ||
s2 = [sg.generate() for _ in range(10)] | ||
|
||
self.assertNotEqual(s0, s2) | ||
|
||
# Spawn new Seed Generators | ||
sg1 = sg.spawn(1) | ||
sg2 = sg.spawn(2) | ||
|
||
for _ in range(10): | ||
self.assertNotEqual(sg1.generate(), sg2.generate()) | ||
|
||
sg1_1 = sg.spawn(1) | ||
sg1_2 = sg.spawn(1) | ||
for _ in range(10): | ||
self.assertEqual(sg1_1.generate(), sg1_2.generate()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchdata.dataloader2.random.seed_generator import SeedGenerator | ||
|
||
|
||
__all__ = ["SeedGenerator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List, Optional, Tuple | ||
|
||
# Note [Philox Engine implementation] | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf for details regarding the engine. | ||
# Using Philox4×32-10 for the sake of performance, randomness and crush-resistance. | ||
# The following code could be optimized into C++ bindings | ||
|
||
# Philox Constants | ||
kPhilox10A = 0x9E3779B9 | ||
kPhilox10B = 0xBB67AE85 | ||
kPhiloxSA = 0xD2511F53 | ||
kPhiloxSB = 0xCD9E8D57 | ||
|
||
MASK_32b = 0xFFFFFFFF | ||
MASK_64b = 0xFFFFFFFFFFFFFFFF | ||
HALF_UINT64 = 0x8000000000000000 | ||
|
||
|
||
def mulhilo32(a: int, b: int) -> Tuple[int, int]: | ||
product = a * b | ||
return product & MASK_32b, (product >> 32) & MASK_32b | ||
|
||
|
||
def single_round(key: List[int], ctr: List[int]) -> List[int]: | ||
lo0, hi0 = mulhilo32(kPhiloxSA, ctr[0]) | ||
lo1, hi1 = mulhilo32(kPhiloxSB, ctr[2]) | ||
res = [0] * 4 | ||
res[0] = hi1 ^ ctr[1] ^ key[0] | ||
res[1] = lo1 | ||
res[2] = hi0 ^ ctr[3] ^ key[1] | ||
res[3] = lo0 | ||
return res | ||
|
||
|
||
def philox_10_round(key: Tuple[int, int], ctr: List[int]) -> List[int]: | ||
_key = list(key) | ||
_ctr = list(ctr) | ||
for _ in range(9): | ||
_ctr = single_round(_key, _ctr) | ||
_key[0] = (_key[0] + kPhilox10A) & MASK_32b | ||
_key[1] = (_key[1] + kPhilox10B) & MASK_32b | ||
return single_round(_key, _ctr) | ||
|
||
|
||
class PhiloxEngine: | ||
r""" | ||
Philox is a counter-based RNG with a certain properties: | ||
- High performance | ||
- Statistiacl random | ||
- Crush-resistance Bijection | ||
Generate new seeds or spawn parallel seeds for worker processes. | ||
""" | ||
|
||
def __init__(self, seed: Optional[int] = None) -> None: | ||
self._seed: Tuple[int, int] = (-1, -1) | ||
self._ctr: List[int] = [0] * 4 | ||
self._generated_seeds: Optional[List[int]] = None | ||
self._spawn_seed: Tuple[int, int] = (-1, -1) | ||
if seed is not None: | ||
self.seed(seed) | ||
|
||
def _incr_ctr(self) -> None: | ||
for i in range(3): | ||
self._ctr[i] += 1 | ||
if self._ctr[i] <= MASK_32b: | ||
return | ||
self._ctr[i] = 0 | ||
self._ctr[3] += 1 | ||
# if overflow (2^128) has occurred during addition, back to the initial counter | ||
if self._ctr[3] > MASK_32b: | ||
self._ctr[3] = 0 | ||
self._incr_ctr() | ||
|
||
def seed(self, seed: int) -> "PhiloxEngine": | ||
seed = seed & MASK_64b | ||
# Convert seed from int64 to uint64 | ||
if seed < 0: | ||
seed = seed + HALF_UINT64 | ||
lo = seed & MASK_32b | ||
hi = (seed >> 32) & MASK_32b | ||
self._seed = (lo, hi) | ||
# Reset counter and cached seed | ||
self._ctr = [0] * 4 | ||
self._generated_seeds = None | ||
# Generate the spawn seed | ||
self._spawn_seed = tuple(philox_10_round(self._seed, self._ctr)[:2]) # type: ignore[assignment] | ||
self._incr_ctr() | ||
return self | ||
|
||
def generate(self) -> int: | ||
assert self._seed != (-1, -1) | ||
if self._generated_seeds is None: | ||
self._generated_seeds = philox_10_round(self._seed, self._ctr) | ||
self._incr_ctr() | ||
res = self._generated_seeds[:2] | ||
else: | ||
res = self._generated_seeds[2:] | ||
self._generated_seeds = None | ||
return (res[1] << 32) + res[0] | ||
|
||
def spawn(self, index: int) -> int: | ||
assert self._seed != (-1, -1) | ||
assert index >= 0 | ||
|
||
offset = index % 2 | ||
val = index if offset == 0 else index - 1 | ||
|
||
ctr = [] | ||
for _ in range(4): | ||
ctr.append(val & MASK_32b) | ||
val = val >> 32 | ||
|
||
res = philox_10_round(self._spawn_seed, ctr)[offset * 2 : offset * 2 + 2] | ||
return (res[1] << 32) + res[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from ._philox import PhiloxEngine | ||
|
||
|
||
class SeedGenerator: | ||
r""" | ||
SeedGenerator is used to generate seeds in a deterministic and randomized manner | ||
based on a user-provided initial seed. Internally, it utilizes a counter-based PRNG | ||
called Philox to generate random seeds. | ||
Args: | ||
seed: The base seed to generate random seeds | ||
""" | ||
|
||
def __init__(self, seed=None) -> None: | ||
self._rng: PhiloxEngine = PhiloxEngine() | ||
if seed is not None: | ||
self._rng.seed(seed) | ||
|
||
def seed(self, seed: int) -> None: | ||
self._rng.seed(seed) | ||
|
||
def generate(self) -> int: | ||
r""" | ||
Generate one uint64 random seed | ||
""" | ||
return self._rng.generate() | ||
|
||
def spawn(self, process_id: int) -> "SeedGenerator": | ||
r""" | ||
Spawn a sub-SeedGenerator based on the provided process_id | ||
""" | ||
return SeedGenerator(self._rng.spawn(process_id)) |
Oops, something went wrong.