Skip to content

Commit

Permalink
Add SeedGenerator to DataLoader2 and enable re-seeding for ReadingSer…
Browse files Browse the repository at this point in the history
…vice (#801)

Summary:
Pull Request resolved: #801

Add the initial support for DataLoader2 to control randomness over the pipeline:
- Implement `SeedGenerator`
- Change API of `ReadingService` to take seed generator from DataLoader2

Differential Revision: D38947827

fbshipit-source-id: fab10a21fecf76e9b5f5c2296fbf930c3af14d2d
  • Loading branch information
ejguan authored and facebook-github-bot committed Oct 4, 2022
1 parent 9ad8efb commit e0db329
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 27 deletions.
4 changes: 2 additions & 2 deletions test/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from unittest import TestCase

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import Shuffle
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from torchdata.datapipes.iter import IterableWrapper


class AdapterTest(TestCase):
Expand Down
7 changes: 5 additions & 2 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from _utils._common_utils_for_test import IS_WINDOWS
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface, SeedGenerator
from torchdata.dataloader2.graph import find_dps, remove_dp, replace_dp, traverse_dps
from torchdata.datapipes.iter import IterableWrapper, Mapper

Expand Down Expand Up @@ -47,7 +47,8 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:

return list(graph.values())[0][0]

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
seed_generator.seed(123)
for dp in self.adaptors:
dp.started = True

Expand Down Expand Up @@ -181,10 +182,12 @@ def test_reading_service(self) -> None:

rs = TempReadingService()
dl = DataLoader2(dp, reading_service=rs)
dl._seed_generator.seed(0)

self.assertTrue(len(rs.adaptors) == 0)

it = iter(dl)
self.assertEqual(dl._seed_generator._rng._seed[0], 123)
for new_dp in rs.adaptors:
self.assertTrue(new_dp.started)

Expand Down
103 changes: 103 additions & 0 deletions test/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.random._philox import PhiloxEngine


class TestRandom(unittest.TestCase):
def test_philox_engine_generate(self):
prng = PhiloxEngine()

# Raise Error without a provided seed
with self.assertRaises(AssertionError):
prng.generate()

# Same seed
prng.seed(123)
s0 = [prng.generate() for _ in range(10)]

prng = PhiloxEngine(seed=123)
s1 = [prng.generate() for _ in range(10)]

self.assertEqual(s0, s1)

# Reset
prng.seed(123)
s2 = [prng.generate() for _ in range(10)]

self.assertEqual(s1, s2)

# Different seeds
prng = PhiloxEngine(seed=321)
s3 = [prng.generate() for _ in range(10)]

self.assertNotEqual(s0, s3)

def test_philox_engine_spawn(self):
prng = PhiloxEngine()

# Raise Error without a provided seed
with self.assertRaises(AssertionError):
prng.spawn(0)

# Same seed
prng.seed(123)
s0 = [prng.spawn(i) for i in range(10)]

prng = PhiloxEngine(seed=123)
s1 = [prng.spawn(i) for i in range(10)]

self.assertEqual(s0, s1)

# Reset
prng.seed(123)
s2 = [prng.spawn(i) for i in range(10)]

self.assertEqual(s1, s2)

# Different seeds
prng = PhiloxEngine(seed=321)
s3 = [prng.spawn(i) for i in range(10)]

self.assertNotEqual(s0, s3)

def test_seed_generator(self):
# Generate seeds
sg = SeedGenerator(123)

# Same seed
s0 = [sg.generate() for _ in range(10)]

# Reset
sg.seed(123)
s1 = [sg.generate() for _ in range(10)]

self.assertEqual(s0, s1)

# Different Seeds
sg.seed(321)
s2 = [sg.generate() for _ in range(10)]

self.assertNotEqual(s0, s2)

# Spawn new Seed Generators
sg1 = sg.spawn(1)
sg2 = sg.spawn(2)

for _ in range(10):
self.assertNotEqual(sg1.generate(), sg2.generate())

sg1_1 = sg.spawn(1)
sg1_2 = sg.spawn(1)
for _ in range(10):
self.assertEqual(sg1_1.generate(), sg1_2.generate())


if __name__ == "__main__":
unittest.main()
10 changes: 6 additions & 4 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
# LICENSE file in the root directory of this source tree.


from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.reading_service import (
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)
from .shuffle_spec import ShuffleSpec
from torchdata.dataloader2.shuffle_spec import ShuffleSpec

__all__ = [
"DataLoader2",
Expand All @@ -23,6 +24,7 @@
"PauseIteration",
"PrototypeMultiProcessingReadingService",
"ReadingServiceInterface",
"SeedGenerator",
"ShuffleSpec",
]

Expand Down
20 changes: 14 additions & 6 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torchdata.dataloader2.adapter import Adapter

from torchdata.dataloader2.graph import DataPipe
import torch
from torch.utils.data.graph import DataPipe

from .error import PauseIteration
from .reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface
from torchdata.dataloader2.adapter import Adapter
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface

T_co = TypeVar("T_co", covariant=True)
SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe"
Expand Down Expand Up @@ -107,11 +108,18 @@ def __init__(
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt: DataPipe = self.datapipe

self._seed_generator = SeedGenerator()

def __iter__(self) -> Iterator[T_co]:
if self._terminated:
raise Exception("Cannot iterate over the DataLoader as it has already been shut down")

if self._reset_iter:

# TODO(ejguan): Provide an exclusive API to seed DataLoader
seed = int(torch.empty((), dtype=torch.int64).random_().item())
self._seed_generator.seed(seed)

if not self._adapted and self.reading_service is not None:
if self.reading_service_state is None:
self.datapipe = self.reading_service.initialize(self.datapipe)
Expand All @@ -122,7 +130,7 @@ def __iter__(self) -> Iterator[T_co]:
self._adapted = True

if self.reading_service is not None:
self.reading_service.initialize_iteration()
self.reading_service.initialize_iteration(self._seed_generator)

self._datapipe_iter = iter(self.datapipe)
self._reset_iter = False
Expand Down
10 changes: 10 additions & 0 deletions torchdata/dataloader2/random/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchdata.dataloader2.random.seed_generator import SeedGenerator


__all__ = ["SeedGenerator"]
122 changes: 122 additions & 0 deletions torchdata/dataloader2/random/_philox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple

# Note [Philox Engine implementation]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf for details regarding the engine.
# Using Philox4×32-10 for the sake of performance, randomness and crush-resistance.
# The following code could be optimized into C++ bindings

# Philox Constants
kPhilox10A = 0x9E3779B9
kPhilox10B = 0xBB67AE85
kPhiloxSA = 0xD2511F53
kPhiloxSB = 0xCD9E8D57

MASK_32b = 0xFFFFFFFF
MASK_64b = 0xFFFFFFFFFFFFFFFF
HALF_UINT64 = 0x8000000000000000


def mulhilo32(a: int, b: int) -> Tuple[int, int]:
product = a * b
return product & MASK_32b, (product >> 32) & MASK_32b


def single_round(key: List[int], ctr: List[int]) -> List[int]:
lo0, hi0 = mulhilo32(kPhiloxSA, ctr[0])
lo1, hi1 = mulhilo32(kPhiloxSB, ctr[2])
res = [0] * 4
res[0] = hi1 ^ ctr[1] ^ key[0]
res[1] = lo1
res[2] = hi0 ^ ctr[3] ^ key[1]
res[3] = lo0
return res


def philox_10_round(key: Tuple[int, int], ctr: List[int]) -> List[int]:
_key = list(key)
_ctr = list(ctr)
for _ in range(9):
_ctr = single_round(_key, _ctr)
_key[0] = (_key[0] + kPhilox10A) & MASK_32b
_key[1] = (_key[1] + kPhilox10B) & MASK_32b
return single_round(_key, _ctr)


class PhiloxEngine:
r"""
Philox is a counter-based RNG with a certain properties:
- High performance
- Statistiacl random
- Crush-resistance Bijection
Generate new seeds or spawn parallel seeds for worker processes.
"""

def __init__(self, seed: Optional[int] = None) -> None:
self._seed: Tuple[int, int] = (-1, -1)
self._ctr: List[int] = [0] * 4
self._generated_seeds: Optional[List[int]] = None
self._spawn_seed: Optional[Tuple[int, int]] = None
if seed is not None:
self.seed(seed)

def _incr_ctr(self) -> None:
for i in range(3):
self._ctr[i] += 1
if self._ctr[i] <= MASK_32b:
return
self._ctr[i] = 0
self._ctr[3] += 1
# if overflow (2^128) has occurred during addition, back to the initial counter
if self._ctr[3] > MASK_32b:
self._ctr[3] = 0
self._incr_ctr()

def seed(self, seed: int) -> "PhiloxEngine":
seed = seed & MASK_64b
# Convert seed from int64 to uint64
if seed < 0:
seed = seed + HALF_UINT64
lo = seed & MASK_32b
hi = (seed >> 32) & MASK_32b
self._seed = (lo, hi)
# Reset counter and cached seed
self._ctr: List[int] = [0] * 4
self._generated_seeds = None
# Generate the spawn seed
self._spawn_seed = tuple(philox_10_round(self._seed, self._ctr)[:2])
self._incr_ctr()
return self

def generate(self) -> int:
assert self._seed != (-1, -1)
if self._generated_seeds is None:
self._generated_seeds = philox_10_round(self._seed, self._ctr)
self._incr_ctr()
res = self._generated_seeds[:2]
else:
res = self._generated_seeds[2:]
self._generated_seeds = None
return (res[1] << 32) + res[0]

def spawn(self, index: int) -> int:
assert self._seed != (-1, -1)
assert index >= 0

offset = index % 2
val = index if offset == 0 else index - 1

ctr = []
for _ in range(4):
ctr.append(val & MASK_32b)
val = val >> 32

res = philox_10_round(self._spawn_seed, ctr)[offset * 2 : offset * 2 + 2]
return (res[1] << 32) + res[0]
38 changes: 38 additions & 0 deletions torchdata/dataloader2/random/seed_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._philox import PhiloxEngine


class SeedGenerator:
r"""
SeedGenerator is used to generate seeds in a deterministic and randomized manner
based on a user-provided initial seed. Internally, it utilizes a counter-based PRNG
called Philox to generate random seeds.
Args:
seed: The base seed to generate random seeds
"""

def __init__(self, seed=None) -> None:
self._rng: PhiloxEngine = PhiloxEngine()
if seed is not None:
self._rng.seed(seed)

def seed(self, seed: int) -> None:
self._rng.seed(seed)

def generate(self) -> int:
r"""
Generate one uint64 random seed
"""
return self._rng.generate()

def spawn(self, process_id: int) -> "SeedGenerator":
r"""
Generate one uint64 random seed based on the provided process_id
"""
return SeedGenerator(self._rng.spawn(process_id))
Loading

0 comments on commit e0db329

Please sign in to comment.