Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ReadingService] Add round robin sharding to support non-replicable DataPipe for Multiprocessing #919

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.doctest",
"sphinx.ext.graphviz",
]

# Do not execute standard reST doctest blocks so that documentation can
Expand Down
115 changes: 111 additions & 4 deletions docs/source/dataloader2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ Note:
- :class:`torchdata.datapipes.map.SequenceWrapper`: ``torch.utils.data.Dataset``
- :class:`torchdata.datapipes.iter.IterableWrapper`: ``torch.utils.data.IterableDataset``

Both custom ``worker_init_fn`` and ``worker_reset_fn`` require the following two arguments:
- :class:`torchdata.dataloader2.utils.WorkerInfo`
- ``DataPipe``

ReadingService
---------------

Expand All @@ -38,6 +34,117 @@ ReadingService

Each ``ReadingServices`` would take the ``DataPipe`` graph and modify it to achieve a few features like dynamic sharding, sharing random seeds and snapshoting for multi-/distributed processes.

Dynamic Sharding
^^^^^^^^^^^^^^^^

Dynamic sharding is achieved by ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.

- ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``.

- ``sharding_round_robin_dispatch``: When there is any non-replicable ``DataPipe`` (``sharding_round_robin_dispatch``) in the pipeline, a dispatching process will be created to load data from the non-replicable ``DataPipe`` and distributed data to the subsequent worker processes.

The following is an example of having two types of sharding strategies in the pipeline.

.. graphviz::

digraph Example {
subgraph cluster_replicable {
label="Replicable"
a -> b -> c -> d -> l;
color=blue;
}

subgraph cluster_non_replicable {
style=filled;
color=lightgrey;
node [style=filled,color=white];
label="Non-Replicable"
e -> f -> g -> k;
h -> i -> j -> k;
}

k -> l -> fullsync -> end;

a [label="DP1"];
b [label="shuffle"];
c [label="sharding_filter", color=blue];
d [label="DP4"];
e [label="DP2"];
f [label="shuffle"];
g [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white];
h [label="DP3"];
i [label="shuffle"];
j [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white];
k [label="DP5 (Lowest common ancestor)"];
l [label="DP6"];
fullsync;
end [shape=box];
}

When multiprocessing takes place, the graph becomes:

.. graphviz::

digraph Example {
subgraph cluster_worker_0 {
label="Worker 0"
a0 -> b0 -> c0 -> d0 -> l0;
m0 -> l0;
color=blue;
}

subgraph cluster_worker_1 {
label="Worker 1"
a1 -> b1 -> c1 -> d1 -> l1;
m1 -> l1;
color=blue;
}

subgraph cluster_non_replicable {
style=filled;
color=lightgrey;
node [style=filled,color=white];
label="Non-Replicable"
e -> f -> g -> k;
h -> i -> j -> k;
k -> round_robin_demux;
}

round_robin_demux -> m0;
round_robin_demux -> m1;
l0 -> n;
l1 -> n;
n -> fullsync -> end;

a0 [label="DP1"];
b0 [label="shuffle"];
c0 [label="sharding_filter", color=blue];
d0 [label="DP4"];
a1 [label="DP1"];
b1 [label="shuffle"];
c1 [label="sharding_filter", color=blue];
d1 [label="DP4"];
e [label="DP2"];
f [label="shuffle"];
g [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white];
h [label="DP3"];
i [label="shuffle"];
j [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white];
k [label="DP5 (Lowest common ancestor)"];
fullsync;
l0 [label="DP6"];
l1 [label="DP6"];
m0 [label="Client"]
m1 [label="Client"]
n [label="Client"]
end [shape=box];
}

``Client`` in the graph is a ``DataPipe`` that send request and receive response from multiprocessing queues.

Graph Mode
^^^^^^^^^^

This also allows easier transition of data-preprocessing pipeline from research to production. After the ``DataPipe`` graph is created and validated with the ``ReadingServices``, a different ``ReadingService`` that configures and connects to the production service/infra such as ``AIStore`` can be provided to :class:`DataLoader2` as a drop-in replacement. The ``ReadingService`` could potentially search the graph, and find ``DataPipe`` operations that can be delegated to the production service/infra, then modify the graph correspondingly to achieve higher-performant execution.

The followings are interfaces for custom ``ReadingService``.
Expand Down
134 changes: 130 additions & 4 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
)
from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME

from torchdata.dataloader2.graph import DataPipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from torchdata.dataloader2.graph import DataPipe, replace_dp, traverse_dps
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingRoundRobinDispatcher
from torchdata.datapipes.map import SequenceWrapper

