Skip to content

Commit

Permalink
DataLoader2 initial support for randomness control (#801)
Browse files Browse the repository at this point in the history
Summary:
Fixes #885

Add the support for DataLoader2 to control randomness over the pipeline:
- Implement SeedGenerator
  - `spawn` to generate sub-SeedGenerators for distributed workers
  - `generate_seed` to generate unique seeds
  - `generate_shared_seed` to generate distributed shared seeds
- Change API of ReadingService to take seed generator from DataLoader2. Then, the SeedGenerator of DataLoader2 becomes the source of truth of randomness within the whole data pipeline.

A separate PR will be added for online doc regarding determinism.

Last step for #885

Pull Request resolved: #801

Reviewed By: NivekT

Differential Revision: D38947827

Pulled By: ejguan

fbshipit-source-id: 2f852b89cb1d638e1b9222df838786eb8855afa4
  • Loading branch information
ejguan authored and facebook-github-bot committed Jan 17, 2023
1 parent 139d558 commit bbe7a8c
Show file tree
Hide file tree
Showing 18 changed files with 564 additions and 164 deletions.
88 changes: 72 additions & 16 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
)
from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME

from torchdata.dataloader2.graph import DataPipe, replace_dp, traverse_dps
from torchdata.dataloader2.graph import DataPipe, list_dps, replace_dp, set_datapipes_seed, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingRoundRobinDispatcher
from torchdata.datapipes.map import SequenceWrapper

Expand Down Expand Up @@ -369,36 +370,29 @@ def _worker_init_fn(datapipe, worker_info):
return datapipe

@staticmethod
def _worker_reset_fn(datapipe, worker_info):
worker_seed_generator = torch.Generator()
worker_seed_generator.manual_seed(123)
torch.utils.data.graph_settings.apply_random_seed(
datapipe,
worker_seed_generator,
)
def _worker_reset_fn(datapipe, worker_info, worker_seed_generator: SeedGenerator):
graph = traverse_dps(datapipe)
dps = list_dps(graph)
worker_seed_generator.seed(123)
set_datapipes_seed(dps, seed_generator=worker_seed_generator, distributed_shared=True)
return datapipe

@mp_ctx_parametrize
def test_worker_fns(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100)).batch(2).shuffle()
torch.manual_seed(123)
exp = list(dp)

rs = PrototypeMultiProcessingReadingService(
num_workers=1,
num_workers=2,
multiprocessing_context=ctx,
worker_init_fn=self._worker_init_fn,
worker_reset_fn=self._worker_reset_fn,
)
dl = DataLoader2(dp, reading_service=rs)

# Test worker_init_fn to shard the DataPipe graph
res1 = list(dl)
self.assertEqual(exp, res1)

# Test worker_reset_fn to set the same random seed across epoches
res1 = list(dl)
res2 = list(dl)
self.assertEqual(exp, res2)
self.assertEqual(res1, res2)

@mp_ctx_parametrize
def test_single_branch_non_replicable(self, ctx):
Expand Down Expand Up @@ -536,6 +530,68 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))

@mp_ctx_parametrize
def test_multi_worker_determinism(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100))
dp = dp.shuffle().sharding_filter()
dp = dp.batch(2)

rs = PrototypeMultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)

torch.manual_seed(123)
res = list(dl) + list(dl)

torch.manual_seed(123)
self.assertEqual(res, list(dl) + list(dl))

torch.manual_seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

# Using seed API for DataLoader2
dl.seed(123)
res = list(dl) + list(dl)

dl.seed(123)
self.assertEqual(res, list(dl) + list(dl))

dl.seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

@mp_ctx_parametrize
def test_dispatching_worker_determinism(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100))
dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.batch(2)

rs = PrototypeMultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)

torch.manual_seed(123)
res = list(dl) + list(dl)

torch.manual_seed(123)
self.assertEqual(res, list(dl) + list(dl))

torch.manual_seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

# Using seed API for DataLoader2
dl.seed(123)
res = list(dl) + list(dl)

dl.seed(123)
self.assertEqual(res, list(dl) + list(dl))

dl.seed(321)
self.assertNotEqual(res, list(dl) + list(dl))


instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest)

Expand Down
10 changes: 5 additions & 5 deletions test/dataloader2/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import torch

from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
from torchdata.dataloader2.graph.settings import _set_worker_seed_for_dp_graph
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
from torchdata.dataloader2.graph.settings import set_graph_random_seed
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import IterableWrapper


Expand Down Expand Up @@ -90,9 +91,8 @@ def _get_dp_seeds_after_setting(worker_id, seed=123):
dp3_ = dp3.sharding_filter()
dp4 = dp1.zip(dp2, dp3_).shuffle()

rng = torch.Generator()
rng.manual_seed(seed)
_set_worker_seed_for_dp_graph(dp4, rng, worker_id)
sg = SeedGenerator(seed).spawn(worker_id)
set_graph_random_seed(dp4, sg)

# same seeds, different seeds
return (dp0._seed, dp3._seed), (dp2._seed, dp4._seed)
Expand Down
4 changes: 2 additions & 2 deletions test/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from unittest import TestCase

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import Shuffle
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from torchdata.datapipes.iter import IterableWrapper


