Skip to content

Commit

Permalink
Enable SequentialReadingService to support MP + Distributed (#985)
Browse files Browse the repository at this point in the history
Summary:
Fixes #911

### Changes

- Remove distributed code from PrototypeMPRS
- Fix a bug of `blocking_request_get` not sent to worker process
- Enable `SequentialReadingService` to combine both Distributed and MP ReadingService
- Add tests for `SequentialReadingService`
- Add tutorial for `SequentialReadingService`

Pull Request resolved: #985

Reviewed By: wenleix, NivekT

Differential Revision: D43009426

Pulled By: ejguan

fbshipit-source-id: 5668f6e0ea606846732b770f4430d399254b42bc
  • Loading branch information
ejguan authored and facebook-github-bot committed Feb 7, 2023
1 parent 01fc762 commit 89be152
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 98 deletions.
18 changes: 18 additions & 0 deletions docs/source/dlv2_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,21 @@ Distributed
for d in dl:
model(d)
dl.shutdown()
Multiprocessing + Distributed
------------------------------

``SequentialReadingService`` can be used to combine both ``ReadingServices`` together to achive multiprocessing and distributed training at the same time.

.. code:: python
mp_rs = PrototypeMultiProcessingReadingService(num_workers=4)
dist_rs = DistributedReadingService()
rs = SequentialReadingService(dist_rs, mp_rs)
dl = DataLoader2(datapipe, reading_service=rs)
for epoch in range(10):
dl.seed(epoch)
for d in dl:
model(d)
dl.shutdown()
223 changes: 222 additions & 1 deletion test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,29 @@
import multiprocessing as mp
import os
import pickle
import queue
import random
import socket
import unittest

from unittest import TestCase

import numpy as np

import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize

from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES

from torchdata.dataloader2 import (
communication,
DataLoader2,
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
SequentialReadingService,
)
from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME

Expand All @@ -42,6 +51,14 @@
HAS_DILL = False

skipIfNoDill = unittest.skipIf(not HAS_DILL, "no dill")

if dist.is_available():
HAS_DIST = True
else:
HAS_DIST = False

skipIfNoDistributed = unittest.skipIf(not HAS_DIST, "no torch.distributed")

TEST_WITH_TSAN = os.getenv("PYTORCH_TEST_WITH_TSAN", "0") == "1"

mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods())
Expand Down Expand Up @@ -574,7 +591,7 @@ def test_multi_worker_determinism(self, ctx):

@mp_ctx_parametrize
def test_dispatching_worker_determinism(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100))
dp: IterDataPipe = IterableWrapper(range(101))
dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.batch(2)

Expand Down Expand Up @@ -635,7 +652,211 @@ def test_non_replicable_datapipe(self, ctx) -> None:
self.assertNotEqual(res, list(dl) + list(dl))


TEST_MASTER_ADDR = "127.0.0.1"
DEFAULT_WORLD_SIZE = 2


def _get_open_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return str(port)


class TerminateSignal:
pass


def _launch_distributed_training(world_size, *args, fn):
os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR
os.environ["MASTER_PORT"] = _get_open_port()
ctx = mp.get_context("spawn")
q = ctx.Queue()
ps = []
for rank in range(world_size):
p = ctx.Process(
target=fn,
args=(
rank,
world_size,
q,
*args,
),
)
p.start()
ps.append(p)
res = []
while True:
try:
d = q.get()
if isinstance(d, TerminateSignal):
break
res.append(d)
except queue.Empty:
continue
for p in ps:
p.join()
return res


def _dist_one_epoch(dl):
res = []
for d in dl:
res.append(d)
# Simulate training synchronization
dist.barrier()
return res


def _finalize_distributed_queue(rank, q):
r"""
Synchronize all distributed processes to guarantee all data have been put into
the Multiprocessing Queue.
"""
pg = dist.new_group(backend="gloo")
end_tensor = torch.tensor([rank], dtype=torch.int64)
dist.all_reduce(end_tensor, group=pg)
if rank == 0:
q.put(TerminateSignal())

dist.destroy_process_group(pg)


