Skip to content

Commit

Permalink
Implement InProcessReadingService (#1139)
Browse files Browse the repository at this point in the history
Summary:
Fixes #1107
Fixes #720
Fixes #616

### Changes

- Implement `InProcessReadingService` (Willing to take any suggestion on naming)
  - Control shuffle and sharding (noop)
  - Add support to pause/resume/limit
- ~Make `InProcessReadingService` as the default `reading_service` to `DataLoader2`.~
  - ~Then, `reading_service` always has a value, and remove the logic of `reading_service` is None.~
- Modify `MultiProcessingReadingService`
  - When `num_workers=0`, raise a warning

Pull Request resolved: #1139

Reviewed By: NivekT

Differential Revision: D45184167

Pulled By: ejguan

fbshipit-source-id: fc7821e8e695920a674145d17fbced86a95226c5
  • Loading branch information
ejguan authored and facebook-github-bot committed May 31, 2023
1 parent bc0650a commit 8f9d123
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/source/dataloader2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ ReadingService
:template: class_method_template.rst

DistributedReadingService
InProcessReadingService
MultiProcessingReadingService
SequentialReadingService

Expand Down
7 changes: 4 additions & 3 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchdata.dataloader2 import (
DataLoader2,
DistributedReadingService,
InProcessReadingService,
MultiProcessingReadingService,
ReadingServiceInterface,
SequentialReadingService,
Expand Down Expand Up @@ -221,8 +222,8 @@ def _get_mp_reading_service():
return MultiProcessingReadingService(num_workers=2)

@staticmethod
def _get_mp_reading_service_zero_workers():
return MultiProcessingReadingService(num_workers=0)
def _get_in_process_reading_service():
return InProcessReadingService()

def _collect_data(self, datapipe, reading_service_gen):
dl: DataLoader2 = DataLoader2(datapipe, reading_service=reading_service_gen())
Expand All @@ -245,7 +246,7 @@ def test_dataloader2_batch_collate(self) -> None:

reading_service_generators = (
self._get_mp_reading_service,
self._get_mp_reading_service_zero_workers,
self._get_in_process_reading_service,
)
for reading_service_gen in reading_service_generators:
actual = self._collect_data(dp, reading_service_gen=reading_service_gen)
Expand Down
185 changes: 165 additions & 20 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.dataloader2 import (
DataLoader2,
DataLoader2Iterator,
InProcessReadingService,
MultiProcessingReadingService,
)
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


Expand All @@ -31,6 +36,164 @@ def _add_one(x: int) -> int:
dp_parametrize = parametrize("dp", test_dps)


class TestInProcessReadingService(TestCase):
r"""
This tests specific functionalities of InProcessReadingService, notably
`pause`, `resume`, `snapshot`.
"""

@dp_parametrize
def test_reading_service_pause_resume(self, dp) -> None:

# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
# properly pauses and resumes
rs1 = InProcessReadingService()
dl1: DataLoader2 = DataLoader2(dp, reading_service=rs1)
res = []
for i, x in enumerate(dl1):
res.append(x)
if i in {2, n_elements - 2}:
dl1._pause()
dl1._resume()

self.assertEqual(list(range(n_elements)), sorted(res))
dl1.shutdown()

rs2 = InProcessReadingService(5)
dl2: DataLoader2 = DataLoader2(dp, reading_service=rs2)
res = []
for i, x in enumerate(dl2):
res.append(x)
if i in {2, n_elements - 2}:
dl2._pause()
dl2._resume()

self.assertEqual(list(range(n_elements)), sorted(res))
dl2.shutdown()

@dp_parametrize
def test_reading_service_pause_stop_yield(self, dp) -> None:

# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
rs = InProcessReadingService(5)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
res = []
for i, x in enumerate(dl):
res.append(x)
if i in {2}:
dl._pause()
self.assertEqual(3, len(res))
dl.shutdown()

@dp_parametrize
def test_reading_service_limit(self, dp) -> None:

rs = InProcessReadingService(5)

dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
res = []
cumulative_res = []
n_limit = 3

it: DataLoader2Iterator = iter(dl)
it.limit(n_limit)
for x in it:
res.append(x)
# Functional Test: Verify that the number of elements yielded equals to the specified limit
self.assertEqual(n_limit, len(res)) # 3
cumulative_res.extend(res)

# Functional Test: Calling `next` after `limit` will trigger `StopIteration`
with self.assertRaises(StopIteration):
next(it)

# Functional Test: Verify that `limit` persists without the need to set it again
it.resume()
res = []
for x in it:
res.append(x)
self.assertEqual(n_limit, len(res)) # 3
cumulative_res.extend(res)

# Functional Test: Clear the `limit` and yield the rest of the elements
it.limit(None)
it.resume()
res = []
for x in it:
res.append(x)
self.assertEqual(n_elements - 2 * n_limit, len(res)) # 4

cumulative_res.extend(res)
self.assertEqual(list(range(n_elements)), sorted(cumulative_res))

# Functional Test: Setting `limit` to a different value during after each mini-epoch
dl2: DataLoader2 = DataLoader2(double_pause_dp, reading_service=rs)
res = []
it2: DataLoader2Iterator = iter(dl2)
it2.limit(3)
for x in it2:
res.append(x)

# Limit can be set before `resume`
it2.limit(4)
it2.resume()
for x in it2:
res.append(x)
self.assertEqual(7, len(res))

# Limit can also be set after `resume`, but before the next `for` loop
it2.resume()
it2.limit(2)
for x in it2:
res.append(x)
self.assertEqual(9, len(res))

def test_initial_epoch_checkpointing(self):
dp = IterableWrapper(range(20)).shuffle()
rs = InProcessReadingService(5)

# Functional Test: Saving state before iterator is created
dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
initial_state = dl.state_dict()
it1 = iter(dl)

restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()
self.assertEqual(list(it1), list(restored_dl))

dl.shutdown()
restored_dl.shutdown()

# Functional Test: Saving state after iterator is created and began iterating
dl = DataLoader2(datapipe=dp, reading_service=rs)
dl.seed(1)
it1 = iter(dl)
temp = next(it1) # Starts iterating
initial_state = dl.state_dict()

restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type]
restored_dl._restore_checkpoint_beginning_of_epoch()

self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result

dl.shutdown()
restored_dl.shutdown()


def _non_dispatching_dp(n_elements=1000):
dp = IterableWrapper(list(range(n_elements))).shuffle()
dp = dp.sharding_filter()
Expand Down Expand Up @@ -97,25 +260,6 @@ def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
_ = list(dl)
dl.shutdown()

@mp_ctx_parametrize
def test_reading_service_pause_resume_0_worker(self, ctx) -> None:

# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
# with `num_workers = 0`
rs0 = MultiProcessingReadingService(
num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx
)
dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0)
res0 = []
for i, x in enumerate(dl0):
res0.append(x)
if i in {2}:
with self.assertRaisesRegex(RuntimeError, r"pause"):
dl0._pause()
with self.assertRaisesRegex(RuntimeError, r"resume"):
dl0._resume()
dl0.shutdown()