class AdapterTest(TestCase):
Expand Down
6 changes: 5 additions & 1 deletion test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.utils.dispatch import (
_DummyIterDataPipe,
find_lca_non_replicable_dp,
Expand Down Expand Up @@ -63,7 +64,8 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:

return list(graph.values())[0][0]

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
seed_generator.seed(123)
for dp in self.adaptors:
dp.started = True

Expand Down Expand Up @@ -248,6 +250,8 @@ def test_reading_service(self) -> None:
for new_dp in rs.adaptors:
self.assertFalse(new_dp.started)

self.assertEqual(res, list(dl))

@unittest.skipIf(IS_WINDOWS, "Fork is required for lambda")
def test_multiprocessing_reading_service(self) -> None:
_, (*_, dp) = self._get_datapipes() # pyre-ignore
Expand Down
109 changes: 109 additions & 0 deletions test/test_seed_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.random._philox import PhiloxEngine


class TestPhilox(unittest.TestCase):
def test_philox_engine_generate(self):
prng = PhiloxEngine()
with self.assertRaisesRegex(AssertionError, "Please provide seed"):
prng.generate()

prng.seed(123)
s0 = [prng.generate() for _ in range(10)]

# Same seed
prng = PhiloxEngine(seed=123)
s1 = [prng.generate() for _ in range(10)]
self.assertEqual(s0, s1)

# Reset
prng.seed(123)
s2 = [prng.generate() for _ in range(10)]
self.assertEqual(s1, s2)

# Different seeds
prng = PhiloxEngine(seed=321)
s3 = [prng.generate() for _ in range(10)]
self.assertNotEqual(s0, s3)

def test_philox_engine_spawn(self):
prng = PhiloxEngine()
with self.assertRaisesRegex(AssertionError, "Expected a non-negative value"):
prng.spawn(-1)
with self.assertRaisesRegex(AssertionError, "Please provide seed"):
prng.spawn(0)

prng.seed(123)
s0 = [prng.spawn(i)._seed for i in range(10)]

# Same seed
prng = PhiloxEngine(seed=123)
s1 = [prng.spawn(i)._seed for i in range(10)]
self.assertEqual(s0, s1)

# Generate after spawn
sprng1 = prng.spawn(1)
sprng2 = prng.spawn(1)
ss1 = [sprng1.generate() for _ in range(10)]
ss2 = [sprng2.generate() for _ in range(10)]
self.assertEqual(ss1, ss2)

sprng3 = prng.spawn(2)
ss3 = [sprng3.generate() for _ in range(10)]
self.assertNotEqual(ss1, ss3)

# Reset
prng.seed(123)
s2 = [prng.spawn(i)._seed for i in range(10)]
self.assertEqual(s1, s2)

# Different seeds
prng = PhiloxEngine(seed=321)
s3 = [prng.spawn(i)._seed for i in range(10)]
self.assertNotEqual(s0, s3)


class TestSeedGenerator(unittest.TestCase):
def test_seed_generator_generate(self):
# Generate seeds
sg = SeedGenerator(123)
s0 = [sg.generate_seed() for _ in range(10)]

# Reset
sg.seed(123)
s1 = [sg.generate_seed() for _ in range(10)]
self.assertEqual(s0, s1)

# Different Seeds
sg.seed(321)
s2 = [sg.generate_seed() for _ in range(10)]
self.assertNotEqual(s0, s2)

def test_seed_generator_spawn(self):
sg = SeedGenerator(123)

# Spawn new Seed Generators
sg1 = sg.spawn(1)
sg2 = sg.spawn(2)

for _ in range(10):
self.assertNotEqual(sg1.generate_seed(), sg2.generate_seed())
# Generate shared seeds
self.assertEqual(sg1.generate_shared_seed(), sg2.generate_shared_seed())

sg1_1 = sg.spawn(1)
sg1_2 = sg.spawn(1)
for _ in range(10):
self.assertEqual(sg1_1.generate_seed(), sg1_2.generate_seed())


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
# LICENSE file in the root directory of this source tree.


from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)
from .shuffle_spec import ShuffleSpec
from torchdata.dataloader2.shuffle_spec import ShuffleSpec

__all__ = [
"CheckpointableReadingServiceInterface",
Expand Down
11 changes: 9 additions & 2 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.utils import WorkerInfo


DEFAULT_NON_BLOCKING_SLEEP = 0.001

__all__ = [
Expand Down Expand Up @@ -258,13 +260,18 @@ def reset(self):
for dp in self.datapipes:
dp.reset_iterator()

def reset_epoch(self, reset_fn: Callable[[WorkerInfo, DataPipe], DataPipe]):
def reset_epoch(
self, reset_fn: Callable[[WorkerInfo, SeedGenerator, DataPipe], DataPipe], seed_generator: SeedGenerator
):
for dp in self.datapipes:
dp.protocol.discard_existing_request()
num_workers = len(self.datapipes)
for worker_id, dp in enumerate(self.datapipes):
worker_info = WorkerInfo(num_workers, worker_id)
dp.protocol.request_reset_epoch(partial(reset_fn, worker_info=worker_info))
worker_seed_generator = seed_generator.spawn(worker_id)
dp.protocol.request_reset_epoch(
partial(reset_fn, worker_info=worker_info, seed_generator=worker_seed_generator)
)
while True:
try:
dp.protocol.get_response_reset_epoch()
Expand Down
Loading

0 comments on commit bbe7a8c

Please sign in to comment.