Skip to content

Commit

Permalink
Add process-local RNG tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Oct 13, 2022
1 parent 2ac9b69 commit ead6eeb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 21 deletions.
37 changes: 31 additions & 6 deletions test/dataloader2/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,77 @@
# LICENSE file in the root directory of this source tree.


import random
import unittest

from unittest import TestCase

import numpy as np

import torch

from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper


def _random_fn(data):
r"""
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


class DeterminismTest(TestCase):
@parametrize("num_workers", [0, 8])
def test_proto_rs_determinism(self, num_workers):
data_length = 64
exp = list(range(data_length))

data_source = IterableWrapper(exp)
dp = data_source.shuffle().sharding_filter()
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
dl = DataLoader2(dp, reading_service=rs)

# No seed
res = []
for d in dl:
for d, *_ in dl:
res.append(d)
self.assertEqual(sorted(res), exp)

# Shuffle with seed
results = []
for _ in range(2):
res = []
ran_res = []
torch.manual_seed(123)
for d in dl:
random.seed(123)
np.random.seed(123)
for d, *ran_nums in dl:
res.append(d)
ran_res.append(ran_nums)
self.assertEqual(sorted(res), exp)
results.append(res)
results.append((res, ran_res))
# Same seed generate the same order of data and the same random state
self.assertEqual(results[0], results[1])

# Different seed
res = []
ran_res = []
torch.manual_seed(321)
for d in dl:
random.seed(321)
np.random.seed(321)
for d, *ran_nums in dl:
res.append(d)
ran_res.append(ran_nums)
self.assertEqual(sorted(res), exp)
self.assertNotEqual(results[0], res)
# Different shuffle order
self.assertNotEqual(results[0][0], res)
# Different subprocess-local random state
self.assertNotEqual(results[0][1], ran_res)


instantiate_parametrized_tests(DeterminismTest)
Expand Down
42 changes: 34 additions & 8 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

import os
import queue
import random
import socket
import sys
import unittest

from functools import partial
from unittest import TestCase

import numpy as np

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
Expand Down Expand Up @@ -46,8 +50,6 @@ def abs_path(path):


def _get_open_port():
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
Expand Down Expand Up @@ -94,8 +96,13 @@ def launch_distributed_training(backend, world_size, *args, fn):


def _dist_iterate_one_epoch(dl, seed=None):
r"""
Iterate a full epoch of DataLoader and set seeds for global RNGs if provided.
"""
if seed is not None:
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
res = []
for d in dl:
res.append(d)
Expand All @@ -105,13 +112,27 @@ def _dist_iterate_one_epoch(dl, seed=None):


def _finalize_distributed_queue(rank, q):
r"""
Synchronize all distributed processes to guarantee all data have been put into
the Multiprocessing Queue.
"""
pg = dist.new_group(backend="gloo")
end_tensor = torch.tensor([rank], dtype=torch.int64)
dist.all_reduce(end_tensor, group=pg)
if rank == 0:
q.put(TerminateSignal())


def _random_fn(data):
r"""
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


def _test_proto_distributed_training(rank, world_size, backend, q, num_workers):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Balanced data
Expand All @@ -120,7 +141,7 @@ def _test_proto_distributed_training(rank, world_size, backend, q, num_workers):
data_length *= num_workers

data_source = IterableWrapper(list(range(data_length)))
dp = data_source.shuffle().sharding_filter()
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
dl = DataLoader2(dp, reading_service=rs)

Expand Down Expand Up @@ -274,11 +295,16 @@ def test_proto_rs_dl2(self, backend, num_workers) -> None:
res = launch_distributed_training(backend, world_size, num_workers, fn=_test_proto_distributed_training)
result = ({}, {}, {}, {})
for epoch, rank, r in res:
result[epoch][rank] = r
# Same Seed
d, *ran_nums = list(zip(*r))
result[epoch][rank] = (d, ran_nums)
# Same seed generate the same order of data and the same random state
self.assertEqual(result[1], result[2])
# Different Seeds
self.assertNotEqual(result[1], result[3])
# Different seeds
for rank in range(world_size):
# Different shuffle order
self.assertNotEqual(result[1][rank][0], result[3][rank][0])
# Different subprocess-local random state
self.assertNotEqual(result[1][rank][1], result[3][rank][1])

# Mutually exclusive and collectively exhaustive with/without seed
data_length = world_size * 8
Expand All @@ -288,7 +314,7 @@ def test_proto_rs_dl2(self, backend, num_workers) -> None:
for res in result:
concat_res = []
for r in res.values():
concat_res.extend(r)
concat_res.extend(r[0])
self.assertEqual(sorted(concat_res), exp)


Expand Down
13 changes: 6 additions & 7 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,26 +262,25 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
)
self.datapipes.append(local_datapipe)

end_datapipe = _IterateQueueDataPipes(self.datapipes)
self.end_datapipe = end_datapipe
return end_datapipe
self.end_datapipe = _IterateQueueDataPipes(self.datapipes) # type: ignore[assignment]
return self.end_datapipe # type: ignore[return-value]

def initialize_iteration(self) -> None:
shared_seed = _generate_random_seed()
if self._pg is not None:
dist.broadcast(shared_seed, src=0, group=self._pg)
shared_seed = shared_seed.item()
shared_seed_int: int = shared_seed.item() # type: ignore[assignment]
_seed_generator = torch.Generator()
_seed_generator.manual_seed(shared_seed)
_seed_generator.manual_seed(shared_seed_int)
torch.utils.data.graph_settings.apply_random_seed(
self.end_datapipe,
self.end_datapipe, # type: ignore[arg-type]
_seed_generator,
)

# Multiprocessing (num_workers > 0)
if isinstance(self.end_datapipe, _IterateQueueDataPipes):
# Send the shared seed to subprocesses
self.end_datapipe.reset_epoch(shared_seed)
self.end_datapipe.reset_epoch(shared_seed_int)
# In-process (num_workers == 0)
else:
# Technically speaking, we should call `_process_reset_fn` to reset global RNGs
Expand Down

0 comments on commit ead6eeb

Please sign in to comment.