Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PrototypeRS] Adding 'pause' and 'resume' operations to halt DataPipes #879

Closed
wants to merge 45 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d22790b
[PrototypeRS] Adding 'full stop' to halt DataPipes
NivekT Nov 3, 2022
9efb6b3
Update on "[PrototypeRS] Adding 'full stop' to halt DataPipes"
NivekT Nov 7, 2022
50bea6c
Update on "[PrototypeRS] Adding 'full stop' to halt DataPipes"
NivekT Nov 10, 2022
a853244
Update on "[PrototypeRS] Adding 'full stop' to halt DataPipes"
NivekT Nov 11, 2022
c88fce0
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Nov 14, 2022
856e8e7
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Nov 16, 2022
261fb9d
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Nov 16, 2022
a78d5d6
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Nov 16, 2022
26c3055
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Nov 17, 2022
e44c4a9
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 3, 2022
b7eab9d
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 3, 2022
2aae34f
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
9faa78d
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
bd7dc6e
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
250dec3
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
ba66c1c
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
d5a6fa9
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
3471a34
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
865ee90
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
ab24eb9
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
c0c7dec
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
c483e96
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 5, 2022
031af83
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 7, 2022
ac017f2
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 7, 2022
eb6f3fb
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 7, 2022
80c65e0
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 14, 2022
ea7952f
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 15, 2022
b960db3
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 17, 2022
08db7ea
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Dec 18, 2022
cb6f017
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Jan 3, 2023
b70b3a7
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Jan 13, 2023
c57c4c4
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Jan 18, 2023
314cc4c
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Jan 18, 2023
d209b6a
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Jan 18, 2023
84d5346
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Jan 24, 2023
29a4fd0
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 2, 2023
a1ecb98
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 6, 2023
dc72f20
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 7, 2023
76219c7
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 8, 2023
e5b2544
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 8, 2023
caeb103
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 8, 2023
491cb41
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 8, 2023
00cbd0e
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 9, 2023
7506f01
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 9, 2023
01bb5c7
Update on "[PrototypeRS] Adding 'pause' and 'resume' operations to ha…
NivekT Feb 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torch.utils.data import IterDataPipe
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import traverse_dps

DEFAULT_NON_BLOCKING_SLEEP = 0.001

Expand Down Expand Up @@ -137,12 +138,30 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
source_datapipe.reset_iterator()
protocol.response_reset_iterator()

elif isinstance(request, communication.messages.FullStopRequest):
graph = traverse_dps(source_datapipe)
for dp, _ in graph.values():
if hasattr(dp, "full_stop") and callable(dp.full_stop):
dp.full_stop()
protocol.response_full_stop()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any potential issue with traversing through the graph and calling .full_stop() on all of them?


elif isinstance(request, communication.messages.ResumeRequest):
graph = traverse_dps(source_datapipe)
for dp, _ in graph.values():
if hasattr(dp, "resume") and callable(dp.resume):
dp.resume()
protocol.response_resume()

elif isinstance(request, communication.messages.TerminateRequest):
forever = False
protocol.response_terminate()

elif isinstance(request, communication.messages.GetNextRequest):
while forever:
if protocol._full_stop:
raise RuntimeError(
"Cannot `GetNext` after `FullStop` has been called. " "`Resume` must be called first."
)
try:
value = source_datapipe.nonblocking_next()
except NotAvailable:
Expand Down
16 changes: 16 additions & 0 deletions torchdata/dataloader2/communication/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ class ResetEpochResponse(Response):
pass


class FullStopRequest(Request):
pass


class FullStopResponse(Response):
pass


class ResumeRequest(Request):
pass


class ResumeResponse(Response):
pass


class TerminateRequest(Request):
pass

Expand Down
32 changes: 31 additions & 1 deletion torchdata/dataloader2/communication/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,42 @@ def request_sent(self, request=True):

def request_served(self, result=None):
if not self.waiting_for_response():
raise Exception("Expected no peding requests, but something got served", result)
raise Exception("Expected no pending requests, but something got served", result)
self._req_sent = None

def discard_existing_request(self):
if self.waiting_for_response():
response = self.response_queue.get(block=True)
self.request_served(response)

def request_full_stop(self):
if not self.can_take_request():
raise Exception("Can not full stop while we are still waiting response for previous request")
request = communication.messages.FullStopRequest()
self.request_queue.put(request)
self.request_sent(request)

def request_resume(self):
if not self.can_take_request():
raise Exception("Can not full stop while we are still waiting response for previous request")
request = communication.messages.ResumeRequest()
self.request_queue.put(request)
self.request_sent(request)


class ProtocolServer(Protocol):
"""
ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe.
"""

_req_received = None
_full_stop = False # When `True`, prevents `GetNext` in `DataPipeBehindQueues`.

def __init__(self, request_queue, response_queue):
self.request_queue = request_queue
self.response_queue = response_queue
self._req_received = None
self._full_stop = False

def have_pending_request(self):
return self._req_received is not None
Expand Down Expand Up @@ -93,6 +109,20 @@ def response_reset_epoch(self):
self.response_queue.put(communication.messages.ResetEpochResponse())
self._req_received = None

def response_full_stop(self):
if not self.have_pending_request():
raise Exception("Attempting to reply with pending request")
self._full_stop = True
self.response_queue.put(communication.messages.FullStopResponse())
self._req_received = None