def _random_fn(data):
r"""
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


def _dist_training_fn(rank, world_size, q, dp_fn, rs_fn, num_workers, ctx):
# Use gloo
dist.init_process_group("gloo", rank=rank, world_size=world_size)

# Uneven shards
data_length = world_size * num_workers * 10 + 1
dp = dp_fn(data_length)
rs = rs_fn(num_workers, ctx)
dl = DataLoader2(dp, reading_service=rs)

# No seed
res = _dist_one_epoch(dl)
q.put((0, rank, res))

# Shuffle with seed
for epoch in range(2):
dl.seed(123)
res = _dist_one_epoch(dl)
q.put((epoch + 1, rank, res))

# Different seed
dl.seed(321)
res = _dist_one_epoch(dl)
q.put((3, rank, res))

_finalize_distributed_queue(rank, q)

dl.shutdown()


@skipIfNoDistributed
class SequentialReadingServiceTest(TestCase):
@staticmethod
def _make_dp(data_length):
data_source = IterableWrapper(list(range(data_length)))
dp = data_source.shuffle().sharding_filter().map(_random_fn)
return dp

@staticmethod
def _make_dispatching_dp(data_length):
data_source = IterableWrapper(list(range(data_length)))
dp = data_source.shuffle().sharding_filter()
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(_random_fn)
return dp

@staticmethod
def _make_rs(num_workers, ctx):
mp_rs = PrototypeMultiProcessingReadingService(
num_workers=num_workers,
multiprocessing_context=ctx,
)
dist_rs = DistributedReadingService()
rs = SequentialReadingService(dist_rs, mp_rs)
return rs

@mp_ctx_parametrize
def test_sequential_reading_service_normal_dp(self, ctx):
world_size = DEFAULT_WORLD_SIZE
num_workers = 2
res = _launch_distributed_training(
world_size,
SequentialReadingServiceTest._make_dp,
SequentialReadingServiceTest._make_rs,
num_workers,
ctx,
fn=_dist_training_fn,
)
result = ({}, {}, {}, {})
for epoch, rank, r in res:
d, *ran_nums = list(zip(*r))
result[epoch][rank] = (d, ran_nums)

# Guarantee the same length per rank
for rr in result:
exp_len = num_workers * 10
for _, (d, _) in rr.items():
self.assertEqual(len(d), exp_len)

# Same seed generate the same order of data and the same random state
self.assertEqual(result[1], result[2])

# Different seeds
for rank in range(world_size):
# Different shuffle order
self.assertNotEqual(result[1][rank][0], result[3][rank][0])
# Different subprocess-local random state
self.assertNotEqual(result[1][rank][1], result[3][rank][1])

@mp_ctx_parametrize
def test_sequential_reading_service_dispatching_dp(self, ctx):
world_size = DEFAULT_WORLD_SIZE
num_workers = 2
res = _launch_distributed_training(
world_size,
SequentialReadingServiceTest._make_dispatching_dp,
SequentialReadingServiceTest._make_rs,
num_workers,
ctx,
fn=_dist_training_fn,
)
result = ({}, {}, {}, {})
for epoch, rank, r in res:
d, *ran_nums = list(zip(*r))
result[epoch][rank] = (d, ran_nums)

# Guarantee the same length per rank
for rr in result:
exp_len = num_workers * 10
for _, (d, _) in rr.items():
self.assertEqual(len(d), exp_len)

# Same seed generate the same order of data and the same random state
self.assertEqual(result[1], result[2])

# Different seeds
for rank in range(world_size):
# Different shuffle order
self.assertNotEqual(result[1][rank][0], result[3][rank][0])
# Different subprocess-local random state
self.assertNotEqual(result[1][rank][1], result[3][rank][1])


instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest)
instantiate_parametrized_tests(SequentialReadingServiceTest)


if __name__ == "__main__":
Expand Down
69 changes: 0 additions & 69 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,45 +125,6 @@ def _finalize_distributed_queue(rank, q):
dist.destroy_process_group(pg)


def _random_fn(data):
r"""
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


def _test_proto_distributed_training(rank, world_size, backend, q, num_workers):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Balanced data
data_length = world_size * 8
if num_workers > 0:
data_length *= num_workers

data_source = IterableWrapper(list(range(data_length)))
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
dl = DataLoader2(dp, reading_service=rs)

# No seed
res = _dist_iterate_one_epoch(dl, seed=None)
q.put((0, rank, res))

# Shuffle with seed
for epoch in range(2):
res = _dist_iterate_one_epoch(dl, seed=123)
q.put((epoch + 1, rank, res))

# Different seed
res = _dist_iterate_one_epoch(dl, seed=321)
q.put((3, rank, res))

_finalize_distributed_queue(rank, q)
dl.shutdown()


class DistributedTest(TestCase):
@staticmethod
def _test_fullsync(rank, world_size, backend, q):
Expand Down Expand Up @@ -285,36 +246,6 @@ def test_elastic_training_dl1(self, backend) -> None:
],
)

@unittest.skipIf(IS_WINDOWS, "Remove when https://github.com/pytorch/data/issues/857 is fixed")
@backend_parametrize
@parametrize("num_workers", [0, 8])
def test_proto_rs_dl2(self, backend, num_workers) -> None:
world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count()
res = launch_distributed_training(backend, world_size, num_workers, fn=_test_proto_distributed_training)
result = ({}, {}, {}, {})
for epoch, rank, r in res:
d, *ran_nums = list(zip(*r))
result[epoch][rank] = (d, ran_nums)
# Same seed generate the same order of data and the same random state
self.assertEqual(result[1], result[2])
# Different seeds
for rank in range(world_size):
# Different shuffle order
self.assertNotEqual(result[1][rank][0], result[3][rank][0])
# Different subprocess-local random state
self.assertNotEqual(result[1][rank][1], result[3][rank][1])

# Mutually exclusive and collectively exhaustive with/without seed
data_length = world_size * 8
if num_workers > 0:
data_length *= num_workers
exp = list(range(data_length))
for res in result:
concat_res = []
for r in res.values():
concat_res.extend(r[0])
self.assertEqual(sorted(concat_res), exp)


instantiate_parametrized_tests(DistributedTest)

Expand Down
2 changes: 1 addition & 1 deletion torchdata/dataloader2/communication/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, blocking_
return pipe_type.DataPipeBehindQueues(
source_datapipe,
protocol_type(req_queue, res_queue),
blocking_request_get=True,
blocking_request_get=blocking_request_get,
)


Expand Down
Loading

0 comments on commit 89be152

Please sign in to comment.