Skip to content

Commit

Permalink
Officially graduate ProtypeMPRS to MPRS (#1009)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1009

Fixes: #932

- Convert all references from `ProtypeMPRS` to `MPRS`
- Remove usage of `MPRS`

Reviewed By: wenleix, NivekT

Differential Revision: D43245136

fbshipit-source-id: 2195a7aae64b8ded5483f5b37383a4fa2d94f13e
  • Loading branch information
ejguan authored and facebook-github-bot committed Feb 14, 2023
1 parent b43c07d commit 9707a13
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 152 deletions.
4 changes: 2 additions & 2 deletions benchmarks/cloud/aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pandas as pd
import psutil
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper


Expand Down Expand Up @@ -64,7 +64,7 @@ def check_and_output_speed(prefix: str, create_dp_fn: Callable, n_prefetch: int,
dp = create_dp_fn()

rs_type = "DataLoader2 w/ tar archives"
new_rs = PrototypeMultiProcessingReadingService(
new_rs = MultiProcessingReadingService(
num_workers=n_workers, worker_prefetch_cnt=n_prefetch, main_prefetch_cnt=n_prefetch
)
dl: DataLoader2 = DataLoader2(dp, reading_service=new_rs)
Expand Down
1 change: 0 additions & 1 deletion docs/source/dataloader2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ ReadingService

DistributedReadingService
MultiProcessingReadingService
PrototypeMultiProcessingReadingService
SequentialReadingService

Each ``ReadingServices`` would take the ``DataPipe`` graph and rewrite it to achieve a few features like dynamic sharding, sharing random seeds and snapshoting for multi-/distributed processes. For more detail about those features, please refer to `the documentation <reading_service.html>`_.
Expand Down
6 changes: 3 additions & 3 deletions docs/source/dlv2_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ Here is an example of a ``DataPipe`` graph:
Multiprocessing
----------------

``PrototypeMultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes.
``MultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes.

.. code:: python
rs = PrototypeMultiProcessingReadingService(num_workers=4)
rs = MultiProcessingReadingService(num_workers=4)
dl = DataLoader2(datapipe, reading_service=rs)
for epoch in range(10):
dl.seed(epoch)
Expand Down Expand Up @@ -58,7 +58,7 @@ Multiprocessing + Distributed

.. code:: python
mp_rs = PrototypeMultiProcessingReadingService(num_workers=4)
mp_rs = MultiProcessingReadingService(num_workers=4)
dist_rs = DistributedReadingService()
rs = SequentialReadingService(dist_rs, mp_rs)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reading_service.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Features
Dynamic Sharding
^^^^^^^^^^^^^^^^

Dynamic sharding is achieved by ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.
Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.

- ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``.

Expand Down Expand Up @@ -121,7 +121,7 @@ Determinism

In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access to it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations.

In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``PrototypeMultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.
In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.

Graph Mode
^^^^^^^^^^^
Expand Down
70 changes: 14 additions & 56 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
DataLoader2,
DistributedReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
SequentialReadingService,
)
Expand Down Expand Up @@ -120,16 +119,6 @@ def test_dataloader2_reading_service(self) -> None:
self.assertEqual(batch, expected_batch)
expected_batch += 1

def test_dataloader2_multi_process_reading_service(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = MultiProcessingReadingService()
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)

expected_batch = 0
for batch in iter(data_loader):
self.assertEqual(batch, expected_batch)
expected_batch += 1

def test_dataloader2_load_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
reading_service = TestReadingService()
Expand Down Expand Up @@ -165,7 +154,7 @@ def test_dataloader2_iterates_correctly(self) -> None:
None,
TestReadingService(),
MultiProcessingReadingService(num_workers=4),
PrototypeMultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0),
MultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0),
]
for reading_service in reading_services:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service)
Expand Down Expand Up @@ -232,18 +221,10 @@ def _get_no_reading_service():
def _get_mp_reading_service():
return MultiProcessingReadingService(num_workers=2)