def response_resume(self):
if not self.have_pending_request():
raise Exception("Attempting to reply with pending request")
self._full_stop = False
self.response_queue.put(communication.messages.ResumeResponse())
self._req_received = None


class MapDataPipeQueueProtocolServer(ProtocolServer):
def response_item(self, key, value):
Expand Down
42 changes: 39 additions & 3 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ class _IterateQueueDataPipes(IterDataPipe):

def __init__(self, datapipes):
# TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper
# into one class, which supports any number of queues.
self.datapipes = datapipes
for dp in self.datapipes:
# into one class, which supports any number of queues.
for dp in datapipes:
if not isinstance(dp, communication.iter.QueueWrapper):
raise Exception("Source datapipes should be an instance of iter.QueueWrapper")
self.datapipes = datapipes
self.res_buffers = [[] for _ in range(len(datapipes))]

def __iter__(self):
total_pipes = len(self.datapipes)
Expand All @@ -137,6 +138,9 @@ def __iter__(self):
while cnt_disabled_pipes < total_pipes:
for idx in range(total_pipes):
if not disabled_pipe[idx]:
# Check if buffer of the DataPipe is empty before requesting next
while len(self.res_buffers[idx]):
yield self.res_buffers[idx].pop()
response = self.datapipes[idx].protocol.get_response_next(block=True)
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[idx] = True
Expand Down Expand Up @@ -164,6 +168,18 @@ def reset_epoch(self, *args):
for dp in self.datapipes:
dp.protocol.request_reset_epoch(*args)

def request_full_stop(self):
# Store results of pending requests
for idx, dp in enumerate(self.datapipes):
res = dp.protocol.get_response_next(block=True)
self.res_buffers[idx].append(res)
for dp in self.datapipes:
dp.protocol.request_full_stop()

def request_resume(self):
for dp in self.datapipes:
dp.protocol.request_resume()


class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
r"""
Expand Down Expand Up @@ -347,6 +363,26 @@ def clean_me(process, req_queue, res_queue):
dist.destroy_process_group(self._pg)
self._pg = None

def full_stop(self):
"""
Fully stop DataPipes' activities such as prefetching, in order to collect state.
"""
if self.prefetch_mainloop > 0:
# Stop prefetching first
self.end_datapipe.full_stop() # type: ignore[union-attr]
end_datapipe: DataPipe = self.end_datapipe.source_datapipe
else:
end_datapipe = self.end_datapipe
end_datapipe.request_full_stop()

def resume(self):
if self.prefetch_mainloop > 0:
self.end_datapipe.resume() # type: ignore[union-attr]
end_datapipe: DataPipe = self.end_datapipe.source_datapipe
else:
end_datapipe = self.end_datapipe
end_datapipe.request_resume()


class MultiProcessingReadingService(ReadingServiceInterface):
r"""
Expand Down
8 changes: 8 additions & 0 deletions torchdata/datapipes/iter/util/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,11 @@ def __setstate__(self, state):
self._error = None
self._sync_counter = torch.tensor([0], dtype=torch.int32)
self._done_callback = False

def full_stop(self):
if self._executor is not None:
self._executor.shutdown()
self._executor = None

def resume(self):
self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout)
21 changes: 18 additions & 3 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ class PrefetcherIterDataPipe(IterDataPipe):
"""
Prefetches elements from the source DataPipe and puts them into a buffer (functional name: ``prefetch``).
Prefetching performs the operations (e.g. I/O, computations) of the DataPipes up to this one ahead of time
and stores the result in the buffer, ready to be consume by the subsequent DataPipe. It has no effect aside
and stores the result in the buffer, ready to be consumed by the subsequent DataPipe. It has no effect aside
from getting the sample ready ahead of time.

This is used by ``PrototypeMultiProcessingReadingService`` when the arguments
``prefetch_worker`` (for prefetching at each worker process) or
``prefetch_mainloop`` (for prefetching at the moain loop) are greater than 0.
``prefetch_mainloop`` (for prefetching at the main loop) are greater than 0.

Beyond the built-in use cases, this can be useful to put after I/O DataPipes that have
expensive I/O operations (e.g. takes a long time to request a file from a remote server).
Expand Down Expand Up @@ -104,7 +104,7 @@ def __iter__(self):

def __getstate__(self):
"""
Getting state in threading enviroment requires next operations:
Getting state in threading environment requires next operations:
1) Stopping of the producer thread.
2) Saving buffer.
3) Adding lazy restart of producer thread when __next__ is called again
Expand All @@ -123,3 +123,18 @@ def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.thread.join()
self.thread = None

def full_stop(self):
if self.thread is not None:
# Note: the content of the buffer still exists in `prefetch_data.prefetch_buffer`
self.prefetch_data.run_prefetcher = False
self.thread.join()
self.thread = None
NivekT marked this conversation as resolved.
Show resolved Hide resolved

def resume(self):
self.thread = threading.Thread(
target=PrefetcherIterDataPipe.thread_worker, args=(self.prefetch_data,), daemon=True
)
self.prefetch_data.run_prefetcher = True
self.thread.start()
NivekT marked this conversation as resolved.
Show resolved Hide resolved