Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3/n] DataLoader2 initial support for randomness control #801

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 72 additions & 16 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
)
from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME

from torchdata.dataloader2.graph import DataPipe, replace_dp, traverse_dps
from torchdata.dataloader2.graph import DataPipe, list_dps, replace_dp, set_datapipes_seed, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingRoundRobinDispatcher
from torchdata.datapipes.map import SequenceWrapper

Expand Down Expand Up @@ -369,36 +370,29 @@ def _worker_init_fn(datapipe, worker_info):
return datapipe

@staticmethod
def _worker_reset_fn(datapipe, worker_info):
worker_seed_generator = torch.Generator()
worker_seed_generator.manual_seed(123)
torch.utils.data.graph_settings.apply_random_seed(
datapipe,
worker_seed_generator,
)
def _worker_reset_fn(datapipe, worker_info, worker_seed_generator: SeedGenerator):
graph = traverse_dps(datapipe)
dps = list_dps(graph)
worker_seed_generator.seed(123)
set_datapipes_seed(dps, seed_generator=worker_seed_generator, distributed_shared=True)
return datapipe

@mp_ctx_parametrize
def test_worker_fns(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100)).batch(2).shuffle()
torch.manual_seed(123)
exp = list(dp)

rs = PrototypeMultiProcessingReadingService(
num_workers=1,
num_workers=2,
multiprocessing_context=ctx,
worker_init_fn=self._worker_init_fn,
worker_reset_fn=self._worker_reset_fn,
)
dl = DataLoader2(dp, reading_service=rs)

# Test worker_init_fn to shard the DataPipe graph
res1 = list(dl)
self.assertEqual(exp, res1)

# Test worker_reset_fn to set the same random seed across epoches
res1 = list(dl)
res2 = list(dl)
self.assertEqual(exp, res2)
self.assertEqual(res1, res2)

@mp_ctx_parametrize
def test_single_branch_non_replicable(self, ctx):
Expand Down Expand Up @@ -536,6 +530,68 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))

@mp_ctx_parametrize
def test_multi_worker_determinism(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100))
dp = dp.shuffle().sharding_filter()
dp = dp.batch(2)

rs = PrototypeMultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)

torch.manual_seed(123)
res = list(dl) + list(dl)

torch.manual_seed(123)
self.assertEqual(res, list(dl) + list(dl))

torch.manual_seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

# Using seed API for DataLoader2
dl.seed(123)
res = list(dl) + list(dl)

dl.seed(123)
self.assertEqual(res, list(dl) + list(dl))

dl.seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

@mp_ctx_parametrize
def test_dispatching_worker_determinism(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100))
dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.batch(2)

rs = PrototypeMultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)

torch.manual_seed(123)
res = list(dl) + list(dl)

torch.manual_seed(123)
self.assertEqual(res, list(dl) + list(dl))

torch.manual_seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

# Using seed API for DataLoader2
dl.seed(123)
res = list(dl) + list(dl)

dl.seed(123)
self.assertEqual(res, list(dl) + list(dl))

dl.seed(321)
self.assertNotEqual(res, list(dl) + list(dl))


instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest)

Expand Down
10 changes: 5 additions & 5 deletions test/dataloader2/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import torch

from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
from torchdata.dataloader2.graph.settings import _set_worker_seed_for_dp_graph
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
from torchdata.dataloader2.graph.settings import set_graph_random_seed
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import IterableWrapper


Expand Down Expand Up @@ -90,9 +91,8 @@ def _get_dp_seeds_after_setting(worker_id, seed=123):
dp3_ = dp3.sharding_filter()
dp4 = dp1.zip(dp2, dp3_).shuffle()

rng = torch.Generator()
rng.manual_seed(seed)
_set_worker_seed_for_dp_graph(dp4, rng, worker_id)
sg = SeedGenerator(seed).spawn(worker_id)
set_graph_random_seed(dp4, sg)

# same seeds, different seeds
return (dp0._seed, dp3._seed), (dp2._seed, dp4._seed)
Expand Down
4 changes: 2 additions & 2 deletions test/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from unittest import TestCase

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import Shuffle
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from torchdata.datapipes.iter import IterableWrapper