@staticmethod
def _get_proto_reading_service():
return PrototypeMultiProcessingReadingService(num_workers=2)

@staticmethod
def _get_mp_reading_service_zero_workers():
return MultiProcessingReadingService(num_workers=0)

@staticmethod
def _get_proto_reading_service_zero_workers():
return PrototypeMultiProcessingReadingService(num_workers=0)

def _collect_data(self, datapipe, reading_service_gen):
dl: DataLoader2 = DataLoader2(datapipe, reading_service=reading_service_gen())
result = []
Expand All @@ -265,9 +246,7 @@ def test_dataloader2_batch_collate(self) -> None:

reading_service_generators = (
self._get_mp_reading_service,
self._get_proto_reading_service,
self._get_mp_reading_service_zero_workers,
self._get_proto_reading_service_zero_workers,
)
for reading_service_gen in reading_service_generators:
actual = self._collect_data(dp, reading_service_gen=reading_service_gen)
Expand All @@ -279,27 +258,6 @@ def test_dataloader2_shuffle(self) -> None:
pass


class DataLoader2IntegrationTest(TestCase):
@staticmethod
def _get_mp_reading_service():
return MultiProcessingReadingService(num_workers=2)

def test_lazy_load(self):
source_dp = IterableWrapper([(i, i) for i in range(10)])
map_dp = source_dp.to_map_datapipe()

reading_service_generators = (self._get_mp_reading_service,)
for reading_service_gen in reading_service_generators:
dl: DataLoader2 = DataLoader2(datapipe=map_dp, reading_service=reading_service_gen())
# Lazy loading
dp = dl.datapipe
self.assertTrue(dp._map is None)
it = iter(dl)
self.assertEqual(list(it), list(range(10)))
# Lazy loading in multiprocessing
self.assertTrue(map_dp._map is None)


@unittest.skipIf(
TEST_WITH_TSAN,
"Fails with TSAN with the following error: starting new threads after multi-threaded "
Expand Down Expand Up @@ -382,7 +340,7 @@ def is_replicable(self):
return False


class PrototypeMultiProcessingReadingServiceTest(TestCase):
class MultiProcessingReadingServiceTest(TestCase):
@staticmethod
def _worker_init_fn(datapipe, worker_info):
datapipe = datapipe.sharding_filter()
Expand All @@ -403,7 +361,7 @@ def _worker_reset_fn(datapipe, worker_info, worker_seed_generator: SeedGenerator
def test_worker_fns(self, ctx):
dp: IterDataPipe = IterableWrapper(range(100)).batch(2).shuffle()

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
worker_init_fn=self._worker_init_fn,
Expand Down Expand Up @@ -448,7 +406,7 @@ def _assert_deterministic_dl_res(dl, exp):
sf_dp = single_br_dp.sharding_filter()
replace_dp(graph, single_br_dp, sf_dp)
dl = DataLoader2(
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
)
# Determinism and dynamic sharding
# _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])
Expand All @@ -462,7 +420,7 @@ def _assert_deterministic_dl_res(dl, exp):
sf_dp = map_dp.sharding_filter()
replace_dp(graph, map_dp, sf_dp)
dl = DataLoader2(
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
)
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])
Expand All @@ -476,7 +434,7 @@ def _assert_deterministic_dl_res(dl, exp):
round_robin_dispatcher = ShardingRoundRobinDispatcher(map_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, map_dp, round_robin_dispatcher)
dl = DataLoader2(
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
)
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 4 for i in range(10)])
Expand Down Expand Up @@ -518,7 +476,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
replace_dp(graph, branch1_dp, sf1_dp)
replace_dp(graph, branch2_dp, sf2_dp)
dl = DataLoader2(
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
)
# Determinism and dynamic sharding
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))
Expand All @@ -533,7 +491,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
sf_dp = branch2_dp.sharding_filter()
replace_dp(graph, branch2_dp, sf_dp)
dl = DataLoader2(
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
)
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))
Expand All @@ -547,7 +505,7 @@ def _assert_deterministic_dl_res(dl, exp1, exp2):
non_replicable_dp2 = ShardingRoundRobinDispatcher(branch2_dp, SHARDING_PRIORITIES.MULTIPROCESSING)
replace_dp(graph, branch2_dp, non_replicable_dp2)
dl = DataLoader2(
end_dp, reading_service=PrototypeMultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx)
)
# Determinism for non-replicable pipeline
_assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10)))
Expand All @@ -558,7 +516,7 @@ def test_multi_worker_determinism(self, ctx):
dp = dp.shuffle().sharding_filter()
dp = dp.batch(2)

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
Expand Down Expand Up @@ -589,7 +547,7 @@ def test_dispatching_worker_determinism(self, ctx):
dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.batch(2)

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
Expand Down Expand Up @@ -625,7 +583,7 @@ def test_non_replicable_datapipe(self, ctx) -> None:
dp = dp.batch(2)
non_rep_dp = NonReplicableDataPipe(dp)