try:
Expand Down Expand Up @@ -302,7 +302,7 @@ class TestDataLoader2EventLoop(TestCase):
#
# it = list(range(100))
# numbers_dp = IterableWrapper(it)
# (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp)
# (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline(numbers_dp)
ejguan marked this conversation as resolved.
Show resolved Hide resolved
#
# process.start()
# local_datapipe = communication.iter.QueueWrapper(
Expand All @@ -323,7 +323,7 @@ def clean_me(process, req_queue, res_queue):
input_len = 100
it = list(range(input_len))
numbers_dp = SequenceWrapper(it)
(process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(
(process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline(
numbers_dp
)

Expand Down Expand Up @@ -352,6 +352,10 @@ def clean_me(process, req_queue, res_queue):
clean_me(process, req_queue, res_queue)


def _x_mult_2(d):
return d * 2


class PrototypeMultiProcessingReadingServiceTest(TestCase):
@staticmethod
def _worker_init_fn(datapipe, worker_info):
Expand Down Expand Up @@ -389,6 +393,128 @@ def test_worker_fns(self):
res2 = list(dl)
self.assertEqual(exp, res2)

def test_single_branch_non_replicable(self):
r"""
For single branch pipeline with a non-replicable DataPipe, all ``sharding_filters``
in the pipeline become non-replicable.
"""

def _make_dp():
single_br_dp = IterableWrapper(list(range(10))).shuffle()
map_dp = single_br_dp.map(_x_mult_2)
end_dp = map_dp.map(_x_mult_2).shuffle()
return single_br_dp, map_dp, end_dp

def _assert_deterministic_dl_res(dl, exp):
torch.manual_seed(123)
res = list(dl)
self.assertEqual(sorted(res), exp)
# Second epoch
torch.manual_seed(123)
self.assertEqual(list(dl), res)
# Different seed
torch.manual_seed(321)
self.assertNotEqual(list(dl), res)
# Properly shutdown
dl.shutdown()

# By-default, all replicable
single_br_dp, _, end_dp = _make_dp()
graph = traverse_dps(end_dp)
sf_dp = single_br_dp.sharding_filter()
replace_dp(graph, single_br_dp, sf_dp)
dl = DataLoader2(end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2))
# Determinism and dynamic sharding
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])

# Non-replicable before sharding_filter
# shuffle in dispatch process
single_br_dp, map_dp, end_dp = _make_dp()
graph = traverse_dps(end_dp)
round_robin_dispatcher = ShardingRoundRobinDispatcher(single_br_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, single_br_dp, round_robin_dispatcher)
sf_dp = map_dp.sharding_filter()
replace_dp(graph, map_dp, sf_dp)
dl = DataLoader2(end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2))
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])

# Non-replicable after sharding_filter
# shuffle in dispatch process
single_br_dp, map_dp, end_dp = _make_dp()
graph = traverse_dps(end_dp)
sf_dp = single_br_dp.sharding_filter()
replace_dp(graph, single_br_dp, sf_dp)
round_robin_dispatcher = ShardingRoundRobinDispatcher(map_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, map_dp, round_robin_dispatcher)
dl = DataLoader2(end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2))
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])

def test_multi_branch_non_replicable(self):
r"""
For multi-branch pipeline with a non-replicable DataPipe on one branch,
all ``sharding_filter`` on the other branches should remain replicable.
"""

def _make_dp():
branch1_dp = IterableWrapper(list(range(10))).shuffle()
branch2_dp = IterableWrapper(list(range(10))).shuffle()
map_dp = branch1_dp.map(_x_mult_2)
end_dp = map_dp.zip(branch2_dp)
return branch1_dp, map_dp, branch2_dp, end_dp

def _assert_deterministic_dl_res(dl, exp1, exp2):
torch.manual_seed(123)
res = list(dl)
res1, res2 = list(zip(*res))
self.assertEqual(sorted(res1), exp1)
self.assertEqual(sorted(res2), exp2)
# Second epoch
torch.manual_seed(123)
self.assertEqual(list(dl), res)
# Different seed
torch.manual_seed(321)
self.assertNotEqual(list(dl), res)
# Properly shutdown
dl.shutdown()

# By-default, all replicable
branch1_dp, _, branch2_dp, end_dp = _make_dp()
graph = traverse_dps(end_dp)
sf1_dp = branch1_dp.sharding_filter()
sf2_dp = branch2_dp.sharding_filter()
replace_dp(graph, branch1_dp, sf1_dp)
replace_dp(graph, branch2_dp, sf2_dp)
dl = DataLoader2(end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2))
# Determinism and dynamic sharding
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))

# Non-replicable on one branch
# shuffle in dispatch process
branch1_dp, _, branch2_dp, end_dp = _make_dp()
graph = traverse_dps(end_dp)
non_replicable_dp = ShardingRoundRobinDispatcher(branch1_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, branch1_dp, non_replicable_dp)
# The other branch should has a sharding_filter to make data even
sf_dp = branch2_dp.sharding_filter()
replace_dp(graph, branch2_dp, sf_dp)
dl = DataLoader2(end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2))
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))

# Non-replicable on both branches
# shuffle in dispatch process
branch1_dp, _, branch2_dp, end_dp = _make_dp()
graph = traverse_dps(end_dp)
non_replicable_dp1 = ShardingRoundRobinDispatcher(branch1_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, branch1_dp, non_replicable_dp1)
non_replicable_dp2 = ShardingRoundRobinDispatcher(branch2_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, branch2_dp, non_replicable_dp2)
dl = DataLoader2(end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2))
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))


if __name__ == "__main__":
unittest.main()
Loading