Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward worker exceptions and exit with it #1003

Closed
wants to merge 13 commits into from
29 changes: 27 additions & 2 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@

mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods())

EXCEPTION_ITERATION_NUM = 7


class _ReadingServiceWrapper:
def __init__(self, dp):
Expand All @@ -79,6 +81,18 @@ def return_one():
return 1


class MakeMistakeDataPipe(IterDataPipe):
def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM):
self.source_datapipe = source_datapipe
self.exc_iteration = exc_iteration

def __iter__(self):
for i, x in enumerate(self.source_datapipe):
if i == self.exc_iteration:
raise Exception("oops")
yield x


class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: DataPipe) -> DataPipe:
return _ReadingServiceWrapper(dp) # type: ignore[return-value]
Expand All @@ -99,6 +113,19 @@ def test_dataloader2_shutdown(self) -> None:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_worker_exception_raised(self):
dp = IterableWrapper(range(100)).sharding_filter()
dp = MakeMistakeDataPipe(dp)
for worker_prefetch_cnt in [0, 5, 10]:
for num_workers in [1, 4]:
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if there is a DataPipe after Prefetcher? Does it still work properly?

e.g. dp.map().prefetch().map()

dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for i in range(EXCEPTION_ITERATION_NUM * num_workers):
next(it)
with self.assertRaises(communication.iter.WorkerException):
next(it)

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
Expand Down Expand Up @@ -171,7 +198,6 @@ def test_dataloader2_iterates_correctly(self) -> None:
self.assertEqual(list(range(10)), actual)

def test_dataloader2_reset(self) -> None:

test_data_pipe = IterableWrapper(range(10))
reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)]

Expand Down Expand Up @@ -264,7 +290,6 @@ def test_dataloader2_shuffle(self) -> None:
"fork is not supported. Dying (set die_after_fork=0 to override)",
)
class TestDataLoader2EventLoop(TestCase):

# TODO: This needs fixing, see issue 624
# @skipIfNoDill
# def test_basic_threading(self):
Expand Down
11 changes: 11 additions & 0 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class TerminateRequired(Exception):
pass


class WorkerException(Exception):
"""
Returned by DataPipe when there is a failure/exception from a worker process
"""


class NonBlocking(IterDataPipe):
not_available_hook = default_not_available_hook

Expand Down Expand Up @@ -212,6 +218,9 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
protocol.response_invalid_state()
yield True
break
except Exception as e:
protocol.response_worker_exception(e)
return
protocol.response_next(value)
yield True # Returns control
break
Expand Down Expand Up @@ -322,6 +331,8 @@ def __iter__(self):
raise communication.iter.InvalidStateResetRequired
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
if isinstance(response, communication.messages.WorkerExceptionResponse):
raise communication.iter.WorkerException(f"Exception from worker {idx}") from response.exception
if len(self.res_buffers[idx]) == 0: # Only request if buffer is empty
self.datapipes[idx].protocol.request_next()
yield response.value
Expand Down
5 changes: 5 additions & 0 deletions torchdata/dataloader2/communication/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,8 @@ class InvalidStateResponse(Response):
"""

pass


class WorkerExceptionResponse(Response):
def __init__(self, exception):
self.exception = exception
6 changes: 6 additions & 0 deletions torchdata/dataloader2/communication/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def response_invalid_state(self):
self.response_queue.put(communication.messages.InvalidStateResponse())
self._req_received = None

def response_worker_exception(self, exception):
if not self.have_pending_request():
raise Exception("Attempting to reply with pending request")
self.response_queue.put(communication.messages.WorkerExceptionResponse(exception))
self._req_received = None


class IterDataPipeQueueProtocolClient(ProtocolClient):
def request_reset_iterator(self):
Expand Down
9 changes: 8 additions & 1 deletion torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def thread_worker(prefetch_data: _PrefetchData):
except communication.iter.TerminateRequired:
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
except Exception as e:
prefetch_data.prefetch_buffer.append(e)
break
NivekT marked this conversation as resolved.
Show resolved Hide resolved
elif prefetch_data.stop_iteration and len(prefetch_data.prefetch_buffer) == 0:
prefetch_data.run_prefetcher = False
else: # Buffer is full, waiting for main thread to consume items
Expand All @@ -93,7 +96,11 @@ def __iter__(self):

while prefetch_data.run_prefetcher:
if len(prefetch_data.prefetch_buffer) > 0:
yield prefetch_data.prefetch_buffer.popleft()
item = prefetch_data.prefetch_buffer.popleft()
if isinstance(item, Exception):
prefetch_data.run_prefetcher = False
raise item
yield item
else:
# TODO: Calculate sleep interval based on previous availability speed
if not prefetch_data.stop_iteration:
Expand Down