rs = PrototypeMultiProcessingReadingService(
rs = MultiProcessingReadingService(
num_workers=2,
multiprocessing_context=ctx,
)
Expand Down Expand Up @@ -775,7 +733,7 @@ def _make_dispatching_dp(data_length):

@staticmethod
def _make_rs(num_workers, ctx):
mp_rs = PrototypeMultiProcessingReadingService(
mp_rs = MultiProcessingReadingService(
num_workers=num_workers,
multiprocessing_context=ctx,
)
Expand Down Expand Up @@ -850,7 +808,7 @@ def test_sequential_reading_service_dispatching_dp(self, ctx):
self.assertNotEqual(result[1][rank][1], result[3][rank][1])


instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest)
instantiate_parametrized_tests(MultiProcessingReadingServiceTest)
instantiate_parametrized_tests(SequentialReadingServiceTest)


Expand Down
4 changes: 2 additions & 2 deletions test/dataloader2/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.dataloader2.graph.settings import set_graph_random_seed
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -40,7 +40,7 @@ def test_proto_rs_determinism(self, num_workers):

data_source = IterableWrapper(exp)
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = PrototypeMultiProcessingReadingService(num_workers=num_workers)
rs = MultiProcessingReadingService(num_workers=num_workers)
dl = DataLoader2(dp, reading_service=rs)

# No seed
Expand Down
2 changes: 1 addition & 1 deletion test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torch.utils.data import DataLoader

from torchdata.dataloader2 import DataLoader2, DistributedReadingService, PrototypeMultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, DistributedReadingService
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter.util.distributed import PrefetchTimeoutError

Expand Down
11 changes: 1 addition & 10 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
from torchdata.dataloader2 import DataLoader2, ReadingServiceInterface
from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
from torchdata.dataloader2.graph.utils import _find_replicable_branches
from torchdata.dataloader2.random import SeedGenerator
Expand Down Expand Up @@ -254,15 +254,6 @@ def test_reading_service(self) -> None:

self.assertEqual(res, list(dl))

@unittest.skipIf(IS_WINDOWS, "Fork is required for lambda")
def test_multiprocessing_reading_service(self) -> None:
_, (*_, dp) = self._get_datapipes() # pyre-ignore
rs = MultiProcessingReadingService(2, persistent_workers=True, multiprocessing_context="fork")
dl = DataLoader2(dp, reading_service=rs)
d1 = list(dl)
d2 = list(dl)
self.assertEqual(d1, d2)


def insert_round_robin_sharding(graph, datapipe):
dispatch_dp = ShardingRoundRobinDispatcher(datapipe, SHARDING_PRIORITIES.MULTIPROCESSING)
Expand Down
Loading

0 comments on commit 9707a13

Please sign in to comment.