From f794b4981af0a860b8d03943326909eb1bbe28f6 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Thu, 9 Feb 2023 12:45:08 -0800 Subject: [PATCH 01/10] raise worker exceptions and exit well --- test/dataloader2/test_dataloader2.py | 31 +++++++++++++++++++ torchdata/dataloader2/communication/iter.py | 20 ++++++++++-- .../dataloader2/communication/messages.py | 5 +++ .../dataloader2/communication/protocol.py | 6 ++++ torchdata/dataloader2/reading_service.py | 4 +++ torchdata/datapipes/iter/util/prefetcher.py | 8 ++++- 6 files changed, 71 insertions(+), 3 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 8d7d84448..d600f25e3 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -46,6 +46,8 @@ mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods()) +DATAPIPE_ITERATION = 7 + class _ReadingServiceWrapper: def __init__(self, dp): @@ -62,6 +64,18 @@ def __next__(self): def return_one(): return 1 +class MakeMistakeDataPipe(IterDataPipe): + def __init__(self, source_datapipe, iteration=DATAPIPE_ITERATION): + self.source_datapipe = source_datapipe + # TODO generate random num within 100 to raise expection at + self.iteration = iteration + + def __iter__(self): + for i, x in enumerate(self.source_datapipe): + if i == self.iteration: + raise Exception("oops") + yield x + class TestReadingService(ReadingServiceInterface): def initialize(self, dp: DataPipe) -> DataPipe: @@ -83,6 +97,23 @@ def test_dataloader2_shutdown(self) -> None: data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) data_loader.shutdown() + def test_passing_errors(self): + dp = IterableWrapper(range(100)).sharding_filter() + dp = MakeMistakeDataPipe(dp) + for worker_prefetch_cnt in [0,5,10]: + print("== worker_prefetch_cnt ", worker_prefetch_cnt) + # TODO test with multiple workers + rs = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=worker_prefetch_cnt) + dl = DataLoader2(dp, reading_service=rs) + it = iter(dl) + for i in range(DATAPIPE_ITERATION): + item = next(it) + print("item ", item) + with self.assertRaises(communication.iter.WorkerException): + item = next(it) + # print("=== expect exception") + # next(it) + def test_dataloader2_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 8c02a93d1..9fb1d4272 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -57,6 +57,17 @@ class TerminateRequired(Exception): pass +class WorkerException(Exception): + def __init__(self, exception): + self.exception = exception + + def __str__(self): + return self.exception + + def __traceback__(self): + return self.exception.__traceback__ + + class NonBlocking(IterDataPipe): not_available_hook = default_not_available_hook @@ -112,7 +123,6 @@ def reset_iterator(self): ) return validated_datapipe - def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False): """ Indefinitely iterates over ``req_queue`` and passing values from source_datapipe to ``res_queue``. @@ -167,6 +177,10 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False): protocol.response_invalid_state() yield True break + except Exception as e: + # print("==== setting resp") + protocol.response_worker_exception(str(e)) + return protocol.response_next(value) yield True # Returns control break @@ -216,7 +230,6 @@ def nonblocking_next(self): raise NotAvailable return response.value - class _IterateQueueDataPipes(IterDataPipe): r""" Takes in ``QueueWrapper``s and iterates through them in a round-robin manner to get batches one-by-one. @@ -252,6 +265,9 @@ def __iter__(self): raise communication.iter.InvalidStateResetRequired if isinstance(response, communication.messages.TerminateResponse): raise communication.iter.TerminateRequired + if isinstance(response, communication.messages.WorkerExceptionResponse): + # print("====raising", idx, total_pipes) + raise communication.iter.WorkerException(f"Exception from worker {idx}: {response.exception}") self.datapipes[idx].protocol.request_next() yield response.value diff --git a/torchdata/dataloader2/communication/messages.py b/torchdata/dataloader2/communication/messages.py index 1e5c388af..6d1eee25e 100644 --- a/torchdata/dataloader2/communication/messages.py +++ b/torchdata/dataloader2/communication/messages.py @@ -92,3 +92,8 @@ class InvalidStateResponse(Response): """ pass + + +class WorkerExceptionResponse(Response): + def __init__(self, exception): + self.exception = exception diff --git a/torchdata/dataloader2/communication/protocol.py b/torchdata/dataloader2/communication/protocol.py index 245ff0f19..26634c1a0 100644 --- a/torchdata/dataloader2/communication/protocol.py +++ b/torchdata/dataloader2/communication/protocol.py @@ -192,6 +192,12 @@ def response_invalid_state(self): self.response_queue.put(communication.messages.InvalidStateResponse()) self._req_received = None + def response_worker_exception(self, exception): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.WorkerExceptionResponse(exception)) + self._req_received = None + class IterDataPipeQueueProtocolClient(ProtocolClient): def request_reset_iterator(self): diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 8f7bbd297..2ab46e58c 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -264,8 +264,11 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: for worker_id in range(self.num_workers): worker_info = WorkerInfo(self.num_workers, worker_id) # Dispatching process for non-replicable DataPipes exists + print("=====in initialize PMPRS disp req queue for ", worker_id) dispatching_req_queue = self._dispatch_process[1][worker_id] if self._dispatch_process is not None else None + print("=====in initialize PMPRS disp req queue for ", worker_id) dispatching_res_queue = self._dispatch_process[2][worker_id] if self._dispatch_process is not None else None + print("=====in initialize PMPRS disp req queue for ", worker_id) call_on_process_init = partial( process_init_fn, worker_info=worker_info, @@ -280,6 +283,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: ) process.daemon = True process.start() + print("=====in initialize PMPRS proc started ", process) self._worker_processes.append((process, req_queue, res_queue)) # These queues are independent local_datapipe = communication.iter.QueueWrapper( communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 8ec13fbc3..8505f51ca 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -74,6 +74,8 @@ def thread_worker(prefetch_data): stop_iteration = True except communication.iter.TerminateRequired: prefetch_data.run_prefetcher = False + except Exception as e: + prefetch_data.prefetch_buffer.append(e) elif stop_iteration and len(prefetch_data.prefetch_buffer) == 0: prefetch_data.run_prefetcher = False else: # Buffer is full, waiting for main thread to consume items @@ -93,7 +95,11 @@ def __iter__(self): self.thread.start() while prefetch_data.run_prefetcher: if len(prefetch_data.prefetch_buffer) > 0: - yield prefetch_data.prefetch_buffer.popleft() + item = prefetch_data.prefetch_buffer.popleft() + if isinstance(item, Exception): + prefetch_data.run_prefetcher = False + raise item + yield item else: # TODO: Calculate sleep interval based on previous availability speed time.sleep(CONSUMER_SLEEP_INTERVAL) From 5e31929ddc2022f3724db1e148a90f971d473eb5 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Thu, 9 Feb 2023 12:55:40 -0800 Subject: [PATCH 02/10] update const name --- test/dataloader2/test_dataloader2.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index d600f25e3..9c2fdea96 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -46,7 +46,7 @@ mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods()) -DATAPIPE_ITERATION = 7 +EXCEPTION_ITERATION_NUM = 7 class _ReadingServiceWrapper: @@ -65,14 +65,13 @@ def return_one(): return 1 class MakeMistakeDataPipe(IterDataPipe): - def __init__(self, source_datapipe, iteration=DATAPIPE_ITERATION): + def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM): self.source_datapipe = source_datapipe - # TODO generate random num within 100 to raise expection at - self.iteration = iteration + self.exc_iteration = exc_iteration def __iter__(self): for i, x in enumerate(self.source_datapipe): - if i == self.iteration: + if i == self.exc_iteration: raise Exception("oops") yield x @@ -101,18 +100,14 @@ def test_passing_errors(self): dp = IterableWrapper(range(100)).sharding_filter() dp = MakeMistakeDataPipe(dp) for worker_prefetch_cnt in [0,5,10]: - print("== worker_prefetch_cnt ", worker_prefetch_cnt) # TODO test with multiple workers rs = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=worker_prefetch_cnt) dl = DataLoader2(dp, reading_service=rs) it = iter(dl) - for i in range(DATAPIPE_ITERATION): + for i in range(EXCEPTION_ITERATION_NUM): item = next(it) - print("item ", item) with self.assertRaises(communication.iter.WorkerException): item = next(it) - # print("=== expect exception") - # next(it) def test_dataloader2_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) From 98bcaae46352a1b595092bf6884949c4af5324ba Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Thu, 9 Feb 2023 16:51:31 -0800 Subject: [PATCH 03/10] update tests for multiple workers --- test/dataloader2/test_dataloader2.py | 20 ++++++++++---------- torchdata/dataloader2/communication/iter.py | 2 -- torchdata/dataloader2/reading_service.py | 4 ---- torchdata/datapipes/iter/util/prefetcher.py | 1 + 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 9c2fdea96..84405b5ea 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -96,18 +96,18 @@ def test_dataloader2_shutdown(self) -> None: data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) data_loader.shutdown() - def test_passing_errors(self): + def test_worker_exception_raised(self): dp = IterableWrapper(range(100)).sharding_filter() dp = MakeMistakeDataPipe(dp) - for worker_prefetch_cnt in [0,5,10]: - # TODO test with multiple workers - rs = PrototypeMultiProcessingReadingService(num_workers=1, worker_prefetch_cnt=worker_prefetch_cnt) - dl = DataLoader2(dp, reading_service=rs) - it = iter(dl) - for i in range(EXCEPTION_ITERATION_NUM): - item = next(it) - with self.assertRaises(communication.iter.WorkerException): - item = next(it) + for worker_prefetch_cnt in [0, 5, 10]: + for num_workers in [1, 4]: + rs = PrototypeMultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) + dl = DataLoader2(dp, reading_service=rs) + it = iter(dl) + for i in range(EXCEPTION_ITERATION_NUM*num_workers): + item = next(it) + with self.assertRaises(communication.iter.WorkerException): + item = next(it) def test_dataloader2_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 9fb1d4272..7ba43c86a 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -178,7 +178,6 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False): yield True break except Exception as e: - # print("==== setting resp") protocol.response_worker_exception(str(e)) return protocol.response_next(value) @@ -266,7 +265,6 @@ def __iter__(self): if isinstance(response, communication.messages.TerminateResponse): raise communication.iter.TerminateRequired if isinstance(response, communication.messages.WorkerExceptionResponse): - # print("====raising", idx, total_pipes) raise communication.iter.WorkerException(f"Exception from worker {idx}: {response.exception}") self.datapipes[idx].protocol.request_next() yield response.value diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 2ab46e58c..8f7bbd297 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -264,11 +264,8 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: for worker_id in range(self.num_workers): worker_info = WorkerInfo(self.num_workers, worker_id) # Dispatching process for non-replicable DataPipes exists - print("=====in initialize PMPRS disp req queue for ", worker_id) dispatching_req_queue = self._dispatch_process[1][worker_id] if self._dispatch_process is not None else None - print("=====in initialize PMPRS disp req queue for ", worker_id) dispatching_res_queue = self._dispatch_process[2][worker_id] if self._dispatch_process is not None else None - print("=====in initialize PMPRS disp req queue for ", worker_id) call_on_process_init = partial( process_init_fn, worker_info=worker_info, @@ -283,7 +280,6 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: ) process.daemon = True process.start() - print("=====in initialize PMPRS proc started ", process) self._worker_processes.append((process, req_queue, res_queue)) # These queues are independent local_datapipe = communication.iter.QueueWrapper( communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 8505f51ca..fb0a1bc1e 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -76,6 +76,7 @@ def thread_worker(prefetch_data): prefetch_data.run_prefetcher = False except Exception as e: prefetch_data.prefetch_buffer.append(e) + break elif stop_iteration and len(prefetch_data.prefetch_buffer) == 0: prefetch_data.run_prefetcher = False else: # Buffer is full, waiting for main thread to consume items From 305bfe162493be3d9a9abfeed00beb2d1a037220 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Fri, 10 Feb 2023 14:45:55 -0800 Subject: [PATCH 04/10] update exception content --- torchdata/dataloader2/communication/iter.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index fc4d3a283..0e6debb8f 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -58,14 +58,11 @@ class TerminateRequired(Exception): class WorkerException(Exception): - def __init__(self, exception): - self.exception = exception - - def __str__(self): - return self.exception + """ + Returned by DataPipe when there is a failure/exception from a worker process + """ - def __traceback__(self): - return self.exception.__traceback__ + pass class NonBlocking(IterDataPipe): @@ -191,7 +188,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, yield True break except Exception as e: - protocol.response_worker_exception(str(e)) + protocol.response_worker_exception(e) return protocol.response_next(value) yield True # Returns control @@ -278,7 +275,7 @@ def __iter__(self): if isinstance(response, communication.messages.TerminateResponse): raise communication.iter.TerminateRequired if isinstance(response, communication.messages.WorkerExceptionResponse): - raise communication.iter.WorkerException(f"Exception from worker {idx}: {response.exception}") + raise communication.iter.WorkerException(f"Exception from worker {idx}") from response.exception self.datapipes[idx].protocol.request_next() yield response.value From 1dc99cdf4ee258f5b4f45fe94cd6be30c2e2511e Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Fri, 10 Feb 2023 14:49:10 -0800 Subject: [PATCH 05/10] fix linter error --- test/dataloader2/test_dataloader2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index e4ff66a4a..2f21c6728 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -122,9 +122,9 @@ def test_worker_exception_raised(self): dl = DataLoader2(dp, reading_service=rs) it = iter(dl) for i in range(EXCEPTION_ITERATION_NUM*num_workers): - item = next(it) + next(it) with self.assertRaises(communication.iter.WorkerException): - item = next(it) + next(it) def test_dataloader2_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) From a210f3df9b4234dca81685dd1a94fc6b7095bf32 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Fri, 10 Feb 2023 16:41:02 -0800 Subject: [PATCH 06/10] fix ufmt errors --- test/dataloader2/test_dataloader2.py | 15 ++++++++------- torchdata/dataloader2/communication/iter.py | 1 + 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 2f21c6728..c263db607 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -81,6 +81,7 @@ def __next__(self): def return_one(): return 1 + class MakeMistakeDataPipe(IterDataPipe): def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM): self.source_datapipe = source_datapipe @@ -118,10 +119,12 @@ def test_worker_exception_raised(self): dp = MakeMistakeDataPipe(dp) for worker_prefetch_cnt in [0, 5, 10]: for num_workers in [1, 4]: - rs = PrototypeMultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) + rs = PrototypeMultiProcessingReadingService( + num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt + ) dl = DataLoader2(dp, reading_service=rs) it = iter(dl) - for i in range(EXCEPTION_ITERATION_NUM*num_workers): + for i in range(EXCEPTION_ITERATION_NUM * num_workers): next(it) with self.assertRaises(communication.iter.WorkerException): next(it) @@ -208,7 +211,6 @@ def test_dataloader2_iterates_correctly(self) -> None: self.assertEqual(list(range(10)), actual) def test_dataloader2_reset(self) -> None: - test_data_pipe = IterableWrapper(range(10)) reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)] @@ -338,7 +340,6 @@ def test_lazy_load(self): "fork is not supported. Dying (set die_after_fork=0 to override)", ) class TestDataLoader2EventLoop(TestCase): - # TODO: This needs fixing, see issue 624 # @skipIfNoDill # def test_basic_threading(self): @@ -753,9 +754,9 @@ def _random_fn(data): r""" Used to validate the randomness of subprocess-local RNGs are set deterministically. """ - py_random_num = random.randint(0, 2 ** 32) - np_random_num = np.random.randint(0, 2 ** 32) - torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() + py_random_num = random.randint(0, 2**32) + np_random_num = np.random.randint(0, 2**32) + torch_random_num = torch.randint(0, 2**32, size=[]).item() return (data, py_random_num, np_random_num, torch_random_num) diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 0e6debb8f..5cc30a32c 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -239,6 +239,7 @@ def nonblocking_next(self): raise NotAvailable return response.value + class _IterateQueueDataPipes(IterDataPipe): r""" Takes in ``QueueWrapper``s and iterates through them in a round-robin manner to get batches one-by-one. From 80be0ef5e31c180ea49a242c299ad322158e6522 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Tue, 14 Feb 2023 16:20:29 -0800 Subject: [PATCH 07/10] precommit lint fixes --- test/dataloader2/test_dataloader2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index c263db607..f248846d4 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -754,9 +754,9 @@ def _random_fn(data): r""" Used to validate the randomness of subprocess-local RNGs are set deterministically. """ - py_random_num = random.randint(0, 2**32) - np_random_num = np.random.randint(0, 2**32) - torch_random_num = torch.randint(0, 2**32, size=[]).item() + py_random_num = random.randint(0, 2 ** 32) + np_random_num = np.random.randint(0, 2 ** 32) + torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() return (data, py_random_num, np_random_num, torch_random_num) From ee8a5fd7825e1ecab24e79e863943b6e1d628b4f Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Tue, 14 Feb 2023 16:25:00 -0800 Subject: [PATCH 08/10] remove pass --- torchdata/dataloader2/communication/iter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 5cc30a32c..2b8a5ed9b 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -62,8 +62,6 @@ class WorkerException(Exception): Returned by DataPipe when there is a failure/exception from a worker process """ - pass - class NonBlocking(IterDataPipe): not_available_hook = default_not_available_hook From 004e26415da7e5336989530d5900bd9585bc2e08 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Wed, 15 Feb 2023 11:24:38 -0800 Subject: [PATCH 09/10] use MultiProcessingReadingService --- test/dataloader2/test_dataloader2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 50f5661ff..1e7e09a00 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -27,7 +27,6 @@ DataLoader2, DistributedReadingService, MultiProcessingReadingService, - PrototypeMultiProcessingReadingService, ReadingServiceInterface, SequentialReadingService, ) @@ -119,7 +118,7 @@ def test_worker_exception_raised(self): dp = MakeMistakeDataPipe(dp) for worker_prefetch_cnt in [0, 5, 10]: for num_workers in [1, 4]: - rs = PrototypeMultiProcessingReadingService( + rs = MultiProcessingReadingService( num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt ) dl = DataLoader2(dp, reading_service=rs) From cdda7e41512912b932dde50cb55e6159d57971b3 Mon Sep 17 00:00:00 2001 From: Priya Ramani Date: Wed, 15 Feb 2023 11:43:31 -0800 Subject: [PATCH 10/10] ufmt changes --- test/dataloader2/test_dataloader2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 1e7e09a00..ef6d86cf6 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -118,9 +118,7 @@ def test_worker_exception_raised(self): dp = MakeMistakeDataPipe(dp) for worker_prefetch_cnt in [0, 5, 10]: for num_workers in [1, 4]: - rs = MultiProcessingReadingService( - num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt - ) + rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) dl = DataLoader2(dp, reading_service=rs) it = iter(dl) for i in range(EXCEPTION_ITERATION_NUM * num_workers):