diff --git a/docs/source/conf.py b/docs/source/conf.py index 1a179cfe0..3391a3dcc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,6 +54,7 @@ "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.doctest", + "sphinx.ext.graphviz", ] # Do not execute standard reST doctest blocks so that documentation can diff --git a/docs/source/dataloader2.rst b/docs/source/dataloader2.rst index 8cb076dd8..2de3791e7 100644 --- a/docs/source/dataloader2.rst +++ b/docs/source/dataloader2.rst @@ -37,12 +37,110 @@ Each ``ReadingServices`` would take the ``DataPipe`` graph and modify it to achi Dynamic Sharding ^^^^^^^^^^^^^^^^ -Dynamic sharding will take place at the place of ``sharding_filter`` within the pipeline. It's carried out by ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` based on the corresponding multiprocessing and distributed workers. - -There is a special case that non-shardable ``DataPipe`` (``datapipe.is_shardable() == False``) is presented in the graph. In that case, a certain part of ``DataPipe`` cannot be sent to multiprocessing workers. Based on the existing use cases, there are two typical non-shardable ``DataPipes``: -- Non-shardable data source like loading data from a remote resource that only accept a single client. When multiprocessing takes place, the lowest common ancestor of non-shardable data source will be sent to a non-sharding process and transfer data from the non-shardable process to worker processes in the round-robin manner. -- Non-shardable ``DataPipe`` that needs to be placed in the main process like ``fullsync``. And, this type of ``DataPipe`` is normally appended at the end of the pipeline and reading data from multiprocessing workers. -- Please let us know if you have new examples about non-shardable ``DataPipe``. +Dynamic sharding is achieved by ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline. + +- ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``. + +- ``sharding_round_robin_dispatch``: When there is any non-replicable ``DataPipe`` (``sharding_round_robin_dispatch``) in the pipeline, a dispatching process will be created to load data from the non-replicable ``DataPipe`` and distributed data to the subsequent worker processes. + +The following is an example of having two types of sharding strategies in the pipeline. + +.. graphviz:: + + digraph Example { + subgraph cluster_replicable { + label="Replicable" + a -> b -> c -> d -> l; + color=blue; + } + + subgraph cluster_non_replicable { + style=filled; + color=lightgrey; + node [style=filled,color=white]; + label="Non-Replicable" + e -> f -> g -> k; + h -> i -> j -> k; + } + + k -> l -> fullsync -> end; + + a [label="DP1"]; + b [label="shuffle"]; + c [label="sharding_filter", color=blue]; + d [label="DP4"]; + e [label="DP2"]; + f [label="shuffle"]; + g [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; + h [label="DP3"]; + i [label="shuffle"]; + j [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; + k [label="DP5 (Lowest common ancestor)"]; + l [label="DP6"]; + fullsync; + end [shape=box]; + } + +When multiprocessing takes place, the graph becomes: + +.. graphviz:: + + digraph Example { + subgraph cluster_worker_0 { + label="Worker 0" + a0 -> b0 -> c0 -> d0 -> l0; + m0 -> l0; + color=blue; + } + + subgraph cluster_worker_1 { + label="Worker 1" + a1 -> b1 -> c1 -> d1 -> l1; + m1 -> l1; + color=blue; + } + + subgraph cluster_non_replicable { + style=filled; + color=lightgrey; + node [style=filled,color=white]; + label="Non-Replicable" + e -> f -> g -> k; + h -> i -> j -> k; + k -> round_robin_demux; + } + + round_robin_demux -> m0; + round_robin_demux -> m1; + l0 -> n; + l1 -> n; + n -> fullsync -> end; + + a0 [label="DP1"]; + b0 [label="shuffle"]; + c0 [label="sharding_filter", color=blue]; + d0 [label="DP4"]; + a1 [label="DP1"]; + b1 [label="shuffle"]; + c1 [label="sharding_filter", color=blue]; + d1 [label="DP4"]; + e [label="DP2"]; + f [label="shuffle"]; + g [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; + h [label="DP3"]; + i [label="shuffle"]; + j [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; + k [label="DP5 (Lowest common ancestor)"]; + fullsync; + l0 [label="DP6"]; + l1 [label="DP6"]; + m0 [label="Client"] + m1 [label="Client"] + n [label="Client"] + end [shape=box]; + } + +``Client`` in the graph is a ``DataPipe`` that send request and receive response from multiprocessing queues. Graph Mode ^^^^^^^^^^ diff --git a/torchdata/datapipes/iter/__init__.pyi.in b/torchdata/datapipes/iter/__init__.pyi.in index 9cf91b10d..855185324 100644 --- a/torchdata/datapipes/iter/__init__.pyi.in +++ b/torchdata/datapipes/iter/__init__.pyi.in @@ -12,6 +12,7 @@ from torchdata._constants import default_timeout_in_s from torchdata.datapipes.map import MapDataPipe from torch.utils.data import DataChunk, IterableDataset, default_collate from torch.utils.data.datapipes._typing import _DataPipeMeta +from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, Hashable