diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index c870833c0..3e9bc5c70 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -252,6 +252,7 @@ def _collect_data(self, datapipe, reading_service_gen): result.append(row) for row in dl: result.append(row) + dl.shutdown() return result @staticmethod diff --git a/test/dataloader2/test_proto_multi_rs.py b/test/dataloader2/test_proto_multi_rs.py new file mode 100644 index 000000000..52c5c239e --- /dev/null +++ b/test/dataloader2/test_proto_multi_rs.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import multiprocessing as mp +import unittest +from unittest import TestCase + +from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize +from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService +from torchdata.datapipes.iter import IterableWrapper + + +def _add_one(x: int) -> int: + return x + 1 + + +# Test DataPipes +n_elements = 10 +dp1 = IterableWrapper(range(n_elements)).shuffle().sharding_filter() +double_pause_dp = dp1.prefetch().prefetch() +test_dps = [dp1, double_pause_dp] + + +mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods()) +dp_parametrize = parametrize("dp", test_dps) + + +class TestPrototypeMultiProcessingReadingService(TestCase): + r""" + This tests specific functionalities of PrototypeMultiProcessingReadingService, notably + `pause`, `resume`, `snapshot`. + """ + + @mp_ctx_parametrize + def test_reading_service_pause_resume_0_worker(self, ctx) -> None: + + # Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used + # with `num_workers = 0` + rs0 = PrototypeMultiProcessingReadingService( + num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx + ) + dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0) + res0 = [] + for i, x in enumerate(dl0): + res0.append(x) + if i in {2}: + with self.assertRaisesRegex(RuntimeError, r"pause"): + dl0._pause() + with self.assertRaisesRegex(RuntimeError, r"resume"): + dl0._resume() + dl0.shutdown() + + @mp_ctx_parametrize + @dp_parametrize + @parametrize( + "n_workers,worker_prefetch_cnt,main_prefetch_cnt", + [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 0), (2, 0, 2), (2, 2, 2)], + ) + def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: + + # Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline + # properly pauses and resumes + rs = PrototypeMultiProcessingReadingService( + num_workers=n_workers, + worker_prefetch_cnt=worker_prefetch_cnt, + main_prefetch_cnt=main_prefetch_cnt, + multiprocessing_context=ctx, + ) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) + res = [] + for i, x in enumerate(dl): + res.append(x) + if i in {2, n_elements - 2}: + dl._pause() + dl._resume() + + self.assertEqual( + list(range(n_elements)), + sorted(res), + msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, " + f"main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + dl.shutdown() + + @mp_ctx_parametrize + @dp_parametrize + @parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(2, 0, 1), (2, 1, 0), (2, 0, 0)]) + def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: + + # Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called + rs = PrototypeMultiProcessingReadingService( + num_workers=n_workers, + worker_prefetch_cnt=worker_prefetch_cnt, + main_prefetch_cnt=main_prefetch_cnt, + multiprocessing_context=ctx, + ) + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) + res = [] + for i, x in enumerate(dl): + res.append(x) + if i in {2}: + dl._pause() + self.assertEqual( + 3, + len(res), + msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + dl.shutdown() + + @dp_parametrize + @parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)]) + def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: + + rs = PrototypeMultiProcessingReadingService( + num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt + ) + + dl: DataLoader2 = DataLoader2(dp, reading_service=rs) + res = [] + cumulative_res = [] + n_limit = 3 + + it: DataLoader2Iterator = iter(dl) + it.limit(n_limit) + for x in it: + res.append(x) + # Functional Test: Verify that the number of elements yielded equals to the specified limit + self.assertEqual( + n_limit, + len(res), # 3 + msg=f"The test is failing with default multiprocessing method, " + f"num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + cumulative_res.extend(res) + + # Functional Test: Calling `next` after `limit` will trigger `StopIteration` + with self.assertRaises(StopIteration): + next(it) + + # Functional Test: Verify that `limit` persists without the need to set it again + it.resume() + res = [] + for x in it: + res.append(x) + self.assertEqual( + n_limit, + len(res), # 3 + msg=f"The test is failing with default multiprocessing method, " + f"num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + cumulative_res.extend(res) + + # Functional Test: Clear the `limit` and yield the rest of the elements + it.limit(None) + it.resume() + res = [] + for x in it: + res.append(x) + self.assertEqual( + n_elements - 2 * n_limit, + len(res), # 4 + msg=f"The test is failing with default multiprocessing method, " + f"num_workers = {rs.num_workers}, " + f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + ) + + cumulative_res.extend(res) + self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) + + # Functional Test: Setting `limit` to a different value during after each mini-epoch + dl2: DataLoader2 = DataLoader2(double_pause_dp, reading_service=rs) + res = [] + it2: DataLoader2Iterator = iter(dl2) + it2.limit(3) + for x in it2: + res.append(x) + + # Limit can be set before `resume` + it2.limit(4) + it2.resume() + for x in it2: + res.append(x) + self.assertEqual(7, len(res)) + + # Limit can also be set after `resume`, but before the next `for` loop + it2.resume() + it2.limit(2) + for x in it2: + res.append(x) + self.assertEqual(9, len(res)) + + # TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding + # Currently, using sharding_round_robin raises a warning + # def test_round_robin_dispatching_pause_limit(self): + # source_dp = IterableWrapper(range(20)) + # dp = source_dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) + # dp = dp.map(_add_one) + + # TODO: This doesn't work with `num_workers > 1` + # TODO: Try checking if `dp_list`'s elements are _IterateQueueDP or QueueWrapper, we can safely assume + # those DPs belong to a dispatching process and only do pause if worker_id == 0 + # There might still be a race condition, need to look into the messages + + # rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0) + # rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2) + # rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0) + # rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2) + # rss = [rs1, rs2, rs3, rs4] + + # for n, rs in enumerate(rss): + # dl = DataLoader2(dp, reading_service=rs) + # res = [] + # # cumulative_res = [] + # n_limit = 3 + # + # it: DataLoader2Iterator = iter(dl) + # it.limit(n_limit) # The `pause` call here doesn't stop + # for x in it: + # res.append(x) + # + # print() + # print(res) + # + # dl.shutdown() + + # # Functional Test: Verify that the number of elements yielded equals to the specified limit + # # self.assertEqual( + # # n_limit, + # # len(res), # 3 + # # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, " + # # f"num_workers = {rs.num_workers}, " + # # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + # # ) + # cumulative_res.extend(res) + # + # # Functional Test: Calling `next` after `limit` will trigger `StopIteration` + # with self.assertRaisesRegex(StopIteration, "pause"): + # next(it) + # + # # Functional Test: Verify that `limit` persists without the need to set it again + # it.resume() + # res = [] + # for x in it: + # res.append(x) + # # self.assertEqual( + # # n_limit, + # # len(res), # 3 + # # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, " + # # f"num_workers = {rs.num_workers}, " + # # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + # # ) + # cumulative_res.extend(res) + # + # # Functional Test: Clear the `limit` and yield the rest of the elements + # it.limit(None) + # it.resume() + # res = [] + # for x in it: + # res.append(x) + # # self.assertEqual( + # # n_elements - 2 * n_limit, + # # len(res), # 4 + # # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, " + # # f"num_workers = {rs.num_workers}, " + # # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", + # # ) + # + # cumulative_res.extend(res) + # self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) + + # TODO: Implemented in an upcoming PR + # def test_reading_service_snapshot(self) -> None: + # pass + # + # def test_dataloader2_snapshot(self) -> None: + # pass + + +instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index aaf123739..9748801b8 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -6,13 +6,15 @@ import time import types +import warnings +from collections import deque from functools import partial -from typing import Callable +from typing import Callable, Deque, List from torch.utils.data import IterDataPipe from torchdata.dataloader2 import communication -from torchdata.dataloader2.graph import DataPipe +from torchdata.dataloader2.graph import DataPipe, list_dps, traverse_dps from torchdata.dataloader2.random import SeedGenerator from torchdata.dataloader2.utils import WorkerInfo @@ -160,12 +162,43 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, source_datapipe.reset_iterator() protocol.response_reset_iterator() + elif isinstance(request, communication.messages.PauseRequest): + dp_list = list_dps(traverse_dps(source_datapipe)) + for dp in dp_list: + # TODO: Remove this condition after there is `pause` support for round-robin sharding + if isinstance(dp, QueueWrapper): + warnings.warn("There is no support for `pause` with round-robin sharding at the moment.") + elif hasattr(dp, "pause") and callable(dp.pause): + dp.pause() + + protocol.response_pause() + yield True # Return control + + elif isinstance(request, communication.messages.ResumeRequest): + dp_list = list_dps(traverse_dps(source_datapipe)) + for dp in reversed(dp_list): + # TODO: Remove this condition after there is `resume` support for round-robin sharding + if isinstance(dp, QueueWrapper): + raise RuntimeError("There is no support for `resume` with round-robin sharding at the moment.") + elif hasattr(dp, "resume") and callable(dp.resume): + dp.resume() + protocol.response_resume() + yield True # Return control + elif isinstance(request, communication.messages.TerminateRequest): forever = False protocol.response_terminate() elif isinstance(request, communication.messages.GetNextRequest): while forever: + if protocol.is_paused(): + protocol.response_stop_iteration() + warnings.warn( + "Cannot `GetNext` after `Pause` has been called. " + "`Resume` must be called first before additional elements can be yielded." + ) + yield True + break try: value = source_datapipe.nonblocking_next() except NotAvailable: @@ -212,6 +245,26 @@ def reset_iterator(self): if NonBlocking.not_available_hook is not None: NonBlocking.not_available_hook() + def pause(self): + self.protocol.request_pause() + while True: + try: + self.protocol.get_response_pause() + break + except communication.protocol.EmptyQueue: + if NonBlocking.not_available_hook is not None: + NonBlocking.not_available_hook() + + def resume(self): + self.protocol.request_resume() + while True: + try: + self.protocol.get_response_resume() + break + except communication.protocol.EmptyQueue: + if NonBlocking.not_available_hook is not None: + NonBlocking.not_available_hook() + def nonblocking_next(self): if self._stop_iteration: raise Exception("`next` or `nonblocking_next` called after receiving StopIteration") @@ -238,11 +291,12 @@ class _IterateQueueDataPipes(IterDataPipe): def __init__(self, datapipes): # TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper - # into one class, which supports any number of queues. - self.datapipes = datapipes - for dp in self.datapipes: - if not isinstance(dp, QueueWrapper): + # into one class, which supports any number of queues. + for dp in datapipes: + if not isinstance(dp, communication.iter.QueueWrapper): raise Exception("Source datapipes should be an instance of iter.QueueWrapper") + self.datapipes = datapipes + self.res_buffers: List[Deque] = [deque() for _ in range(len(datapipes))] def __iter__(self): total_pipes = len(self.datapipes) @@ -255,7 +309,11 @@ def __iter__(self): while cnt_disabled_pipes < total_pipes: for idx in range(total_pipes): if not disabled_pipe[idx]: - response = self.datapipes[idx].protocol.get_response_next(block=True) + # Check if buffer of the DataPipe is empty, if not, yield one before requesting next + if len(self.res_buffers[idx]): + response = self.res_buffers[idx].popleft() + else: + response = self.datapipes[idx].protocol.get_response_next(block=True) if isinstance(response, communication.messages.StopIterationResponse): disabled_pipe[idx] = True cnt_disabled_pipes += 1 @@ -264,7 +322,8 @@ def __iter__(self): raise communication.iter.InvalidStateResetRequired if isinstance(response, communication.messages.TerminateResponse): raise communication.iter.TerminateRequired - self.datapipes[idx].protocol.request_next() + if len(self.res_buffers[idx]) == 0: # Only request if buffer is empty + self.datapipes[idx].protocol.request_next() yield response.value def reset(self): @@ -292,3 +351,16 @@ def reset_epoch( except communication.protocol.EmptyQueue: if NonBlocking.not_available_hook is not None: NonBlocking.not_available_hook() + + def request_pause(self): + # Store results of pending requests + for idx, dp in enumerate(self.datapipes): + if dp.protocol.waiting_for_response(): + res = dp.protocol.get_response_next(block=True) + self.res_buffers[idx].append(res) + for dp in self.datapipes: + dp.pause() + + def request_resume(self): + for dp in self.datapipes: + dp.resume() diff --git a/torchdata/dataloader2/communication/messages.py b/torchdata/dataloader2/communication/messages.py index 1e5c388af..f436e2778 100644 --- a/torchdata/dataloader2/communication/messages.py +++ b/torchdata/dataloader2/communication/messages.py @@ -36,6 +36,22 @@ class ResetEpochResponse(Response): pass +class PauseRequest(Request): + pass + + +class PauseResponse(Response): + pass + + +class ResumeRequest(Request): + pass + + +class ResumeResponse(Response): + pass + + class TerminateRequest(Request): pass diff --git a/torchdata/dataloader2/communication/protocol.py b/torchdata/dataloader2/communication/protocol.py index 245ff0f19..38764b774 100644 --- a/torchdata/dataloader2/communication/protocol.py +++ b/torchdata/dataloader2/communication/protocol.py @@ -42,7 +42,7 @@ def request_sent(self, request=True): def request_served(self, result=None): if not self.waiting_for_response(): - raise Exception("Expected no peding requests, but something got served", result) + raise Exception("Expected no pending requests, but something got served", result) self._req_sent = None def discard_existing_request(self): @@ -50,18 +50,39 @@ def discard_existing_request(self): response = self.response_queue.get(block=True) self.request_served(response) + def request_pause(self): + if not self.can_take_request(): + raise Exception("Can not `pause` while we are still waiting response for previous request") + request = communication.messages.PauseRequest() + self.request_queue.put(request) + self.request_sent(request) + + def request_resume(self): + if not self.can_take_request(): + raise Exception("Can not `resume` while we are still waiting response for previous request") + request = communication.messages.ResumeRequest() + self.request_queue.put(request) + self.request_sent(request) + class ProtocolServer(Protocol): """ ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. """ + # TODO(966): Update the exceptions raised in this class to be more specific + _req_received = None + _paused = False # When `True`, prevents `GetNext` in `DataPipeBehindQueues`. def __init__(self, request_queue, response_queue): self.request_queue = request_queue self.response_queue = response_queue self._req_received = None + self._paused = False + + def is_paused(self): + return self._paused def have_pending_request(self): return self._req_received is not None @@ -81,7 +102,7 @@ def response_terminate(self): if not self.have_pending_request(): raise Exception("Attempting to reply with pending request") if not isinstance(self._req_received, communication.messages.TerminateRequest): - raise Exception("Replaying with terminate status to other type of message") + raise Exception("Replaying with `terminate` status to other type of message") self.response_queue.put(communication.messages.TerminateResponse()) self._req_received = None @@ -89,10 +110,28 @@ def response_reset_epoch(self): if not self.have_pending_request(): raise Exception("Attempting to reply with pending request") if not isinstance(self._req_received, communication.messages.ResetEpochRequest): - raise Exception("Replaying with reset epoch status to other type of message") + raise Exception("Replaying with `reset_epoch` status to other type of message") self.response_queue.put(communication.messages.ResetEpochResponse()) self._req_received = None + def response_pause(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + if not isinstance(self._req_received, communication.messages.PauseRequest): + raise Exception("Replaying with `pause` status to other type of message") + self._paused = True + self.response_queue.put(communication.messages.PauseResponse()) + self._req_received = None + + def response_resume(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + if not isinstance(self._req_received, communication.messages.ResumeRequest): + raise Exception("Replaying with `resume` status to other type of message") + self._paused = False + self.response_queue.put(communication.messages.ResumeResponse()) + self._req_received = None + class MapDataPipeQueueProtocolServer(ProtocolServer): def response_item(self, key, value): @@ -235,6 +274,26 @@ def get_response_reset_epoch(self, block=False): if not isinstance(response, communication.messages.ResetEpochResponse): raise Exception("Invalid response received") + def get_response_pause(self, block=False): + try: + response = self.response_queue.get(block=block) + except EmptyException: + raise EmptyQueue("queue is empty") + self.request_served(response) + + if not isinstance(response, communication.messages.PauseResponse): + raise Exception("Invalid response received when expecting `PauseResponse`") + + def get_response_resume(self, block=False): + try: + response = self.response_queue.get(block=block) + except EmptyException: + raise EmptyQueue("queue is empty") + self.request_served(response) + + if not isinstance(response, communication.messages.ResumeResponse): + raise Exception("Invalid response received when expecting `ResumeResponse`") + def get_response_next(self, block=False, timeout=None): if not self.waiting_for_response(): raise Exception("Can not expect any response without submitted request") diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index 3072c5aa3..77227ce15 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +import warnings + from dataclasses import dataclass from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union @@ -46,13 +48,21 @@ class DataLoader2Iterator(Iterator[T_co]): def __init__(self, dataloader: "DataLoader2", iterator_id: int): self.dataloader = dataloader self.iterator_id = iterator_id + self.limit_counter: Optional[int] = None + self.limit_threshold: Optional[int] = None def __next__(self) -> T_co: if self.iterator_id == self.dataloader.valid_iterator_id: self.dataloader._reset_iter = True try: - return next(self.dataloader._datapipe_iter) # type: ignore[arg-type] - except PauseIteration: + if self.dataloader._is_paused: + raise PauseIteration("DataLoader2 has been paused. `resume` must be called before continuing.") + else: + next_val = next(self.dataloader._datapipe_iter) # type: ignore[arg-type] + if self.limit_threshold is not None: + self.limit_counter = self.limit_counter + 1 # type: ignore[operator] + return next_val + except PauseIteration: # This can be used for raising `StopIteration` without `finalize_iteration` raise StopIteration except StopIteration: if self.dataloader.reading_service is not None: @@ -62,7 +72,15 @@ def __next__(self) -> T_co: if self.dataloader: self.dataloader.shutdown() raise - else: + finally: + # Call `pause` if threshold is reached + if ( + not self.dataloader._is_paused + and self.limit_threshold is not None + and self.limit_counter >= self.limit_threshold # type: ignore[operator] + ): + self._pause() + else: # `iterator_id` is not valid if self.dataloader.reading_service is not None: self.dataloader.reading_service.finalize_iteration() raise RuntimeError( @@ -73,6 +91,44 @@ def __next__(self) -> T_co: "to comment on this issue: https://github.com/pytorch/data/issues/45." ) + def _pause(self) -> None: + r""" + Pauses ``DataLoader2`` by halting its threads and ensure that its state remains unchanged, + allowing ``DataLoader2`` to safely perform snapshotting and similar operations afterwards. + + The ``limit_counter`` is also reset to ``0``. + """ + self.dataloader._pause() + self.limit_counter = 0 + + def resume(self) -> None: + r""" + Restarts the threads within ``DataLoader2`` and allows it to yield additional batches. + """ + self.dataloader._resume() + if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"): + self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined] + + def limit(self, num_batches: Optional[int]) -> None: + """ + Pauses ``DataLoader2`` from yielding additional batches after ``num_batches`` has been yielded. The count + begins after this method is invoked (i.e. previously yielded batches do not count towards the threshold). + + While paused, ``DataLoader2``'s threads are halted and its state remains unchanged, + allowing ``DataLoader2`` to safely perform snapshotting and similar operations. + After ``DataLoader2`` is paused, ``resume()`` must be called before it can start yielding again. + + Note: + ``limit_threshold`` persists after ``pause`` and ``resume``. Use ``.limit(None)`` to remove it. + + Args: + num_batches: Number of batches after which the DataLoader2 will pause, use ``None`` to remove the limit + """ + self.limit_counter = 0 + self.limit_threshold = num_batches + if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"): + self.dataloader._datapipe_iter.limit() # type: ignore[attr-defined] + def __getattr__(self, name): """ To delegate operations to ``dataloader._datapipe_iter``. @@ -114,7 +170,7 @@ def __init__( self.datapipe = clone(wrap_datapipe_for_serialization(datapipe)) if datapipe is not None else None self._adapted: bool = False self._datapipe_iter: Optional[Iterator[T_co]] = None - self._reset_iter: bool = True # Sets to `False` when __iter__ starts, and `True` when `StopIteration`` + self._reset_iter: bool = True # Sets to `False` when `__iter__` runs, and `True` when `__next__` is called # TODO(630): Some ReadingServices might want to validate adapters, we can add this feature if datapipe_adapter_fn is None: self.datapipe_adapter_fns = None @@ -126,6 +182,7 @@ def __init__( self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called self._terminated: bool = False self.valid_iterator_id: Optional[int] = None + self._is_paused = False if self.datapipe is not None and self.datapipe_adapter_fns is not None: for adapter_fn in self.datapipe_adapter_fns: @@ -283,3 +340,24 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: for adapter_fn in self.datapipe_adapter_fns: self.datapipe = adapter_fn(self.datapipe) self._datapipe_before_reading_service_adapt = clone(self.datapipe) + + def _pause(self): + if hasattr(self.reading_service, "_pause"): + self._is_paused = True + self.reading_service._pause() + # TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used + elif self._datapipe_iter is None or not ( + hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause") + ): + warnings.warn("ReadingService doesn't support pause.") + + def _resume(self): + if hasattr(self.reading_service, "_resume"): + if not self._is_paused: + warnings.warn("Resume is called when `DataLoader2` is not paused. No operation is performed.") + else: + self.reading_service._resume() + self._is_paused = False + # TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used + elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"): + warnings.warn("ReadingService doesn't support resume.") diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 8a93ddeb5..92b67c476 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -73,7 +73,7 @@ def initialize_iteration( Args: seed_generator: SeedGenerator object created and managed by DataLoader2. As the single - source of randomness, it will governs the determinism for all of random operations + source of randomness, it will govern the determinism for all of random operations with the graph of DataPipes. iter_reset_fn: Optional reset function from the prior ``ReadingServcie`` when ``SequentialReadingService`` chains multiple ``ReadingServices`` @@ -371,11 +371,41 @@ def clean_me(process, req_queue, res_queue): self._worker_processes = [] self._dispatch_process = None + def _pause(self): + """ + Pauses DataPipes' activities such as prefetching, in order to collect state. + """ + if self.main_prefetch_cnt > 0 and self.num_workers > 0: + # Stop prefetching of main loop first + self._main_prefetch_datapipe.pause() # type: ignore[union-attr] + if self.num_workers > 0: + self._worker_consumer_datapipe.request_pause() # type: ignore[union-attr] + else: + raise RuntimeError( + "If you would like to use `pause` with `PrototypeMultiProcessingReadingService`, " + "please use more than 0 worker." + ) + + def _resume(self): + """ + Resumes DataPipes' activities. This is required to be called after `_pause` before + the DataLoader can keep yielding elements. + """ + if self.num_workers > 0: + self._worker_consumer_datapipe.request_resume() # type: ignore[union-attr] + else: + raise RuntimeError( + "If you would like to use `resume` with `PrototypeMultiProcessingReadingService`, " + "please use more than 0 worker." + ) + if self.main_prefetch_cnt > 0 and self.num_workers > 0: + self._main_prefetch_datapipe.resume() # type: ignore[union-attr] + class MultiProcessingReadingService(ReadingServiceInterface): r""" ``MultiProcessingReadingService`` that utilizes ``torch.utils.data.DataLoader`` to - launch subprocesses for ``DataPipe`` graph. Please refers to documents of ``DataLoader`` + launch subprocesses for ``DataPipe`` graph. Please refer to documents of ``DataLoader`` in https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader for all arguments. Note: @@ -455,7 +485,7 @@ def __init__(self, timeout: int = default_timeout_in_s): def initialize(self, datapipe: DataPipe) -> DataPipe: r""" Launches the ``gloo``-backend distributed process group. Carries out distributed sharding - on the graph of ``DataPipe`` and returnes the graph attached with a ``FullSyncIterDataPipe`` + on the graph of ``DataPipe`` and returns the graph attached with a ``FullSyncIterDataPipe`` at the end. """ if not (dist.is_available() and dist.is_initialized()): diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py index 46fb53165..2e6fe813c 100644 --- a/torchdata/datapipes/iter/util/distributed.py +++ b/torchdata/datapipes/iter/util/distributed.py @@ -229,3 +229,13 @@ def __setstate__(self, state): self._error = None self._sync_counter = torch.tensor([0], dtype=torch.int32) self._done_callback = False + + def pause(self): + raise RuntimeError("`pause` is not supported for FullSync at the moment.") + # if self._executor is not None: + # self._executor.shutdown() + # self._executor = None + + def resume(self): + raise RuntimeError("`resume` is not supported for FullSync at the moment.") + # self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout) diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 8cd66f559..1dc18c2bb 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -13,8 +13,8 @@ from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe -PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fullfilment checks -CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availablitity in buffer +PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fulfillment checks +CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availability in buffer class _PrefetchData: @@ -23,6 +23,7 @@ def __init__(self, source_datapipe, buffer_size: int): self.prefetch_buffer: Deque = deque() self.buffer_size: int = buffer_size self.source_datapipe = source_datapipe + self.stop_iteration = False @functional_datapipe("prefetch") @@ -30,7 +31,7 @@ class PrefetcherIterDataPipe(IterDataPipe): """ Prefetches elements from the source DataPipe and puts them into a buffer (functional name: ``prefetch``). Prefetching performs the operations (e.g. I/O, computations) of the DataPipes up to this one ahead of time - and stores the result in the buffer, ready to be consume by the subsequent DataPipe. It has no effect aside + and stores the result in the buffer, ready to be consumed by the subsequent DataPipe. It has no effect aside from getting the sample ready ahead of time. This is used by ``PrototypeMultiProcessingReadingService`` when the arguments @@ -57,54 +58,56 @@ def __init__(self, source_datapipe, buffer_size: int = 10): self.thread: Optional[threading.Thread] = None @staticmethod - def thread_worker(prefetch_data): + def thread_worker(prefetch_data: _PrefetchData): # Lazily import to prevent circular import from torchdata.dataloader2 import communication itr = iter(prefetch_data.source_datapipe) - stop_iteration = False - while prefetch_data.run_prefetcher: - if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size and not stop_iteration: - try: - item = next(itr) - prefetch_data.prefetch_buffer.append(item) - except StopIteration: - stop_iteration = True - except communication.iter.InvalidStateResetRequired: - stop_iteration = True - except communication.iter.TerminateRequired: + while not prefetch_data.stop_iteration: + while prefetch_data.run_prefetcher: + if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size and not prefetch_data.stop_iteration: + try: + item = next(itr) + prefetch_data.prefetch_buffer.append(item) + except StopIteration: + prefetch_data.stop_iteration = True + except communication.iter.InvalidStateResetRequired: + prefetch_data.stop_iteration = True + except communication.iter.TerminateRequired: + prefetch_data.run_prefetcher = False + prefetch_data.stop_iteration = True + elif prefetch_data.stop_iteration and len(prefetch_data.prefetch_buffer) == 0: prefetch_data.run_prefetcher = False - elif stop_iteration and len(prefetch_data.prefetch_buffer) == 0: - prefetch_data.run_prefetcher = False - else: # Buffer is full, waiting for main thread to consume items - # TODO: Calculate sleep interval based on previous consumption speed - time.sleep(PRODUCER_SLEEP_INTERVAL) + else: # Buffer is full, waiting for main thread to consume items + # TODO: Calculate sleep interval based on previous consumption speed + time.sleep(PRODUCER_SLEEP_INTERVAL) + time.sleep(PRODUCER_SLEEP_INTERVAL) def __iter__(self): - if self.buffer_size < 1: - yield from self.source_datapipe - else: - try: - prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) - self.prefetch_data = prefetch_data - thread = threading.Thread( - target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True - ) - self.thread = thread - self.thread.start() - while prefetch_data.run_prefetcher: - if len(prefetch_data.prefetch_buffer) > 0: - yield prefetch_data.prefetch_buffer.popleft() - else: - # TODO: Calculate sleep interval based on previous availability speed + try: + prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) + self.prefetch_data = prefetch_data + thread = threading.Thread(target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True) + self.thread = thread + self.thread.start() + + while prefetch_data.run_prefetcher: + if len(prefetch_data.prefetch_buffer) > 0: + yield prefetch_data.prefetch_buffer.popleft() + else: + # TODO: Calculate sleep interval based on previous availability speed + if not prefetch_data.stop_iteration: time.sleep(CONSUMER_SLEEP_INTERVAL) - finally: - prefetch_data.run_prefetcher = False - thread.join() + else: + prefetch_data.run_prefetcher = False + finally: + prefetch_data.run_prefetcher = False + prefetch_data.stop_iteration = True + thread.join() def __getstate__(self): """ - Getting state in threading enviroment requires next operations: + Getting state in threading environment requires next operations: 1) Stopping of the producer thread. 2) Saving buffer. 3) Adding lazy restart of producer thread when __next__ is called again @@ -122,5 +125,16 @@ def __setstate__(self, state): def reset(self): if self.thread is not None: self.prefetch_data.run_prefetcher = False + self.prefetch_data.stop_iteration = True self.thread.join() self.thread = None + + def pause(self): + if self.thread is not None: + self.prefetch_data.run_prefetcher = False + + def resume(self): + if self.thread is not None and ( + not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0 + ): + self.prefetch_data.run_prefetcher = True