@mp_ctx_parametrize
@dp_parametrize
@parametrize(
Expand Down Expand Up @@ -394,6 +538,7 @@ def test_initial_epoch_checkpointing(self):
# pass


instantiate_parametrized_tests(TestInProcessReadingService)
instantiate_parametrized_tests(TestMultiProcessingReadingService)


Expand Down
54 changes: 50 additions & 4 deletions test/dataloader2/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.dataloader2 import DataLoader2, InProcessReadingService, MultiProcessingReadingService
from torchdata.dataloader2.graph.settings import set_graph_random_seed
from torchdata.dataloader2.random import SeedGenerator
from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -26,15 +26,15 @@ def _random_fn(data):
Used to validate the randomness of subprocess-local RNGs are set deterministically.
"""
py_random_num = random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32)
np_random_num = np.random.randint(0, 2 ** 32, dtype=np.uint32)
torch_random_num = torch.randint(0, 2 ** 32, size=[]).item()
return (data, py_random_num, np_random_num, torch_random_num)


class DeterminismTest(TestCase):
@unittest.skipIf(IS_WINDOWS, "Remove when https://github.com/pytorch/data/issues/857 is fixed")
@parametrize("num_workers", [0, 8])
def test_proto_rs_determinism(self, num_workers):
@parametrize("num_workers", [1, 8])
def test_mprs_determinism(self, num_workers):
data_length = 64
exp = list(range(data_length))

Expand Down Expand Up @@ -110,6 +110,52 @@ def _get_dp_seeds_after_setting(worker_id, seed=123):
self.assertNotEqual(ss_0_123, ss_0_321)
self.assertNotEqual(ds_0_123, ds_0_321)

def test_sprs_determinism(self):
data_length = 64
exp = list(range(data_length))

data_source = IterableWrapper(exp)
dp = data_source.shuffle().sharding_filter().map(_random_fn)
rs = InProcessReadingService()
dl = DataLoader2(dp, reading_service=rs)

# No seed
res = []
for d, *_ in dl:
res.append(d)
self.assertEqual(sorted(res), exp)

# Shuffle with seed
results = []
for _ in range(2):
res = []
ran_res = []
torch.manual_seed(123)
random.seed(123)
np.random.seed(123)
for d, *ran_nums in dl:
res.append(d)
ran_res.append(ran_nums)
self.assertEqual(sorted(res), exp)
results.append((res, ran_res))
# Same seed generate the same order of data and the same random state
self.assertEqual(results[0], results[1])

# Different seed
res = []
ran_res = []
torch.manual_seed(321)
random.seed(321)
np.random.seed(321)
for d, *ran_nums in dl:
res.append(d)
ran_res.append(ran_nums)
self.assertEqual(sorted(res), exp)
# Different shuffle order
self.assertNotEqual(results[0][0], res)
# Different subprocess-local random state
self.assertNotEqual(results[0][1], ran_res)


instantiate_parametrized_tests(DeterminismTest)

Expand Down
2 changes: 2 additions & 0 deletions torchdata/dataloader2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchdata.dataloader2.reading_service import (
CheckpointableReadingServiceInterface,
DistributedReadingService,
InProcessReadingService,
MultiProcessingReadingService,
PrototypeMultiProcessingReadingService,
ReadingServiceInterface,
Expand All @@ -22,6 +23,7 @@
"DataLoader2",
"DataLoader2Iterator",
"DistributedReadingService",
"InProcessReadingService",
"MultiProcessingReadingService",
"PauseIteration",
"PrototypeMultiProcessingReadingService",
Expand Down
Loading

0 comments on commit 8f9d123

Please sign in to comment.