class AdapterTest(TestCase):
Expand Down
6 changes: 5 additions & 1 deletion test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.utils.dispatch import (
_DummyIterDataPipe,
find_lca_non_replicable_dp,
Expand Down Expand Up @@ -63,7 +64,8 @@ def initialize(self, datapipe: IterDataPipe) -> IterDataPipe:

return list(graph.values())[0][0]

def initialize_iteration(self) -> None:
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
seed_generator.seed(123)
for dp in self.adaptors:
dp.started = True

Expand Down Expand Up @@ -248,6 +250,8 @@ def test_reading_service(self) -> None:
for new_dp in rs.adaptors:
self.assertFalse(new_dp.started)

self.assertEqual(res, list(dl))

@unittest.skipIf(IS_WINDOWS, "Fork is required for lambda")
def test_multiprocessing_reading_service(self) -> None:
_, (*_, dp) = self._get_datapipes() # pyre-ignore
Expand Down
109 changes: 109 additions & 0 deletions test/test_seed_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.random._philox import PhiloxEngine


class TestPhilox(unittest.TestCase):
def test_philox_engine_generate(self):
prng = PhiloxEngine()
with self.assertRaisesRegex(AssertionError, "Please provide seed"):
prng.generate()

prng.seed(123)
s0 = [prng.generate() for _ in range(10)]

# Same seed
prng = PhiloxEngine(seed=123)
s1 = [prng.generate() for _ in range(10)]
self.assertEqual(s0, s1)

# Reset
prng.seed(123)
s2 = [prng.generate() for _ in range(10)]
self.assertEqual(s1, s2)

# Different seeds
prng = PhiloxEngine(seed=321)
s3 = [prng.generate() for _ in range(10)]
self.assertNotEqual(s0, s3)

def test_philox_engine_spawn(self):
prng = PhiloxEngine()
with self.assertRaisesRegex(AssertionError, "Expected a non-negative value"):
prng.spawn(-1)
with self.assertRaisesRegex(AssertionError, "Please provide seed"):
prng.spawn(0)

prng.seed(123)
s0 = [prng.spawn(i)._seed for i in range(10)]

# Same seed
prng = PhiloxEngine(seed=123)
s1 = [prng.spawn(i)._seed for i in range(10)]
self.assertEqual(s0, s1)

# Generate after spawn
sprng1 = prng.spawn(1)
sprng2 = prng.spawn(1)
ss1 = [sprng1.generate() for _ in range(10)]
ss2 = [sprng2.generate() for _ in range(10)]
self.assertEqual(ss1, ss2)

sprng3 = prng.spawn(2)
ss3 = [sprng3.generate() for _ in range(10)]
self.assertNotEqual(ss1, ss3)

# Reset
prng.seed(123)
s2 = [prng.spawn(i)._seed for i in range(10)]
self.assertEqual(s1, s2)

# Different seeds
prng = PhiloxEngine(seed=321)
s3 = [prng.spawn(i)._seed for i in range(10)]
self.assertNotEqual(s0, s3)


class TestSeedGenerator(unittest.TestCase):
def test_seed_generator_generate(self):
# Generate seeds
sg = SeedGenerator(123)
s0 = [sg.generate_seed() for _ in range(10)]

# Reset
sg.seed(123)
s1 = [sg.generate_seed() for _ in range(10)]
self.assertEqual(s0, s1)

# Different Seeds
sg.seed(321)
s2 = [sg.generate_seed() for _ in range(10)]
self.assertNotEqual(s0, s2)

def test_seed_generator_spawn(self):
sg = SeedGenerator(123)

# Spawn new Seed Generators
sg1 = sg.spawn(1)
sg2 = sg.spawn(2)

for _ in range(10):
self.assertNotEqual(sg1.generate_seed(), sg2.generate_seed())
# Generate shared seeds
self.assertEqual(sg1.generate_shared_seed(), sg2.generate_shared_seed())

sg1_1 = sg.spawn(1)
sg1_2 = sg.spawn(1)
for _ in range(10):
self.assertEqual(sg1_1.generate_seed(), sg1_2.generate_seed())


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
# LICENSE file in the root directory of this source tree.


from .dataloader2 import DataLoader2, DataLoader2Iterator
from .error import PauseIteration
from .reading_service import (
from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator
from torchdata.dataloader2.error import PauseIteration
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
)
from .shuffle_spec import ShuffleSpec
from torchdata.dataloader2.shuffle_spec import ShuffleSpec

__all__ = [
"CheckpointableReadingServiceInterface",
Expand Down
11 changes: 9 additions & 2 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.utils import WorkerInfo


DEFAULT_NON_BLOCKING_SLEEP = 0.001

__all__ = [
Expand Down Expand Up @@ -258,13 +260,18 @@ def reset(self):
for dp in self.datapipes:
dp.reset_iterator()

def reset_epoch(self, reset_fn: Callable[[WorkerInfo, DataPipe], DataPipe]):
def reset_epoch(
self, reset_fn: Callable[[WorkerInfo, SeedGenerator, DataPipe], DataPipe], seed_generator: SeedGenerator
):
for dp in self.datapipes:
dp.protocol.discard_existing_request()
num_workers = len(self.datapipes)
for worker_id, dp in enumerate(self.datapipes):
worker_info = WorkerInfo(num_workers, worker_id)
dp.protocol.request_reset_epoch(partial(reset_fn, worker_info=worker_info))
worker_seed_generator = seed_generator.spawn(worker_id)
dp.protocol.request_reset_epoch(
partial(reset_fn, worker_info=worker_info, seed_generator=worker_seed_generator)
)
while True:
try:
dp.protocol.get_response_reset_epoch()
Expand Down
Loading